
From: Wenkai Lin <linwenkai6@hisilicon.com> 1.check for build and probe index row size with hardware spec. 2.check for build and probe index row num with output row num. Signed-off-by: Wenkai Lin <linwenkai6@hisilicon.com> Signed-off-by: Qi Tao <taoqi10@huawei.com> --- drv/hisi_dae_join_gather.c | 40 +++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/drv/hisi_dae_join_gather.c b/drv/hisi_dae_join_gather.c index da10550b..b617071a 100644 --- a/drv/hisi_dae_join_gather.c +++ b/drv/hisi_dae_join_gather.c @@ -24,6 +24,8 @@ #define DAE_JOIN_MAX_ROW_NUN 50000 #define DAE_JOIN_MAX_BATCH_NUM 2800 #define DAE_MAX_TABLE_NUM 16 +#define BUILD_INDEX_ROW_SIZE 8 +#define PROBE_INDEX_ROW_SIZE 4 /* align size */ #define DAE_CHAR_ALIGN_SIZE 4 @@ -550,6 +552,7 @@ static void fill_join_gather_info(struct dae_sqe *sqe, struct dae_ext_sqe *ext_s static int check_join_gather_param(struct wd_join_gather_msg *msg) { struct wd_probe_out_info *output = &msg->req.join_req.probe_output; + struct wd_gather_req *greq = &msg->req.gather_req; __u64 row_num; __u64 size; @@ -587,27 +590,50 @@ static int check_join_gather_param(struct wd_join_gather_msg *msg) return -WD_EINVAL; } if (msg->index_type == WD_BATCH_ADDR_INDEX) { - row_num = output->probe_index.row_num << DAE_ADDR_INDEX_SHIFT; - if (output->build_index.row_num != row_num) { - WD_ERR("invalid: build index row number need be %llu\n", row_num); + row_num = msg->req.output_row_num << DAE_ADDR_INDEX_SHIFT; + if (output->build_index.row_num < row_num) { + WD_ERR("build index row number is less than: %llu\n", + row_num); return -WD_EINVAL; } } + + if (output->probe_index.row_size != PROBE_INDEX_ROW_SIZE || + output->build_index.row_size != BUILD_INDEX_ROW_SIZE) { + WD_ERR("build and probe index row size need be %d, %d!\n", + BUILD_INDEX_ROW_SIZE, PROBE_INDEX_ROW_SIZE); + return -WD_EINVAL; + } break; case WD_JOIN_REHASH: case WD_GATHER_CONVERT: break; case WD_GATHER_COMPLETE: + if (!msg->multi_batch_en) { + if (greq->index.row_size != PROBE_INDEX_ROW_SIZE) { + WD_ERR("invalid: probe index row size need be %d!\n", + PROBE_INDEX_ROW_SIZE); + return -WD_EINVAL; + } + break; + } + + if (greq->index.row_size != BUILD_INDEX_ROW_SIZE) { + WD_ERR("invalid: build index row size need be %d!\n", + BUILD_INDEX_ROW_SIZE); + return -WD_EINVAL; + } if (msg->index_type == WD_BATCH_NUMBER_INDEX) { - if (msg->req.gather_req.row_batchs.batch_num > DAE_JOIN_MAX_BATCH_NUM) { + if (greq->row_batchs.batch_num > DAE_JOIN_MAX_BATCH_NUM) { WD_ERR("invalid: gather row batch num is more than %d!\n", DAE_JOIN_MAX_BATCH_NUM); return -WD_EINVAL; } } else { - row_num = output->probe_index.row_num << DAE_ADDR_INDEX_SHIFT; - if (output->build_index.row_num != row_num) { - WD_ERR("invalid: build index row number need be: %llu\n", row_num); + row_num = msg->req.output_row_num << DAE_ADDR_INDEX_SHIFT; + if (greq->index.row_num < row_num) { + WD_ERR("build index row number is less than: %llu\n", + row_num); return -WD_EINVAL; } } -- 2.33.0