From: Tang Yizhou tangyizhou@huawei.com
ascend inclusion category: perf bugzilla: 47462 CVE: NA
-------------------------------------------------
After getting the pointer of sp_group spg by calling __sp_find_spg(), the memory of spg may be released if spg is dead and free_sp_group() is called.
To solve this problem, we increase the refcount of spg when call __sp_find_spg(). Users should call sp_group_drop() after finishing using it.
Signed-off-by: Tang Yizhou tangyizhou@huawei.com Reviewed-by: Ding Tianhong dingtianhong@huawei.com Reviewed-by: Kefeng Wang wangkefeng.wang@huawei.com Signed-off-by: Yang Yingliang yangyingliang@huawei.com --- mm/share_pool.c | 131 +++++++++++++++++++++++++++++------------------- 1 file changed, 80 insertions(+), 51 deletions(-)
diff --git a/mm/share_pool.c b/mm/share_pool.c index 1bf6df6e42417..7d50b55b80cae 100644 --- a/mm/share_pool.c +++ b/mm/share_pool.c @@ -196,21 +196,6 @@ static bool host_svm_sp_enable = false;
int sysctl_share_pool_hugepage_enable = 1;
-static void free_sp_group(struct sp_group *spg); - -/* the caller make sure spg is not NULL */ -static bool sp_group_get(struct sp_group *spg) -{ - down_read(&spg->rw_lock); - if (spg_valid(spg) && atomic_inc_not_zero(&spg->use_count)) { - up_read(&spg->rw_lock); - return true; - } - up_read(&spg->rw_lock); - - return false; -} - static unsigned long spa_size(struct sp_area *spa) { return spa->real_size; @@ -322,7 +307,8 @@ static void free_sp_group(struct sp_group *spg) kfree(spg); }
-static struct sp_group *__sp_find_spg(int pid, int spg_id) +/* user must call sp_group_drop() after use */ +static struct sp_group *__sp_find_spg_locked(int pid, int spg_id) { struct sp_group *spg; int ret = 0; @@ -347,20 +333,41 @@ static struct sp_group *__sp_find_spg(int pid, int spg_id) task_lock(tsk); if (tsk->mm == NULL) spg = NULL; - else + else { spg = tsk->mm->sp_group; + /* don't revive a dead group */ + if (!spg || !atomic_inc_not_zero(&spg->use_count)) + spg = NULL; + } task_unlock(tsk);
put_task_struct(tsk); } else { - down_read(&sp_group_sem); spg = idr_find(&sp_group_idr, spg_id); - up_read(&sp_group_sem); + /* don't revive a dead group */ + if (!spg || !atomic_inc_not_zero(&spg->use_count)) + spg = NULL; }
return spg; }
+static struct sp_group *__sp_find_spg(int pid, int spg_id) +{ + struct sp_group *spg; + + down_read(&sp_group_sem); + spg = __sp_find_spg_locked(pid, spg_id); + up_read(&sp_group_sem); + return spg; +} + +static void sp_group_drop(struct sp_group *spg) +{ + if (atomic_dec_and_test(&spg->use_count)) + free_sp_group(spg); +} + int sp_group_id_by_pid(int pid) { struct sp_group *spg; @@ -377,6 +384,7 @@ int sp_group_id_by_pid(int pid) spg_id = spg->id; up_read(&spg->rw_lock);
+ sp_group_drop(spg); return spg_id; } EXPORT_SYMBOL_GPL(sp_group_id_by_pid); @@ -387,9 +395,8 @@ static struct sp_group *find_or_alloc_sp_group(int spg_id) int ret; char name[20];
- down_read(&sp_group_sem); - spg = idr_find(&sp_group_idr, spg_id); - up_read(&sp_group_sem); + down_write(&sp_group_sem); + spg = __sp_find_spg_locked(current->pid, spg_id);
if (!spg) { struct user_struct *user = NULL; @@ -401,6 +408,15 @@ static struct sp_group *find_or_alloc_sp_group(int spg_id) pr_err("share pool: alloc spg failed due to lack of memory\n"); return ERR_PTR(-ENOMEM); } + ret = idr_alloc(&sp_group_idr, spg, spg_id, spg_id + 1, + GFP_KERNEL); + up_write(&sp_group_sem); + if (ret < 0) { + if (printk_ratelimit()) + pr_err("share pool: create group idr alloc failed\n"); + goto out_kfree; + } + spg->id = spg_id; atomic_set(&spg->spa_num, 0); atomic64_set(&spg->size, 0); @@ -417,16 +433,6 @@ static struct sp_group *find_or_alloc_sp_group(int spg_id)
init_rwsem(&spg->rw_lock);
- down_write(&sp_group_sem); - ret = idr_alloc(&sp_group_idr, spg, spg_id, spg_id + 1, - GFP_KERNEL); - up_write(&sp_group_sem); - if (ret < 0) { - if (printk_ratelimit()) - pr_err("share pool: create group idr alloc failed\n"); - goto out_kfree; - } - sprintf(name, "sp_group_%d", spg_id); spg->file = shmem_kernel_file_setup(name, MAX_LFS_FILESIZE, VM_NORESERVE); @@ -449,8 +455,15 @@ static struct sp_group *find_or_alloc_sp_group(int spg_id) goto out_fput; } } else { - if (!sp_group_get(spg)) + up_write(&sp_group_sem); + down_read(&spg->rw_lock); + if (!spg_valid(spg)) { + up_read(&spg->rw_lock); + sp_group_drop(spg); return ERR_PTR(-ENODEV); + } + up_read(&spg->rw_lock); + /* spg->use_count has increased due to __sp_find_spg() */ }
return spg; @@ -500,12 +513,6 @@ static void sp_munmap_task_areas(struct mm_struct *mm, struct list_head *stop) spin_unlock(&sp_area_lock); }
-static void sp_group_drop(struct sp_group *spg) -{ - if (atomic_dec_and_test(&spg->use_count)) - free_sp_group(spg); -} - /** * sp_group_add_task - add a process to an sp_group * @pid: the pid of the task to be added @@ -541,9 +548,7 @@ int sp_group_add_task(int pid, int spg_id) }
if (spg_id >= SPG_ID_AUTO_MIN && spg_id <= SPG_ID_AUTO_MAX) { - down_read(&sp_group_sem); - spg = idr_find(&sp_group_idr, spg_id); - up_read(&sp_group_sem); + spg = __sp_find_spg(pid, spg_id);
if (!spg) { if (printk_ratelimit()) @@ -557,9 +562,12 @@ int sp_group_add_task(int pid, int spg_id) if (printk_ratelimit()) pr_err("share pool: task add group failed because group id %d " "is dead\n", spg_id); + sp_group_drop(spg); return -EINVAL; } up_read(&spg->rw_lock); + + sp_group_drop(spg); }
if (spg_id == SPG_ID_AUTO) { @@ -778,6 +786,7 @@ void sp_group_post_exit(struct mm_struct *mm) idr_remove(&sp_stat_idr, mm->sp_stat_id); up_write(&sp_stat_sem);
+ /* match with sp_group_add_task -> find_or_alloc_sp_group */ sp_group_drop(spg);
kfree(stat); @@ -1286,12 +1295,12 @@ static unsigned long sp_mmap(struct mm_struct *mm, struct file *file, */ void *sp_alloc(unsigned long size, unsigned long sp_flags, int spg_id) { - struct sp_group *spg = NULL; + struct sp_group *spg, *spg_tmp; struct sp_area *spa = NULL; struct sp_proc_stat *stat; unsigned long sp_addr; unsigned long mmap_addr; - void *p = ERR_PTR(-ENODEV); + void *p; /* return value */ struct mm_struct *mm; struct file *file; unsigned long size_aligned; @@ -1339,21 +1348,28 @@ void *sp_alloc(unsigned long size, unsigned long sp_flags, int spg_id) return ERR_PTR(ret); } spg = current->mm->sp_group; + /* + * increase use_count deliberately, due to __sp_find_spg is + * matched with sp_group_drop + */ + atomic_inc(&spg->use_count); } else { /* other scenes */ if (spg_id != SPG_ID_DEFAULT) { - down_read(&sp_group_sem); - /* the caller should be a member of the sp group */ - if (spg != idr_find(&sp_group_idr, spg_id)) { - up_read(&sp_group_sem); - return ERR_PTR(-EINVAL); + spg_tmp = __sp_find_spg(current->pid, spg_id); + if (spg != spg_tmp) { + sp_group_drop(spg); + if (spg_tmp) + sp_group_drop(spg_tmp); + return ERR_PTR(-ENODEV); } - up_read(&sp_group_sem); + sp_group_drop(spg_tmp); } }
down_read(&spg->rw_lock); if (!spg_valid(spg)) { up_read(&spg->rw_lock); + sp_group_drop(spg); pr_err("share pool: sp alloc failed, spg is dead\n"); return ERR_PTR(-ENODEV); } @@ -1481,6 +1497,8 @@ void *sp_alloc(unsigned long size, unsigned long sp_flags, int spg_id) if (spa && !IS_ERR(spa)) __sp_area_drop(spa);
+ sp_group_drop(spg); + sp_dump_stack(); sp_try_to_compact(); return p; @@ -1797,6 +1815,7 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size, if (printk_ratelimit()) pr_err("share pool: k2spg invalid spg id %d\n", spg_id); uva = ERR_PTR(-EINVAL); + sp_group_drop(spg); goto out; }
@@ -1812,12 +1831,14 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size, "(potential no enough virtual memory when -75): %ld\n", PTR_ERR(spa)); uva = spa; + sp_group_drop(spg); goto out; }
if (!vmalloc_area_set_flag(spa, kva_aligned, VM_SHAREPOOL)) { up_read(&spg->rw_lock); pr_err("share pool: %s: the kva %pK is not valid\n", __func__, (void *)kva_aligned); + sp_group_drop(spg); goto out_drop_spa; }
@@ -1832,6 +1853,7 @@ void *sp_make_share_k2u(unsigned long kva, unsigned long size, uva = ERR_PTR(-ENODEV); } up_read(&spg->rw_lock); + sp_group_drop(spg);
accounting: if (!IS_ERR(uva)) { @@ -2472,6 +2494,7 @@ bool sp_config_dvpp_range(size_t start, size_t size, int device_id, int pid)
host_svm_sp_enable = true;
+ sp_group_drop(spg); return true; } EXPORT_SYMBOL_GPL(sp_config_dvpp_range); @@ -2533,8 +2556,10 @@ int proc_sp_group_state(struct seq_file *m, struct pid_namespace *ns,
/* eliminate potential ABBA deadlock */ stat = sp_get_proc_stat(task->mm->sp_stat_id); - if (!stat) + if (unlikely(!stat)) { + sp_group_drop(spg); return 0; + }
/* print the file header */ seq_printf(m, "%-8s %-9s %-13s\n", @@ -2543,10 +2568,13 @@ int proc_sp_group_state(struct seq_file *m, struct pid_namespace *ns, spg_id, byte2kb(atomic64_read(&stat->alloc_size)), hugepage_failures); + + sp_group_drop(spg); return 0; } up_read(&spg->rw_lock);
+ sp_group_drop(spg); return 0; }
@@ -2755,6 +2783,7 @@ static int idr_proc_stat_cb(int id, void *p, void *data) sp_res = byte2kb(atomic64_read(&spg->alloc_size)); } up_read(&spg->rw_lock); + sp_group_drop(spg);
anon = get_mm_counter(mm, MM_ANONPAGES); file = get_mm_counter(mm, MM_FILEPAGES);