From: Tang Yizhou tangyizhou@huawei.com
ascend inclusion category: perf bugzilla: 47462 CVE: NA
-------------------------------------------------
After getting the pointer of struct sp_proc_stat stat by calling sp_get_proc_stat(), the memory of stat may be released if target process exits.
To solve this problem, we increase the refcount of stat when call sp_get_proc_stat(). Users should call sp_proc_stat_drop() after finishing using it.
Signed-off-by: Tang Yizhou tangyizhou@huawei.com Reviewed-by: Ding Tianhong dingtianhong@huawei.com Reviewed-by: Kefeng Wang wangkefeng.wang@huawei.com Signed-off-by: Yang Yingliang yangyingliang@huawei.com --- include/linux/share_pool.h | 7 +++ mm/oom_kill.c | 4 +- mm/share_pool.c | 120 +++++++++++++++++++++++++------------ 3 files changed, 91 insertions(+), 40 deletions(-)
diff --git a/include/linux/share_pool.h b/include/linux/share_pool.h index 356781bfe3e0a..d94d48f57798c 100644 --- a/include/linux/share_pool.h +++ b/include/linux/share_pool.h @@ -109,6 +109,8 @@ struct sp_walk_data {
/* per process memory usage statistics indexed by tgid */ struct sp_proc_stat { + atomic_t use_count; + int tgid; struct mm_struct *mm; char comm[TASK_COMM_LEN]; /* @@ -170,6 +172,7 @@ extern int sp_unregister_notifier(struct notifier_block *nb); extern bool sp_config_dvpp_range(size_t start, size_t size, int device_id, int pid); extern bool is_sharepool_addr(unsigned long addr); extern struct sp_proc_stat *sp_get_proc_stat(int tgid); +extern void sp_proc_stat_drop(struct sp_proc_stat *stat); extern void spa_overview_show(struct seq_file *seq); extern void spg_overview_show(struct seq_file *seq); extern void proc_sharepool_init(void); @@ -373,6 +376,10 @@ static inline struct sp_proc_stat *sp_get_proc_stat(int tgid) return NULL; }
+static inline void sp_proc_stat_drop(struct sp_proc_stat *stat) +{ +} + static inline void spa_overview_show(struct seq_file *seq) { } diff --git a/mm/oom_kill.c b/mm/oom_kill.c index 86db5d5508234..b10a38f58a55c 100644 --- a/mm/oom_kill.c +++ b/mm/oom_kill.c @@ -433,10 +433,12 @@ static void dump_tasks(struct mem_cgroup *memcg, const nodemask_t *nodemask) task->tgid, task->mm->total_vm, get_mm_rss(task->mm)); if (!stat) pr_cont("%-9c %-9c ", '-', '-'); - else + else { pr_cont("%-9ld %-9ld ", /* byte to KB */ atomic64_read(&stat->alloc_size) >> 10, atomic64_read(&stat->k2u_size) >> 10); + sp_proc_stat_drop(stat); + } pr_cont("%8ld %8lu %5hd %s\n", mm_pgtables_bytes(task->mm), get_mm_counter(task->mm, MM_SWAPENTS), diff --git a/mm/share_pool.c b/mm/share_pool.c index 7d50b55b80cae..6ba479887f0da 100644 --- a/mm/share_pool.c +++ b/mm/share_pool.c @@ -87,47 +87,72 @@ static DEFINE_IDA(sp_group_id_ida);
/* idr of all sp_proc_stats */ static DEFINE_IDR(sp_stat_idr); -/* rw semaphore for sp_stat_idr */ +/* rw semaphore for sp_stat_idr and mm->sp_stat_id */ static DECLARE_RWSEM(sp_stat_sem);
/* for kthread buff_module_guard_work */ static struct sp_proc_stat kthread_stat = {0};
+/* The caller must hold sp_stat_sem */ +static struct sp_proc_stat *sp_get_proc_stat_locked(int tgid) +{ + struct sp_proc_stat *stat; + + stat = idr_find(&sp_stat_idr, tgid); + if (stat) + atomic_inc(&stat->use_count); + + /* maybe NULL or not, we always return it */ + return stat; +} + /* * The caller must ensure no concurrency problem * for task_struct and mm_struct. + * + * The user must call sp_proc_stat_drop() after use. */ static struct sp_proc_stat *sp_init_proc_stat(struct task_struct *tsk, struct mm_struct *mm) { struct sp_proc_stat *stat; - int id = mm->sp_stat_id; - int tgid = tsk->tgid; + int id, tgid = tsk->tgid; int ret;
+ down_write(&sp_stat_sem); + id = mm->sp_stat_id; if (id) { - stat = sp_get_proc_stat(id); /* other threads in the same process may have initialized it */ - if (stat) + stat = sp_get_proc_stat_locked(tgid); + if (stat) { + up_write(&sp_stat_sem); return stat; + } else { + /* if enter this branch, that's our mistake */ + pr_err("share pool: sp_init_proc_stat invalid id %d\n", id); + return ERR_PTR(-EBUSY); + } }
stat = kzalloc(sizeof(*stat), GFP_KERNEL); if (stat == NULL) { + up_write(&sp_stat_sem); if (printk_ratelimit()) pr_err("share pool: alloc proc stat failed due to lack of memory\n"); return ERR_PTR(-ENOMEM); }
+ /* use_count = 2: match with sp_proc_stat_drop */ + atomic_set(&stat->use_count, 2); atomic64_set(&stat->alloc_size, 0); atomic64_set(&stat->k2u_size, 0); + stat->tgid = tgid; stat->mm = mm; get_task_comm(stat->comm, tsk);
- down_write(&sp_stat_sem); ret = idr_alloc(&sp_stat_idr, stat, tgid, tgid + 1, GFP_KERNEL); - up_write(&sp_stat_sem); if (ret < 0) { + up_write(&sp_stat_sem); if (printk_ratelimit()) pr_err("share pool: proc stat idr alloc failed %d\n", ret); kfree(stat); @@ -135,6 +160,7 @@ static struct sp_proc_stat *sp_init_proc_stat(struct task_struct *tsk, }
mm->sp_stat_id = ret; + up_write(&sp_stat_sem); return stat; }
@@ -727,13 +753,10 @@ int sp_group_add_task(int pid, int spg_id) } up_write(&spg->rw_lock);
- if (unlikely(ret)) { - down_write(&sp_stat_sem); - idr_remove(&sp_stat_idr, mm->sp_stat_id); - up_write(&sp_stat_sem); - kfree(stat); - mm->sp_stat_id = 0; - } + /* double drop when fail: ensure release stat */ + if (unlikely(ret)) + sp_proc_stat_drop(stat); + sp_proc_stat_drop(stat); /* match with sp_init_proc_stat */
out_drop_group: if (unlikely(ret)) @@ -780,16 +803,15 @@ void sp_group_post_exit(struct mm_struct *mm) "It applied %ld aligned KB, k2u shared %ld aligned KB\n", stat->comm, mm->sp_stat_id, mm->sp_group->id, byte2kb(alloc_size), byte2kb(k2u_size)); - }
- down_write(&sp_stat_sem); - idr_remove(&sp_stat_idr, mm->sp_stat_id); - up_write(&sp_stat_sem); + /* match with sp_get_proc_stat in THIS function */ + sp_proc_stat_drop(stat); + /* match with sp_init_proc_stat, we expect stat is released after this call */ + sp_proc_stat_drop(stat); + }
/* match with sp_group_add_task -> find_or_alloc_sp_group */ sp_group_drop(spg); - - kfree(stat); }
/* the caller must hold sp_area_lock */ @@ -1240,9 +1262,10 @@ int sp_free(unsigned long addr) atomic64_sub(spa->real_size, &kthread_stat.alloc_size); } else { stat = sp_get_proc_stat(current->mm->sp_stat_id); - if (stat) + if (stat) { atomic64_sub(spa->real_size, &stat->alloc_size); - else + sp_proc_stat_drop(stat); + } else BUG(); }
@@ -1489,8 +1512,10 @@ void *sp_alloc(unsigned long size, unsigned long sp_flags, int spg_id)
if (!IS_ERR(p)) { stat = sp_get_proc_stat(current->mm->sp_stat_id); - if (stat) + if (stat) { atomic64_add(size_aligned, &stat->alloc_size); + sp_proc_stat_drop(stat); + } }
/* this will free spa if mmap failed */ @@ -1769,13 +1794,12 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size, /* * Process statistics initialization. if the target process has been * added to a sp group, then stat will be returned immediately. - * I believe there is no need to free stat in error handling branches. */ stat = sp_init_proc_stat(tsk, mm); if (IS_ERR(stat)) { uva = stat; pr_err("share pool: init proc stat failed, ret %lx\n", PTR_ERR(stat)); - goto out; + goto out_put_mm; }
spg = __sp_find_spg(pid, SPG_ID_DEFAULT); @@ -1785,7 +1809,7 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size, if (printk_ratelimit()) pr_err("share pool: k2task invalid spg id %d\n", spg_id); uva = ERR_PTR(-EINVAL); - goto out; + goto out_drop_proc_stat; } spa = sp_alloc_area(size_aligned, sp_flags, NULL, SPA_TYPE_K2TASK); if (IS_ERR(spa)) { @@ -1794,7 +1818,7 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size, "(potential no enough virtual memory when -75): %ld\n", PTR_ERR(spa)); uva = spa; - goto out; + goto out_drop_proc_stat; }
if (!vmalloc_area_set_flag(spa, kva_aligned, VM_SHAREPOOL)) { @@ -1815,8 +1839,7 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size, if (printk_ratelimit()) pr_err("share pool: k2spg invalid spg id %d\n", spg_id); uva = ERR_PTR(-EINVAL); - sp_group_drop(spg); - goto out; + goto out_drop_spg; }
if (enable_share_k2u_spg) @@ -1831,14 +1854,12 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size, "(potential no enough virtual memory when -75): %ld\n", PTR_ERR(spa)); uva = spa; - sp_group_drop(spg); - goto out; + goto out_drop_spg; }
if (!vmalloc_area_set_flag(spa, kva_aligned, VM_SHAREPOOL)) { up_read(&spg->rw_lock); pr_err("share pool: %s: the kva %pK is not valid\n", __func__, (void *)kva_aligned); - sp_group_drop(spg); goto out_drop_spa; }
@@ -1853,7 +1874,6 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size, uva = ERR_PTR(-ENODEV); } up_read(&spg->rw_lock); - sp_group_drop(spg);
accounting: if (!IS_ERR(uva)) { @@ -1868,7 +1888,12 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size,
out_drop_spa: __sp_area_drop(spa); -out: +out_drop_spg: + if (spg) + sp_group_drop(spg); +out_drop_proc_stat: + sp_proc_stat_drop(stat); +out_put_mm: mmput(mm); out_put_task: put_task_struct(tsk); @@ -2298,9 +2323,10 @@ static int sp_unshare_uva(unsigned long uva, unsigned long size, int pid, int sp atomic64_sub(spa->real_size, &kthread_stat.k2u_size); } else { stat = sp_get_proc_stat(current->mm->sp_stat_id); - if (stat) + if (stat) { atomic64_sub(spa->real_size, &stat->k2u_size); - else + sp_proc_stat_drop(stat); + } else WARN(1, "share pool: %s: null process stat\n", __func__); }
@@ -2525,18 +2551,33 @@ __setup("enable_sp_share_k2u_spg", enable_share_k2u_to_group);
/*** Statistical and maintenance functions ***/
+/* user must call sp_proc_stat_drop() after use */ struct sp_proc_stat *sp_get_proc_stat(int tgid) { struct sp_proc_stat *stat;
down_read(&sp_stat_sem); - stat = idr_find(&sp_stat_idr, tgid); + stat = sp_get_proc_stat_locked(tgid); up_read(&sp_stat_sem); - - /* maybe NULL or not, we always return it */ return stat; }
+static void free_sp_proc_stat(struct sp_proc_stat *stat) +{ + stat->mm->sp_stat_id = 0; + down_write(&sp_stat_sem); + idr_remove(&sp_stat_idr, stat->tgid); + up_write(&sp_stat_sem); + kfree(stat); +} + +/* the caller make sure stat is not NULL */ +void sp_proc_stat_drop(struct sp_proc_stat *stat) +{ + if (atomic_dec_and_test(&stat->use_count)) + free_sp_proc_stat(stat); +} + int proc_sp_group_state(struct seq_file *m, struct pid_namespace *ns, struct pid *pid, struct task_struct *task) { @@ -2569,6 +2610,7 @@ int proc_sp_group_state(struct seq_file *m, struct pid_namespace *ns, byte2kb(atomic64_read(&stat->alloc_size)), hugepage_failures);
+ sp_proc_stat_drop(stat); sp_group_drop(spg); return 0; }