[PATCH] uadk: add new alg called lz77_only

From: Chenghai Huang <huangchenghai2@huawei.com> Supports LZ77 encoding for LZ4 without additional offset processing. The output includes literal and sequence (LitLength, MatchLength, Offset). Signed-off-by: Chenghai Huang <huangchenghai2@huawei.com> --- drv/hisi_comp.c | 310 +++++++++++++++++++++++++++++++++++++--------- include/wd_comp.h | 1 + wd_comp.c | 2 +- wd_util.c | 1 + 4 files changed, 255 insertions(+), 59 deletions(-) diff --git a/drv/hisi_comp.c b/drv/hisi_comp.c index 0c36301d..1c9f438f 100644 --- a/drv/hisi_comp.c +++ b/drv/hisi_comp.c @@ -84,6 +84,9 @@ #define OVERFLOW_DATA_SIZE 8 #define SEQ_DATA_SIZE_SHIFT 3 #define ZSTD_FREQ_DATA_SIZE 784 +#define ZSTD_MIN_OUT_SIZE 1000 +#define LZ77_MIN_OUT_SIZE 200 +#define PRICE_MIN_OUT_SIZE 4096 #define ZSTD_LIT_RESV_SIZE 16 #define REPCODE_SIZE 12 @@ -108,6 +111,8 @@ enum alg_type { HW_GZIP, HW_LZ77_ZSTD_PRICE = 0x42, HW_LZ77_ZSTD, + HW_LZ77_ONLY = 0x40, + HW_LZ77_ONLY_PRICE, }; enum hw_state { @@ -616,31 +621,30 @@ static void fill_buf_addr_lz77_zstd(struct hisi_zip_sqe *sqe, sqe->stream_ctx_addr_h = upper_32_bits(ctx_buf); } -static int fill_buf_lz77_zstd(handle_t h_qp, struct hisi_zip_sqe *sqe, - struct wd_comp_msg *msg) +static int lz77_zstd_buf_check(struct wd_comp_msg *msg) { - struct wd_comp_req *req = &msg->req; - struct wd_lz77_zstd_data *data = req->priv; __u32 in_size = msg->req.src_len; - __u32 lits_size = in_size + ZSTD_LIT_RESV_SIZE; __u32 out_size = msg->avail_out; - void *ctx_buf = NULL; + __u32 lits_size = in_size + ZSTD_LIT_RESV_SIZE; + __u32 seq_avail_out = out_size - lits_size; - if (unlikely(!data)) { - WD_ERR("invalid: wd_lz77_zstd_data address is NULL!\n"); - return -WD_EINVAL; + if (unlikely(in_size > ZSTD_MAX_SIZE)) { + WD_ERR("invalid: in_len(%u) of lz77_zstd is out of range!\n", in_size); + return -WD_EINVAL; } - if (unlikely(in_size > ZSTD_MAX_SIZE)) { - WD_ERR("invalid: in_len(%u) of lz77_zstd is out of range!\n", - in_size); + if (unlikely(msg->stream_mode == WD_COMP_STATEFUL && msg->comp_lv < WD_COMP_L9 && + seq_avail_out <= ZSTD_MIN_OUT_SIZE)) { + WD_ERR("invalid: out_len(%u) not enough, %u bytes are minimum!\n", + out_size, ZSTD_MIN_OUT_SIZE + lits_size); return -WD_EINVAL; } - if (unlikely(out_size > HZ_MAX_SIZE)) { - WD_ERR("warning: avail_out(%u) is out of range , will set 8MB size max!\n", - out_size); - out_size = HZ_MAX_SIZE; + if (unlikely(msg->stream_mode == WD_COMP_STATEFUL && msg->comp_lv == WD_COMP_L9 && + seq_avail_out <= PRICE_MIN_OUT_SIZE)) { + WD_ERR("invalid: out_len(%u) not enough, %u bytes are minimum in price mode!\n", + out_size, PRICE_MIN_OUT_SIZE + lits_size); + return -WD_EINVAL; } /* @@ -653,14 +657,92 @@ static int fill_buf_lz77_zstd(handle_t h_qp, struct hisi_zip_sqe *sqe, return -WD_EINVAL; } + return 0; +} + +static int lz77_only_buf_check(struct wd_comp_msg *msg) +{ + __u32 in_size = msg->req.src_len; + __u32 out_size = msg->avail_out; + __u32 lits_size = in_size + ZSTD_LIT_RESV_SIZE; + __u32 seq_avail_out = out_size - lits_size; + + /* lits_size need to be less than 8M when use pbuffer */ + if (unlikely(lits_size > HZ_MAX_SIZE)) { + WD_ERR("invalid: in_len(%u) of lz77_only is out of range!\n", in_size); + return -WD_EINVAL; + } + + if (unlikely(msg->stream_mode == WD_COMP_STATEFUL && msg->comp_lv < WD_COMP_L9 && + seq_avail_out <= LZ77_MIN_OUT_SIZE)) { + WD_ERR("invalid: out_len(%u) not enough, %u bytes are minimum!\n", + out_size, LZ77_MIN_OUT_SIZE + lits_size); + return -WD_EINVAL; + } + + if (unlikely(msg->stream_mode == WD_COMP_STATEFUL && msg->comp_lv == WD_COMP_L9 && + seq_avail_out <= PRICE_MIN_OUT_SIZE)) { + WD_ERR("invalid: out_len(%u) not enough, %u bytes are minimum in price mode!\n", + out_size, PRICE_MIN_OUT_SIZE + lits_size); + return -WD_EINVAL; + } + + /* For lz77_only, the hardware needs 32 Bytes buffer to output the dfx information */ + if (unlikely(out_size < ZSTD_LIT_RESV_SIZE + lits_size)) { + WD_ERR("invalid: output is not enough, %u bytes are minimum!\n", + ZSTD_LIT_RESV_SIZE + lits_size); + return -WD_EINVAL; + } + + return 0; +} + +static int lz77_buf_check(struct wd_comp_msg *msg) +{ + enum wd_comp_alg_type alg_type = msg->alg_type; + + if (alg_type == WD_LZ77_ZSTD) + return lz77_zstd_buf_check(msg); + else if (alg_type == WD_LZ77_ONLY) + return lz77_only_buf_check(msg); + + return 0; +} + +static int fill_buf_lz77_zstd(handle_t h_qp, struct hisi_zip_sqe *sqe, + struct wd_comp_msg *msg) +{ + struct wd_comp_req *req = &msg->req; + struct wd_lz77_zstd_data *data = req->priv; + __u32 in_size = msg->req.src_len; + __u32 lits_size = in_size + ZSTD_LIT_RESV_SIZE; + __u32 seq_avail_out = msg->avail_out - lits_size; + void *ctx_buf = NULL; + int ret; + + if (unlikely(!data)) { + WD_ERR("invalid: wd_lz77_zstd_data address is NULL!\n"); + return -WD_EINVAL; + } + + ret = lz77_buf_check(msg); + if (ret) + return ret; + + if (unlikely(seq_avail_out > HZ_MAX_SIZE)) { + WD_ERR("warning: sequence avail_out(%u) is out of range , will set 8MB size max!\n", + seq_avail_out); + seq_avail_out = HZ_MAX_SIZE; + } + if (msg->ctx_buf) { ctx_buf = msg->ctx_buf + RSV_OFFSET; - if (data->blk_type != COMP_BLK) + if (msg->alg_type == WD_LZ77_ZSTD && data->blk_type != COMP_BLK) memcpy(ctx_buf + CTX_HW_REPCODE_OFFSET, msg->ctx_buf + CTX_REPCODE2_OFFSET, REPCODE_SIZE); } - fill_buf_size_lz77_zstd(sqe, in_size, lits_size, out_size - lits_size); + fill_buf_size_lz77_zstd(sqe, in_size, lits_size, seq_avail_out); fill_buf_addr_lz77_zstd(sqe, req->src, req->dst, req->dst + lits_size, ctx_buf); @@ -685,6 +767,103 @@ static struct wd_datalist *get_seq_start_list(struct wd_comp_req *req) return cur; } +static int lz77_zstd_buf_check_sgl(struct wd_comp_msg *msg, __u32 lits_size) +{ + __u32 in_size = msg->req.src_len; + __u32 out_size = msg->avail_out; + __u32 seq_avail_out; + + if (unlikely(in_size > ZSTD_MAX_SIZE)) { + WD_ERR("invalid: in_len(%u) of lz77_zstd is out of range!\n", in_size); + return -WD_EINVAL; + } + + /* + * For lz77_zstd, the hardware needs 784 Bytes buffer to output + * the frequency information about input data. The sequences + * and frequency data need to be written to an independent sgl + * splited from list_dst. + */ + if (unlikely(lits_size < in_size + ZSTD_LIT_RESV_SIZE)) { + WD_ERR("invalid: output is not enough for literals, at least %u bytes!\n", + ZSTD_FREQ_DATA_SIZE + lits_size); + return -WD_EINVAL; + } else if (unlikely(out_size < ZSTD_FREQ_DATA_SIZE + lits_size)) { + WD_ERR("invalid: output is not enough for sequences, at least %u bytes more!\n", + ZSTD_FREQ_DATA_SIZE + lits_size - out_size); + return -WD_EINVAL; + } + + seq_avail_out = out_size - lits_size; + if (unlikely(msg->stream_mode == WD_COMP_STATEFUL && msg->comp_lv < WD_COMP_L9 && + seq_avail_out <= ZSTD_MIN_OUT_SIZE)) { + WD_ERR("invalid: out_len(%u) not enough, %u bytes are minimum!\n", + out_size, ZSTD_MIN_OUT_SIZE + lits_size); + return -WD_EINVAL; + } + + if (unlikely(msg->stream_mode == WD_COMP_STATEFUL && msg->comp_lv == WD_COMP_L9 && + seq_avail_out <= PRICE_MIN_OUT_SIZE)) { + WD_ERR("invalid: out_len(%u) not enough, %u bytes are minimum in price mode!\n", + out_size, PRICE_MIN_OUT_SIZE + lits_size); + return -WD_EINVAL; + } + + return 0; +} + +static int lz77_only_buf_check_sgl(struct wd_comp_msg *msg, __u32 lits_size) +{ + __u32 in_size = msg->req.src_len; + __u32 out_size = msg->avail_out; + __u32 seq_avail_out; + + /* + * For lz77_only, the hardware needs 32 Bytes buffer to output + * the dfx information. The literals and sequences data need to be written + * to an independent sgl splited from list_dst. + */ + if (unlikely(lits_size < in_size + ZSTD_LIT_RESV_SIZE)) { + WD_ERR("invalid: output is not enough for literals, at least %u bytes!\n", + ZSTD_LIT_RESV_SIZE + lits_size); + return -WD_EINVAL; + } else if (unlikely(out_size < ZSTD_LIT_RESV_SIZE + lits_size)) { + WD_ERR("invalid: output is not enough for sequences, at least %u bytes more!\n", + ZSTD_LIT_RESV_SIZE + lits_size - out_size); + return -WD_EINVAL; + } + + seq_avail_out = out_size - lits_size; + if (unlikely(msg->stream_mode == WD_COMP_STATEFUL && msg->comp_lv < WD_COMP_L9 && + seq_avail_out <= LZ77_MIN_OUT_SIZE)) { + WD_ERR("invalid: out_len(%u) not enough, %u bytes are minimum!\n", + out_size, LZ77_MIN_OUT_SIZE + lits_size); + return -WD_EINVAL; + } + + if (unlikely(msg->stream_mode == WD_COMP_STATEFUL && msg->comp_lv == WD_COMP_L9 && + seq_avail_out <= PRICE_MIN_OUT_SIZE)) { + WD_ERR("invalid: out_len(%u) not enough, %u bytes are minimum in price mode!\n", + out_size, PRICE_MIN_OUT_SIZE + lits_size); + return -WD_EINVAL; + } + + return 0; +} + + +static int lz77_buf_check_sgl(struct wd_comp_msg *msg, __u32 lits_size) +{ + enum wd_comp_alg_type alg_type = msg->alg_type; + + if (alg_type == WD_LZ77_ZSTD) + return lz77_zstd_buf_check_sgl(msg, lits_size); + else if (alg_type == WD_LZ77_ONLY) + return lz77_only_buf_check_sgl(msg, lits_size); + + return 0; +} + static int fill_buf_lz77_zstd_sgl(handle_t h_qp, struct hisi_zip_sqe *sqe, struct wd_comp_msg *msg) { @@ -698,12 +877,6 @@ static int fill_buf_lz77_zstd_sgl(handle_t h_qp, struct hisi_zip_sqe *sqe, __u32 lits_size; int ret; - if (unlikely(in_size > ZSTD_MAX_SIZE)) { - WD_ERR("invalid: in_len(%u) of lz77_zstd is out of range!\n", - in_size); - return -WD_EINVAL; - } - if (unlikely(!data)) { WD_ERR("invalid: wd_lz77_zstd_data address is NULL!\n"); return -WD_EINVAL; @@ -715,26 +888,15 @@ static int fill_buf_lz77_zstd_sgl(handle_t h_qp, struct hisi_zip_sqe *sqe, if (unlikely(!seq_start)) return -WD_EINVAL; + lits_size = hisi_qm_get_list_size(req->list_dst, seq_start); + + ret = lz77_buf_check_sgl(msg, lits_size); + if (ret) + return ret; + data->literals_start = req->list_dst; data->sequences_start = seq_start; - /* - * For lz77_zstd, the hardware needs 784 Bytes buffer to output - * the frequency information about input data. The sequences - * and frequency data need to be written to an independent sgl - * splited from list_dst. - */ - lits_size = hisi_qm_get_list_size(req->list_dst, seq_start); - if (unlikely(lits_size < in_size + ZSTD_LIT_RESV_SIZE)) { - WD_ERR("invalid: output is not enough for literals, %u bytes are minimum!\n", - ZSTD_FREQ_DATA_SIZE + lits_size); - return -WD_EINVAL; - } else if (unlikely(out_size < ZSTD_FREQ_DATA_SIZE + lits_size)) { - WD_ERR("invalid: output is not enough for sequences, at least %u bytes more!\n", - ZSTD_FREQ_DATA_SIZE + lits_size - out_size); - return -WD_EINVAL; - } - fill_buf_size_lz77_zstd(sqe, in_size, lits_size, out_size - lits_size); h_sgl_pool = hisi_qm_get_sglpool(h_qp); @@ -824,6 +986,15 @@ static void fill_alg_lz77_zstd(struct hisi_zip_sqe *sqe) sqe->dw9 = val; } +static void fill_alg_lz77_only(struct hisi_zip_sqe *sqe) +{ + __u32 val; + + val = sqe->dw9 & ~HZ_REQ_TYPE_MASK; + val |= HW_LZ77_ONLY; + sqe->dw9 = val; +} + static void fill_tag_v1(struct hisi_zip_sqe *sqe, __u32 tag) { sqe->dw13 = tag; @@ -841,7 +1012,7 @@ static int fill_comp_level_deflate(struct hisi_zip_sqe *sqe, enum wd_comp_level static int fill_comp_level_lz77_zstd(struct hisi_zip_sqe *sqe, enum wd_comp_level comp_lv) { - __u32 val; + __u32 val, alg; switch (comp_lv) { case WD_COMP_L8: @@ -851,8 +1022,12 @@ static int fill_comp_level_lz77_zstd(struct hisi_zip_sqe *sqe, enum wd_comp_leve */ break; case WD_COMP_L9: + alg = sqe->dw9 & HZ_REQ_TYPE_MASK; val = sqe->dw9 & ~HZ_REQ_TYPE_MASK; - val |= HW_LZ77_ZSTD_PRICE; + if (alg == HW_LZ77_ZSTD) + val |= HW_LZ77_ZSTD_PRICE; + else if (alg == HW_LZ77_ONLY) + val |= HW_LZ77_ONLY_PRICE; sqe->dw9 = val; break; default: @@ -911,18 +1086,22 @@ static void get_data_size_lz77_zstd(struct hisi_zip_sqe *sqe, enum wd_comp_op_ty if (unlikely(!data)) return; + recv_msg->in_cons = sqe->consumed; data->lit_num = sqe->comp_data_length; data->seq_num = sqe->produced; - data->lit_length_overflow_cnt = sqe->dw31 >> LITLEN_OVERFLOW_CNT_SHIFT; - data->lit_length_overflow_pos = sqe->dw31 & LITLEN_OVERFLOW_POS_MASK; - data->freq = data->sequences_start + (data->seq_num << SEQ_DATA_SIZE_SHIFT) + - OVERFLOW_DATA_SIZE; - - if (ctx_buf) { - memcpy(ctx_buf + CTX_REPCODE2_OFFSET, - ctx_buf + CTX_REPCODE1_OFFSET, REPCODE_SIZE); - memcpy(ctx_buf + CTX_REPCODE1_OFFSET, - ctx_buf + RSV_OFFSET + CTX_HW_REPCODE_OFFSET, REPCODE_SIZE); + + if (recv_msg->alg_type == WD_LZ77_ZSTD) { + data->lit_length_overflow_cnt = sqe->dw31 >> LITLEN_OVERFLOW_CNT_SHIFT; + data->lit_length_overflow_pos = sqe->dw31 & LITLEN_OVERFLOW_POS_MASK; + data->freq = data->sequences_start + (data->seq_num << SEQ_DATA_SIZE_SHIFT) + + OVERFLOW_DATA_SIZE; + + if (ctx_buf) { + memcpy(ctx_buf + CTX_REPCODE2_OFFSET, + ctx_buf + CTX_REPCODE1_OFFSET, REPCODE_SIZE); + memcpy(ctx_buf + CTX_REPCODE1_OFFSET, + ctx_buf + RSV_OFFSET + CTX_HW_REPCODE_OFFSET, REPCODE_SIZE); + } } } @@ -970,6 +1149,16 @@ struct hisi_zip_sqe_ops ops[] = { { .fill_comp_level = fill_comp_level_lz77_zstd, .get_data_size = get_data_size_lz77_zstd, .get_tag = get_tag_v3, + }, { + .alg_name = "lz77_only", + .fill_buf[WD_FLAT_BUF] = fill_buf_lz77_zstd, + .fill_buf[WD_SGL_BUF] = fill_buf_lz77_zstd_sgl, + .fill_sqe_type = fill_sqe_type_v3, + .fill_alg = fill_alg_lz77_only, + .fill_tag = fill_tag_v3, + .fill_comp_level = fill_comp_level_lz77_zstd, + .get_data_size = get_data_size_lz77_zstd, + .get_tag = get_tag_v3, } }; @@ -1079,10 +1268,6 @@ static int fill_zip_comp_sqe(struct hisi_qp *qp, struct wd_comp_msg *msg, return -WD_EINVAL; } - ret = ops[alg_type].fill_comp_level(sqe, msg->comp_lv); - if (unlikely(ret)) - return ret; - ret = ops[alg_type].fill_buf[msg->req.data_fmt]((handle_t)qp, sqe, msg); if (unlikely(ret)) return ret; @@ -1091,6 +1276,10 @@ static int fill_zip_comp_sqe(struct hisi_qp *qp, struct wd_comp_msg *msg, ops[alg_type].fill_alg(sqe); + ret = ops[alg_type].fill_comp_level(sqe, msg->comp_lv); + if (unlikely(ret)) + return ret; + ops[alg_type].fill_tag(sqe, msg->tag); state = (msg->stream_mode == WD_COMP_STATEFUL) ? HZ_STATEFUL : @@ -1132,7 +1321,7 @@ static void free_hw_sgl(handle_t h_qp, struct hisi_zip_sqe *sqe, hw_sgl_out = VA_ADDR(sqe->dest_addr_h, sqe->dest_addr_l); hisi_qm_put_hw_sgl(h_sgl_pool, hw_sgl_out); - if (alg_type == WD_LZ77_ZSTD) { + if (alg_type == WD_LZ77_ZSTD || alg_type == WD_LZ77_ONLY) { hw_sgl_out = VA_ADDR(sqe->literals_addr_h, sqe->literals_addr_l); hisi_qm_put_hw_sgl(h_sgl_pool, hw_sgl_out); @@ -1190,6 +1379,10 @@ static int get_alg_type(__u32 type) case HW_LZ77_ZSTD_PRICE: alg_type = WD_LZ77_ZSTD; break; + case HW_LZ77_ONLY: + case HW_LZ77_ONLY_PRICE: + alg_type = WD_LZ77_ONLY; + break; default: break; } @@ -1369,6 +1562,7 @@ static struct wd_alg_driver zip_alg_driver[] = { GEN_ZIP_ALG_DRIVER("deflate"), GEN_ZIP_ALG_DRIVER("lz77_zstd"), + GEN_ZIP_ALG_DRIVER("lz77_only"), }; #ifdef WD_STATIC_DRV diff --git a/include/wd_comp.h b/include/wd_comp.h index 45994ff6..0012ef6b 100644 --- a/include/wd_comp.h +++ b/include/wd_comp.h @@ -20,6 +20,7 @@ enum wd_comp_alg_type { WD_ZLIB, WD_GZIP, WD_LZ77_ZSTD, + WD_LZ77_ONLY, WD_COMP_ALG_MAX, }; diff --git a/wd_comp.c b/wd_comp.c index 647c320e..8e47a32f 100644 --- a/wd_comp.c +++ b/wd_comp.c @@ -27,7 +27,7 @@ #define cpu_to_be32(x) swap_byte(x) static const char *wd_comp_alg_name[WD_COMP_ALG_MAX] = { - "zlib", "gzip", "deflate", "lz77_zstd" + "zlib", "gzip", "deflate", "lz77_zstd", "lz77_only" }; struct wd_comp_sess { diff --git a/wd_util.c b/wd_util.c index 669743cb..f21b3236 100644 --- a/wd_util.c +++ b/wd_util.c @@ -107,6 +107,7 @@ static struct acc_alg_item alg_options[] = { {"gzip", "gzip"}, {"deflate", "deflate"}, {"lz77_zstd", "lz77_zstd"}, + {"lz77_only", "lz77_only"}, {"hashagg", "hashagg"}, {"udma", "udma"}, -- 2.33.0

From: Chenghai Huang <huangchenghai2@huawei.com> 1.When determining whether the output size meets the threshold requirements, only compression needs to consider the length of the head size. 2.check_store_buf will not return value less than 0. 3.msg will not be null, because it is a local variable in wd_comp. Signed-off-by: Chenghai Huang <huangchenghai2@huawei.com> --- drv/hisi_comp.c | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/drv/hisi_comp.c b/drv/hisi_comp.c index 98b45d71..0c36301d 100644 --- a/drv/hisi_comp.c +++ b/drv/hisi_comp.c @@ -343,15 +343,14 @@ static int check_enable_store_buf(struct wd_comp_msg *msg, __u32 out_size, int h if (msg->stream_mode != WD_COMP_STATEFUL) return 0; - if (msg->stream_pos != WD_COMP_STREAM_NEW && out_size > SW_STOREBUF_TH) - return 0; + if (msg->stream_pos == WD_COMP_STREAM_NEW && msg->req.op_type == WD_DIR_COMPRESS && + out_size - head_size <= SW_STOREBUF_TH) + return 1; - if (msg->stream_pos == WD_COMP_STREAM_NEW && - out_size - head_size > SW_STOREBUF_TH) - return 0; + if (out_size <= SW_STOREBUF_TH) + return 1; - /* 1 mean it need store buf */ - return 1; + return 0; } static void fill_buf_size_deflate(struct hisi_zip_sqe *sqe, __u32 in_size, @@ -386,7 +385,7 @@ static int fill_buf_deflate_generic(struct hisi_zip_sqe *sqe, /* * When the output buffer is smaller than the SW_STOREBUF_TH in STATEFUL, - * the internal buffer is used. + * the internal buffer is used. It requires a storage buffer when returning 1. */ ret = check_enable_store_buf(msg, out_size, head_size); if (ret) { @@ -524,7 +523,7 @@ static int fill_buf_deflate_sgl_generic(handle_t h_qp, struct hisi_zip_sqe *sqe, /* * When the output buffer is smaller than the SW_STOREBUF_TH in STATEFUL, - * the internal buffer is used. + * the internal buffer is used. It requires a storage buffer when returning 1. */ ret = check_enable_store_buf(msg, out_size, head_size); if (ret) { @@ -1152,7 +1151,7 @@ static int hisi_zip_comp_send(struct wd_alg_driver *drv, handle_t ctx, void *com /* Skip hardware, if the store buffer need to be copied to output */ ret = check_store_buf(msg); if (ret) - return ret < 0 ? ret : 0; + return 0; hisi_set_msg_id(h_qp, &msg->tag); ret = fill_zip_comp_sqe(qp, msg, &sqe); @@ -1322,7 +1321,7 @@ static int hisi_zip_comp_recv(struct wd_alg_driver *drv, handle_t ctx, void *com __u16 count = 0; int ret; - if (recv_msg && recv_msg->ctx_buf) { + if (recv_msg->ctx_buf) { buf = (struct hisi_comp_buf *)(recv_msg->ctx_buf + CTX_STOREBUF_OFFSET); /* * The output has been copied from the storage buffer, -- 2.33.0

From: Wenkai Lin <linwenkai6@hisilicon.com> Uadk will change output address information before rehash input task, so it is needed to restore these information before next rehash task. Signed-off-by: Wenkai Lin <linwenkai6@hisilicon.com> --- drv/hisi_dae.c | 3 ++- wd_agg.c | 63 +++++++++++++++++++++++++++++--------------------- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/drv/hisi_dae.c b/drv/hisi_dae.c index 2ab0e7f5..d6f96caa 100644 --- a/drv/hisi_dae.c +++ b/drv/hisi_dae.c @@ -531,8 +531,9 @@ static void fill_hashagg_msg_task_done(struct dae_sqe *sqe, struct wd_agg_msg *m msg->out_row_count = sqe->out_raw_num; msg->output_done = sqe->output_end; } else if (sqe->task_type_ext == DAE_HASHAGG_MERGE) { - msg->out_row_count = temp_msg->row_count; msg->output_done = sqe->output_end; + if (!msg->output_done) + msg->out_row_count = temp_msg->row_count; } else { msg->in_row_count = temp_msg->row_count; } diff --git a/wd_agg.c b/wd_agg.c index 8869ab84..686e7393 100644 --- a/wd_agg.c +++ b/wd_agg.c @@ -1406,16 +1406,16 @@ static int wd_agg_set_col_size(struct wd_agg_sess *sess, struct wd_agg_req *req, return WD_SUCCESS; } -static int wd_agg_rehash_sync_inner(struct wd_agg_sess *sess, struct wd_agg_req *req) +static int wd_agg_rehash_sync_inner(struct wd_agg_sess *sess, struct wd_agg_req *in_req, + struct wd_agg_req *out_req) { + struct wd_agg_msg in_msg = {0}; struct wd_agg_msg msg = {0}; - bool output_done; int ret; - fill_request_msg_output(&msg, req, sess, true); - req->state = 0; + fill_request_msg_output(&msg, out_req, sess, true); - ret = wd_agg_sync_job(sess, req, &msg); + ret = wd_agg_sync_job(sess, out_req, &msg); if (unlikely(ret)) return ret; @@ -1423,33 +1423,26 @@ static int wd_agg_rehash_sync_inner(struct wd_agg_sess *sess, struct wd_agg_req if (unlikely(ret)) return ret; - req->real_out_row_count = msg.out_row_count; - output_done = msg.output_done; if (!msg.out_row_count) { - req->output_done = true; + out_req->output_done = true; return WD_SUCCESS; } - req->key_cols = req->out_key_cols; - req->agg_cols = req->out_agg_cols; - req->key_cols_num = req->out_key_cols_num; - req->agg_cols_num = req->out_agg_cols_num; - wd_agg_set_col_size(sess, req, req->real_out_row_count); - req->in_row_count = req->real_out_row_count; + out_req->real_out_row_count = msg.out_row_count; + wd_agg_set_col_size(sess, in_req, out_req->real_out_row_count); + in_req->in_row_count = out_req->real_out_row_count; - memset(&msg, 0, sizeof(struct wd_agg_msg)); - fill_request_msg_input(&msg, req, sess, true); + fill_request_msg_input(&in_msg, in_req, sess, true); - ret = wd_agg_sync_job(sess, req, &msg); + ret = wd_agg_sync_job(sess, in_req, &in_msg); if (unlikely(ret)) return ret; - ret = wd_agg_check_msg_result(msg.result); + ret = wd_agg_check_msg_result(in_msg.result); if (unlikely(ret)) return ret; - req->state = msg.result; - req->output_done = output_done; + out_req->output_done = msg.output_done; return WD_SUCCESS; } @@ -1472,9 +1465,9 @@ int wd_agg_rehash_sync(handle_t h_sess, struct wd_agg_req *req) { struct wd_agg_sess *sess = (struct wd_agg_sess *)h_sess; enum wd_agg_sess_state expected = WD_AGG_SESS_RESET; - struct wd_agg_req src_req; - __u64 cnt = 0; - __u64 max_cnt; + struct wd_dae_col_addr *cols; + struct wd_agg_req in_req; + __u64 max_cnt, key_len, agg_len, cnt = 0; int ret; ret = wd_agg_check_rehash_params(sess, req); @@ -1487,21 +1480,39 @@ int wd_agg_rehash_sync(handle_t h_sess, struct wd_agg_req *req) if (unlikely(ret)) return ret; - memcpy(&src_req, req, sizeof(struct wd_agg_req)); + memcpy(&in_req, req, sizeof(struct wd_agg_req)); + + key_len = req->out_key_cols_num * sizeof(struct wd_dae_col_addr); + agg_len = req->out_agg_cols_num * sizeof(struct wd_dae_col_addr); + cols = malloc(key_len + agg_len); + if (unlikely(!cols)) + return -WD_ENOMEM; + + /* The input task uses the address of the output task as input address. */ + in_req.key_cols = cols; + in_req.agg_cols = cols + req->out_key_cols_num; + in_req.key_cols_num = req->out_key_cols_num; + in_req.agg_cols_num = req->out_agg_cols_num; + memcpy(in_req.key_cols, req->out_key_cols, key_len); + memcpy(in_req.agg_cols, req->out_agg_cols, agg_len); + max_cnt = MAX_HASH_TABLE_ROW_NUM / req->out_row_count; + while (cnt < max_cnt) { - ret = wd_agg_rehash_sync_inner(sess, &src_req); + ret = wd_agg_rehash_sync_inner(sess, &in_req, req); if (ret) { __atomic_store_n(&sess->state, WD_AGG_SESS_RESET, __ATOMIC_RELEASE); WD_ERR("failed to do agg rehash task!\n"); + free(cols); return ret; } - if (src_req.output_done) + if (req->output_done) break; cnt++; } __atomic_store_n(&sess->state, WD_AGG_SESS_INPUT, __ATOMIC_RELEASE); + free(cols); return WD_SUCCESS; } -- 2.33.0

From: Zhushuai Yin <yinzhushuai@huawei.com> In the new comp strem scenario, there is no need to check the buffer length,only to check if it is non-zero. Signed-off-by: Zhushuai Yin <yinzhushuai@huawei.com> --- wd_comp.c | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/wd_comp.c b/wd_comp.c index 0871abf6..0f88589a 100644 --- a/wd_comp.c +++ b/wd_comp.c @@ -524,34 +524,40 @@ static void fill_comp_msg(struct wd_comp_sess *sess, struct wd_comp_msg *msg, msg->req.last = 1; } -static int wd_check_alg_buff_size(__u32 dst_len, __u32 src_len, - enum wd_comp_alg_type alg_type, enum wd_comp_op_type op_type) +static int wd_check_alg_buff_size(struct wd_comp_req *req, struct wd_comp_sess *sess) { - if (!dst_len) { + if (!req->dst_len) { WD_ERR("invalid: dst_len is 0!\n"); return -WD_EINVAL; } - if (alg_type == WD_ZLIB) { - if (dst_len <= WD_ZLIB_HEADER_SZ && op_type == WD_DIR_COMPRESS) { - WD_ERR("invalid: zlib dst_len(%u) is too samll!\n", dst_len); + /* + * Only the first package needs to be checked, + * the middle and last packages do not need to be checked + */ + if (sess->stream_pos != WD_COMP_STREAM_NEW) + return 0; + + if (sess->alg_type == WD_ZLIB) { + if (req->dst_len <= WD_ZLIB_HEADER_SZ && req->op_type == WD_DIR_COMPRESS) { + WD_ERR("invalid: zlib dst_len(%u) is too small!\n", req->dst_len); return -WD_EINVAL; } - if (src_len <= WD_ZLIB_HEADER_SZ && op_type == WD_DIR_DECOMPRESS) { - WD_ERR("invalid: zlib src_len(%u) is too samll!\n", src_len); + if (req->src_len <= WD_ZLIB_HEADER_SZ && req->op_type == WD_DIR_DECOMPRESS) { + WD_ERR("invalid: zlib src_len(%u) is too small!\n", req->src_len); return -WD_EINVAL; } } - if (alg_type == WD_GZIP) { - if (dst_len <= WD_GZIP_HEADER_SZ && op_type == WD_DIR_COMPRESS) { - WD_ERR("invalid: gzip dst_len(%u) is too samll!\n", dst_len); + if (sess->alg_type == WD_GZIP) { + if (req->dst_len <= WD_GZIP_HEADER_SZ && req->op_type == WD_DIR_COMPRESS) { + WD_ERR("invalid: gzip dst_len(%u) is too small!\n", req->dst_len); return -WD_EINVAL; } - if (src_len <= WD_GZIP_HEADER_SZ && op_type == WD_DIR_DECOMPRESS) { - WD_ERR("invalid: gzip src_len(%u) is too samll!\n", src_len); + if (req->src_len <= WD_GZIP_HEADER_SZ && req->op_type == WD_DIR_DECOMPRESS) { + WD_ERR("invalid: gzip src_len(%u) is too small!\n", req->src_len); return -WD_EINVAL; } } @@ -559,7 +565,7 @@ static int wd_check_alg_buff_size(__u32 dst_len, __u32 src_len, return 0; } -static int wd_comp_check_buffer(struct wd_comp_req *req, enum wd_comp_alg_type alg_type) +static int wd_comp_check_buffer(struct wd_comp_req *req, struct wd_comp_sess *sess) { if (req->data_fmt == WD_FLAT_BUF) { if (unlikely(!req->src || !req->dst)) { @@ -573,7 +579,7 @@ static int wd_comp_check_buffer(struct wd_comp_req *req, enum wd_comp_alg_type a } } - return wd_check_alg_buff_size(req->dst_len, req->src_len, alg_type, req->op_type); + return wd_check_alg_buff_size(req, sess); } static int wd_comp_check_params(struct wd_comp_sess *sess, @@ -592,7 +598,7 @@ static int wd_comp_check_params(struct wd_comp_sess *sess, return -WD_EINVAL; } - ret = wd_comp_check_buffer(req, sess->alg_type); + ret = wd_comp_check_buffer(req, sess); if (unlikely(ret)) return ret; -- 2.33.0

From: Wenkai Lin <linwenkai6@hisilicon.com> UADK supports hardware acceleration for the hashjoin and gather. Hashjoin is used to construct a hash table to join two tables, gather is used to combine data of different types in multiple columns in a specified order to obtain a new column. Signed-off-by: Wenkai Lin <linwenkai6@hisilicon.com> --- Makefile.am | 8 +- drv/hisi_dae.c | 602 +--------- drv/hisi_dae.h | 229 ++++ drv/hisi_dae_common.c | 387 +++++++ drv/hisi_dae_join_gather.c | 1040 +++++++++++++++++ include/drv/wd_join_gather_drv.h | 52 + include/wd_alg.h | 2 + include/wd_dae.h | 12 + include/wd_join_gather.h | 352 ++++++ include/wd_util.h | 1 + libwd_dae.map | 19 + wd_join_gather.c | 1823 ++++++++++++++++++++++++++++++ wd_util.c | 6 +- 13 files changed, 3959 insertions(+), 574 deletions(-) create mode 100644 drv/hisi_dae.h create mode 100644 drv/hisi_dae_common.c create mode 100644 drv/hisi_dae_join_gather.c create mode 100644 include/drv/wd_join_gather_drv.h create mode 100644 include/wd_join_gather.h create mode 100644 wd_join_gather.c diff --git a/Makefile.am b/Makefile.am index 6af4295a..ae6db732 100644 --- a/Makefile.am +++ b/Makefile.am @@ -47,7 +47,7 @@ pkginclude_HEADERS = include/wd.h include/wd_cipher.h include/wd_aead.h \ include/wd_rsa.h include/uacce.h include/wd_alg_common.h \ include/wd_ecc.h include/wd_sched.h include/wd_alg.h \ include/wd_zlibwrapper.h include/wd_dae.h include/wd_agg.h \ - include/wd_udma.h + include/wd_udma.h include/wd_join_gather.h nobase_pkginclude_HEADERS=v1/wd.h v1/wd_cipher.h v1/wd_aead.h v1/uacce.h v1/wd_dh.h \ v1/wd_digest.h v1/wd_rsa.h v1/wd_bmm.h v1/wd_ecc.h v1/wd_comp.h @@ -81,7 +81,7 @@ libwd_la_SOURCES=wd.c wd_mempool.c wd.h wd_alg.c wd_alg.h \ libwd_udma_la_SOURCES=wd_udma.h wd_udma_drv.h wd_udma.c \ wd_util.c wd_util.h wd_sched.c wd_sched.h wd.c wd.h -libwd_dae_la_SOURCES=wd_dae.h wd_agg.h wd_agg_drv.h wd_agg.c \ +libwd_dae_la_SOURCES=wd_dae.h wd_agg.h wd_agg_drv.h wd_agg.c wd_join_gather.h wd_join_gather_drv.h wd_join_gather.c \ wd_util.c wd_util.h wd_sched.c wd_sched.h wd.c wd.h libwd_comp_la_SOURCES=wd_comp.c wd_comp.h wd_comp_drv.h wd_util.c wd_util.h \ @@ -117,8 +117,8 @@ libisa_sve_la_SOURCES=drv/hash_mb/hash_mb.c wd_digest_drv.h drv/hash_mb/hash_mb. drv/hash_mb/md5_sve_common.S drv/hash_mb/md5_mb_asimd_x1.S \ drv/hash_mb/md5_mb_asimd_x4.S drv/hash_mb/md5_mb_sve.S -libhisi_dae_la_SOURCES=drv/hisi_dae.c drv/hisi_qm_udrv.c \ - hisi_qm_udrv.h +libhisi_dae_la_SOURCES=drv/hisi_dae.c hisi_dae.h drv/hisi_qm_udrv.c \ + hisi_qm_udrv.h drv/hisi_dae_join_gather.c drv/hisi_dae_common.c libhisi_udma_la_SOURCES=drv/hisi_udma.c drv/hisi_qm_udrv.c \ hisi_qm_udrv.h diff --git a/drv/hisi_dae.c b/drv/hisi_dae.c index 57758a2d..d6f96caa 100644 --- a/drv/hisi_dae.c +++ b/drv/hisi_dae.c @@ -1,19 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 /* Copyright 2024 Huawei Technologies Co.,Ltd. All rights reserved. */ -#include <math.h> -#include <stdint.h> -#include <stdlib.h> -#include <stdio.h> -#include <unistd.h> -#include <sys/epoll.h> -#include <sys/eventfd.h> -#include <sys/mman.h> -#include <sys/types.h> #include "hisi_qm_udrv.h" +#include "hisi_dae.h" #include "../include/drv/wd_agg_drv.h" -#define DAE_HASH_AGG_TYPE 2 #define DAE_EXT_SQE_SIZE 128 #define DAE_CTX_Q_NUM_DEF 1 @@ -39,37 +30,14 @@ /* align size */ #define DAE_CHAR_ALIGN_SIZE 4 -#define DAE_TABLE_ALIGN_SIZE 128 -#define DAE_ADDR_ALIGN_SIZE 128 - -/* decimal infomartion */ -#define DAE_DECIMAL_PRECISION_OFFSET 8 -#define DAE_DECIMAL128_MAX_PRECISION 38 -#define DAE_DECIMAL64_MAX_PRECISION 18 /* hash table */ -#define HASH_EXT_TABLE_INVALID_OFFSET 5 -#define HASH_EXT_TABLE_VALID 0x80 #define HASH_TABLE_HEAD_TAIL_SIZE 8 #define HASH_TABLE_EMPTY_SIZE 4 -#define HASH_TABLE_WITDH_POWER 2 -#define HASH_TABLE_MIN_WIDTH 10 -#define HASH_TABLE_MAX_WIDTH 43 -#define HASH_TABLE_OFFSET_3ROW 3 -#define HASH_TABLE_OFFSET_1ROW 1 /* hash agg operations col max num */ #define DAE_AGG_COL_ALG_MAX_NUM 2 -#define __ALIGN_MASK(x, mask) (((x) + (mask)) & ~(mask)) -#define ALIGN(x, a) __ALIGN_MASK(x, (typeof(x))(a)-1) -#define PTR_ALIGN(p, a) ((typeof(p))ALIGN((uintptr_t)(p), (a))) - -#define BIT(nr) (1UL << (nr)) -#define BITS_PER_LONG (__SIZEOF_LONG__ * 8) -#define GENMASK(h, l) \ - (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h)))) - /* DAE hardware protocol data */ enum dae_stage { DAE_HASH_AGGREGATE = 0x0, @@ -85,48 +53,6 @@ enum dae_op_type { DAE_SUM = 0x5, }; -enum dae_done_flag { - DAE_HW_TASK_NOT_PROCESS = 0x0, - DAE_HW_TASK_DONE = 0x1, - DAE_HW_TASK_ERR = 0x2, -}; - -enum dae_error_type { - DAE_TASK_SUCCESS = 0x0, - DAE_TASK_BD_ERROR_MIN = 0x1, - DAE_TASK_BD_ERROR_MAX = 0x7f, - DAE_HASH_TABLE_NEED_REHASH = 0x82, - DAE_HASH_TABLE_INVALID = 0x83, - DAE_HASHAGG_VCHAR_OVERFLOW = 0x84, - DAE_HASHAGG_RESULT_OVERFLOW = 0x85, - DAE_HASHAGG_BUS_ERROR = 0x86, - DAE_HASHAGG_VCHAR_LEN_ERROR = 0x87, -}; - -enum dae_data_type { - DAE_SINT32 = 0x0, - DAE_SINT64 = 0x2, - DAE_DECIMAL64 = 0x9, - DAE_DECIMAL128 = 0xA, - DAE_CHAR = 0xC, - DAE_VCHAR = 0xD, -}; - -enum dae_date_type_size { - SINT32_SIZE = 4, - SINT64_SIZE = 8, - DECIMAL128_SIZE = 16, - DEFAULT_VCHAR_SIZE = 30, -}; - -enum dae_table_row_size { - ROW_SIZE32 = 32, - ROW_SIZE64 = 64, - ROW_SIZE128 = 128, - ROW_SIZE256 = 256, - ROW_SIZE512 = 512, -}; - enum dae_sum_optype { DECIMAL64_TO_DECIMAL64 = 0x2, DECIMAL64_TO_DECIMAL128 = 0x3, @@ -139,99 +65,6 @@ enum dae_alg_optype { DAE_HASHAGG_MIN = 0x8, }; -enum dae_bd_type { - DAE_BD_TYPE_V1 = 0x0, - DAE_BD_TYPE_V2 = 0x1, -}; - -struct dae_sqe { - __u32 bd_type : 6; - __u32 resv1 : 2; - __u32 task_type : 6; - __u32 resv2 : 2; - __u32 task_type_ext : 6; - __u32 resv3 : 9; - __u32 bd_invlid : 1; - __u16 table_row_size; - __u16 resv4; - __u32 resv5; - __u32 low_tag; - __u32 hi_tag; - __u32 row_num; - __u32 resv6; - __u32 src_table_width : 6; - __u32 dst_table_width : 6; - __u32 resv7 : 4; - __u32 counta_vld : 1; - __u32 resv8 : 15; - /* - * high 4bits: compare mode if data type is char/vchar, - * out type if operation is sum. - * low 4bits: input value type. - */ - __u8 key_data_type[16]; - __u8 agg_data_type[16]; - __u32 resv9[8]; - __u32 key_col_bitmap; - __u32 agg_col_bitmap; - __u64 addr_list; - __u32 done_flag : 3; - __u32 output_end : 1; - __u32 ext_err_type : 12; - __u32 err_type : 8; - __u32 wtype : 8; - __u32 out_raw_num; - __u32 vchar_err_offset; - __u16 sum_overflow_cols; - __u16 resv10; -}; - -struct dae_ext_sqe { - /* - * If date type is char/vchar, data info fill data type size - * If data type is decimal64/decimal128, data info fill data precision - */ - __u16 key_data_info[16]; - __u16 agg_data_info[16]; - /* Aggregated output from input agg col index */ - __u64 out_from_in_idx; - /* Aggregated output from input agg col operation, sum or count */ - __u64 out_optype; - __u32 resv[12]; -}; - -struct dae_col_addr { - __u64 empty_addr; - __u64 empty_size; - __u64 value_addr; - __u64 value_size; -}; - -struct dae_table_addr { - __u64 std_table_addr; - __u64 std_table_size; - __u64 ext_table_addr; - __u64 ext_table_size; -}; - -struct dae_addr_list { - __u64 ext_sqe_addr; - __u64 ext_sqe_size; - struct dae_table_addr src_table; - struct dae_table_addr dst_table; - __u64 resv_addr[6]; - struct dae_col_addr input_addr[32]; - struct dae_col_addr output_addr[32]; -}; - -struct dae_extend_addr { - struct dae_ext_sqe *ext_sqe; - struct dae_addr_list *addr_list; - __u8 *addr_status; - __u16 addr_num; - __u16 tail; -}; - static enum dae_data_type hw_data_type_order[] = { DAE_VCHAR, DAE_CHAR, DAE_DECIMAL128, DAE_DECIMAL64, DAE_SINT64, DAE_SINT32, @@ -265,14 +98,6 @@ struct hashagg_col_data { bool is_count_all; }; -struct hash_table_data { - void *std_table; - void *ext_table; - __u64 std_table_size; - __u64 ext_table_size; - __u32 table_width; -}; - struct hashagg_ctx { struct hashagg_col_data cols_data; struct hash_table_data table_data; @@ -282,33 +107,6 @@ struct hashagg_ctx { __u16 sum_overflow_cols; }; -struct hisi_dae_ctx { - struct wd_ctx_config_internal config; -}; - -static int get_free_ext_addr(struct dae_extend_addr *ext_addr) -{ - __u16 addr_num = ext_addr->addr_num; - __u16 idx = ext_addr->tail; - __u16 cnt = 0; - - while (__atomic_test_and_set(&ext_addr->addr_status[idx], __ATOMIC_ACQUIRE)) { - idx = (idx + 1) % addr_num; - cnt++; - if (cnt == addr_num) - return -WD_EBUSY; - } - - ext_addr->tail = (idx + 1) % addr_num; - - return idx; -} - -static void put_ext_addr(struct dae_extend_addr *ext_addr, int idx) -{ - __atomic_clear(&ext_addr->addr_status[idx], __ATOMIC_RELEASE); -} - static void fill_hashagg_task_type(struct wd_agg_msg *msg, struct dae_sqe *sqe, __u16 hw_type) { /* @@ -670,7 +468,7 @@ static int hashagg_send(struct wd_alg_driver *drv, handle_t ctx, void *hashagg_m return WD_SUCCESS; fill_hashagg_task_type(msg, &sqe, qp->q_info.hw_type); - sqe.row_num = msg->row_count; + sqe.data_row_num = msg->row_count; idx = get_free_ext_addr(ext_addr); if (idx < 0) @@ -757,15 +555,15 @@ static void fill_hashagg_msg_task_err(struct dae_sqe *sqe, struct wd_agg_msg *ms break; case DAE_HASHAGG_VCHAR_OVERFLOW: WD_ERR("failed to do hashagg task, vchar size overflow! consumed row num: %u!\n", - sqe->vchar_err_offset); + sqe->data_row_offset); msg->result = WD_AGG_INVALID_VARCHAR; - msg->in_row_count = sqe->vchar_err_offset; + msg->in_row_count = sqe->data_row_offset; break; case DAE_HASHAGG_RESULT_OVERFLOW: msg->in_row_count = temp_msg->row_count; msg->result = WD_AGG_SUM_OVERFLOW; break; - case DAE_HASHAGG_BUS_ERROR: + case DAE_TASK_BUS_ERROR: WD_ERR("failed to do hashagg task, bus error! etype %u!\n", sqe->err_type); msg->result = WD_AGG_BUS_ERROR; break; @@ -965,31 +763,11 @@ static int hashagg_init_param_check(struct wd_agg_sess_setup *setup, __u16 hw_ty setup->is_count_all, hw_type); } -static __u32 hashagg_get_data_type_size(enum dae_data_type type, __u16 data_info) -{ - switch (type) { - case DAE_SINT32: - return SINT32_SIZE; - case DAE_SINT64: - case DAE_DECIMAL64: - return SINT64_SIZE; - case DAE_DECIMAL128: - return DECIMAL128_SIZE; - case DAE_CHAR: - return ALIGN(data_info, DAE_CHAR_ALIGN_SIZE); - case DAE_VCHAR: - return data_info; - default: - break; - } - - return 0; -} - static int transfer_key_col_info(struct wd_key_col_info *key_cols, struct hw_agg_data *key_data, __u32 col_num) { __u32 i; + int ret; for (i = 0; i < col_num; i++) { switch (key_cols[i].input_data_type) { @@ -1007,9 +785,15 @@ static int transfer_key_col_info(struct wd_key_col_info *key_cols, key_data[i].hw_type = DAE_CHAR; break; case WD_DAE_LONG_DECIMAL: + ret = dae_decimal_precision_check(key_cols[i].col_data_info, true); + if (ret) + return ret; key_data[i].hw_type = DAE_DECIMAL128; break; case WD_DAE_SHORT_DECIMAL: + ret = dae_decimal_precision_check(key_cols[i].col_data_info, false); + if (ret) + return ret; key_data[i].hw_type = DAE_DECIMAL64; break; case WD_DAE_LONG: @@ -1059,33 +843,6 @@ static int transfer_key_to_hw_type(struct hashagg_col_data *cols_data, return WD_SUCCESS; } -static int hashagg_decimal_precision_check(__u16 data_info, bool longdecimal) -{ - __u8 all_precision; - - /* - * low 8bits: overall precision - * high 8bits: precision of the decimal part - */ - all_precision = data_info; - if (longdecimal) { - if (all_precision > DAE_DECIMAL128_MAX_PRECISION) { - WD_ERR("invalid: longdecimal precision %u is more than support %d!\n", - all_precision, DAE_DECIMAL128_MAX_PRECISION); - return -WD_EINVAL; - } - return WD_SUCCESS; - } - - if (all_precision > DAE_DECIMAL64_MAX_PRECISION) { - WD_ERR("invalid: shortdecimal precision %u is more than support %d!\n", - all_precision, DAE_DECIMAL64_MAX_PRECISION); - return -WD_EINVAL; - } - - return WD_SUCCESS; -} - static int hashagg_check_sum_info(struct wd_agg_col_info *agg_col, struct hw_agg_data *user_input_data, struct hw_agg_data *user_output_data, __u32 index) @@ -1104,7 +861,7 @@ static int hashagg_check_sum_info(struct wd_agg_col_info *agg_col, break; case WD_DAE_SHORT_DECIMAL: if (agg_col->output_data_types[index] == WD_DAE_SHORT_DECIMAL) { - ret = hashagg_decimal_precision_check(agg_col->col_data_info, false); + ret = dae_decimal_precision_check(agg_col->col_data_info, false); if (ret) return ret; user_input_data->sum_outtype = DECIMAL64_TO_DECIMAL64; @@ -1112,7 +869,7 @@ static int hashagg_check_sum_info(struct wd_agg_col_info *agg_col, /* For rehash, rehash will do sum */ user_output_data->sum_outtype = DECIMAL64_TO_DECIMAL64; } else if (agg_col->output_data_types[index] == WD_DAE_LONG_DECIMAL) { - ret = hashagg_decimal_precision_check(agg_col->col_data_info, true); + ret = dae_decimal_precision_check(agg_col->col_data_info, true); if (ret) return ret; user_input_data->sum_outtype = DECIMAL64_TO_DECIMAL128; @@ -1130,7 +887,7 @@ static int hashagg_check_sum_info(struct wd_agg_col_info *agg_col, agg_col->output_data_types[index]); return -WD_EINVAL; } - ret = hashagg_decimal_precision_check(agg_col->col_data_info, true); + ret = dae_decimal_precision_check(agg_col->col_data_info, true); if (ret) return ret; user_input_data->hw_type = DAE_DECIMAL128; @@ -1166,16 +923,24 @@ static int hashagg_check_max_min_info(struct wd_agg_col_info *agg_col, struct hw_agg_data *user_input_data, struct hw_agg_data *user_output_data) { + int ret; + switch (agg_col->input_data_type) { case WD_DAE_LONG: user_input_data->hw_type = DAE_SINT64; user_output_data->hw_type = DAE_SINT64; break; case WD_DAE_SHORT_DECIMAL: + ret = dae_decimal_precision_check(agg_col->col_data_info, false); + if (ret) + return ret; user_input_data->hw_type = DAE_DECIMAL64; user_output_data->hw_type = DAE_DECIMAL64; break; case WD_DAE_LONG_DECIMAL: + ret = dae_decimal_precision_check(agg_col->col_data_info, true); + if (ret) + return ret; user_input_data->hw_type = DAE_DECIMAL128; user_output_data->hw_type = DAE_DECIMAL128; break; @@ -1394,12 +1159,12 @@ static int hashagg_get_table_rowsize(struct hashagg_col_data *cols_data) __u32 i; for (i = 0; i < key_num; i++) - row_count_size += hashagg_get_data_type_size(key_data[i].hw_type, - key_data[i].data_info); + row_count_size += get_data_type_size(key_data[i].hw_type, + key_data[i].data_info); for (i = 0; i < output_num; i++) - row_count_size += hashagg_get_data_type_size(output_col[i].hw_type, - output_col[i].data_info); + row_count_size += get_data_type_size(output_col[i].hw_type, + output_col[i].data_info); row_count_size += HASH_TABLE_EMPTY_SIZE; if (row_count_size < DAE_MIN_ROW_SIZE || row_count_size > DAE_MAX_ROW_SIZE) { @@ -1509,58 +1274,7 @@ free_agg_ctx: return ret; } -static void dae_uninit_qp_priv(handle_t h_qp) -{ - struct hisi_qp *qp = (struct hisi_qp *)h_qp; - struct dae_extend_addr *ext_addr = (struct dae_extend_addr *)qp->priv; - - free(ext_addr->addr_list); - free(ext_addr->addr_status); - free(ext_addr->ext_sqe); - free(ext_addr); - qp->priv = NULL; -} - -static int dae_init_qp_priv(handle_t h_qp) -{ - struct hisi_qp *qp = (struct hisi_qp *)h_qp; - __u16 sq_depth = qp->q_info.sq_depth; - struct dae_extend_addr *ext_addr; - int ret = -WD_ENOMEM; - - ext_addr = calloc(1, sizeof(struct dae_extend_addr)); - if (!ext_addr) - return ret; - - ext_addr->ext_sqe = aligned_alloc(DAE_ADDR_ALIGN_SIZE, DAE_EXT_SQE_SIZE * sq_depth); - if (!ext_addr->ext_sqe) - goto free_ext_addr; - - ext_addr->addr_status = calloc(1, sizeof(__u8) * sq_depth); - if (!ext_addr->addr_status) - goto free_ext_sqe; - - ext_addr->addr_list = aligned_alloc(DAE_ADDR_ALIGN_SIZE, - sizeof(struct dae_addr_list) * sq_depth); - if (!ext_addr->addr_list) - goto free_addr_status; - - ext_addr->addr_num = sq_depth; - qp->priv = ext_addr; - - return WD_SUCCESS; - -free_addr_status: - free(ext_addr->addr_status); -free_ext_sqe: - free(ext_addr->ext_sqe); -free_ext_addr: - free(ext_addr); - - return ret; -} - -static int dae_get_row_size(struct wd_alg_driver *drv, void *param) +static int agg_get_row_size(struct wd_alg_driver *drv, void *param) { struct hashagg_ctx *agg_ctx = param; @@ -1570,266 +1284,16 @@ static int dae_get_row_size(struct wd_alg_driver *drv, void *param) return agg_ctx->row_size; } -static __u32 dae_ext_table_rownum(void **ext_table, struct wd_dae_hash_table *hash_table, - __u32 row_size) -{ - __u64 tlb_size, tmp_size, row_num; - void *tmp_table; - - /* - * The first row of the extended hash table stores the hash table information, - * and the second row stores the aggregated data. The 128-bytes aligned address - * in the second row provides the optimal performance. - */ - tmp_table = PTR_ALIGN(hash_table->ext_table, DAE_TABLE_ALIGN_SIZE); - tlb_size = (__u64)hash_table->table_row_size * hash_table->ext_table_row_num; - tmp_size = (__u64)(uintptr_t)tmp_table - (__u64)(uintptr_t)hash_table->ext_table; - if (tmp_size >= tlb_size) - return 0; - - row_num = (tlb_size - tmp_size) / row_size; - if (row_size == ROW_SIZE32) { - if (tmp_size >= row_size) { - tmp_table = (__u8 *)tmp_table - row_size; - row_num += 1; - } else { - /* - * When row size is 32 bytes, the first 96 bytes are not used. - * Ensure that the address of the second row is 128 bytes aligned. - */ - if (row_num > HASH_TABLE_OFFSET_3ROW) { - tmp_table = (__u8 *)tmp_table + HASH_TABLE_OFFSET_3ROW * row_size; - row_num -= HASH_TABLE_OFFSET_3ROW; - } else { - return 0; - } - } - } else if (row_size == ROW_SIZE64) { - if (tmp_size >= row_size) { - tmp_table = (__u8 *)tmp_table - row_size; - row_num += 1; - } else { - /* - * When row size is 64 bytes, the first 64 bytes are not used. - * Ensure that the address of the second row is 128 bytes aligned. - */ - if (row_num > HASH_TABLE_OFFSET_1ROW) { - tmp_table = (__u8 *)tmp_table + HASH_TABLE_OFFSET_1ROW * row_size; - row_num -= HASH_TABLE_OFFSET_1ROW; - } else { - return 0; - } - } - } - - *ext_table = tmp_table; - - return row_num; -} - -static int dae_ext_table_init(struct hashagg_ctx *agg_ctx, - struct wd_dae_hash_table *hash_table, bool is_rehash) -{ - struct hash_table_data *hw_table = &agg_ctx->table_data; - __u64 ext_size = hw_table->ext_table_size; - __u32 row_size = agg_ctx->row_size; - __u64 tlb_size, row_num; - void *ext_table; - __u8 *ext_valid; - __u64 *ext_row; - - row_num = dae_ext_table_rownum(&ext_table, hash_table, row_size); - if (row_num <= 1) { - WD_ERR("invalid: after aligned, extend table row num is less than device need!\n"); - return -WD_EINVAL; - } - - tlb_size = row_num * row_size; - if (is_rehash && tlb_size <= ext_size) { - WD_ERR("invalid: rehash extend table size %llu is not longer than current %llu!\n", - tlb_size, ext_size); - return -WD_EINVAL; - } - - /* - * If table has been initialized, save the previous data - * before replacing the new table. - */ - if (is_rehash) - memcpy(&agg_ctx->rehash_table, hw_table, sizeof(struct hash_table_data)); - - /* Initialize the extend table value. */ - memset(ext_table, 0, tlb_size); - ext_valid = (__u8 *)ext_table + HASH_EXT_TABLE_INVALID_OFFSET; - *ext_valid = HASH_EXT_TABLE_VALID; - ext_row = (__u64 *)ext_table + 1; - *ext_row = row_num - 1; - - hw_table->ext_table = ext_table; - hw_table->ext_table_size = tlb_size; - - return WD_SUCCESS; -} - -static int dae_std_table_init(struct hash_table_data *hw_table, - struct wd_dae_hash_table *hash_table, __u32 row_size) -{ - __u64 tlb_size, row_num, tmp_size; - - /* - * Hash table address must be 128-bytes aligned, and the number - * of rows in a standard hash table must be a power of 2. - */ - hw_table->std_table = PTR_ALIGN(hash_table->std_table, DAE_TABLE_ALIGN_SIZE); - tlb_size = (__u64)hash_table->table_row_size * hash_table->std_table_row_num; - tmp_size = (__u64)(uintptr_t)hw_table->std_table - (__u64)(uintptr_t)hash_table->std_table; - if (tmp_size >= tlb_size) { - WD_ERR("invalid: after aligned, standard table size is less than 0!\n"); - return -WD_EINVAL; - } - - row_num = (tlb_size - tmp_size) / row_size; - if (!row_num) { - WD_ERR("invalid: standard table row num is 0!\n"); - return -WD_EINVAL; - } - - hw_table->table_width = (__u32)log2(row_num); - if (hw_table->table_width < HASH_TABLE_MIN_WIDTH || - hw_table->table_width > HASH_TABLE_MAX_WIDTH) { - WD_ERR("invalid: standard table width %u is out of device support range %d~%d!\n", - hw_table->table_width, HASH_TABLE_MIN_WIDTH, HASH_TABLE_MAX_WIDTH); - return -WD_EINVAL; - } - - row_num = (__u64)pow(HASH_TABLE_WITDH_POWER, hw_table->table_width); - hw_table->std_table_size = row_num * row_size; - memset(hw_table->std_table, 0, hw_table->std_table_size); - - return WD_SUCCESS; -} - -static int dae_hash_table_init(struct wd_alg_driver *drv, +static int agg_hash_table_init(struct wd_alg_driver *drv, struct wd_dae_hash_table *hash_table, void *priv) { struct hashagg_ctx *agg_ctx = priv; - struct hash_table_data *hw_table; - bool is_rehash = false; - int ret; if (!agg_ctx || !hash_table) return -WD_EINVAL; - if (!agg_ctx->row_size || agg_ctx->row_size > hash_table->table_row_size) { - WD_ERR("invalid: row size %u is error, device need %u!\n", - hash_table->table_row_size, agg_ctx->row_size); - return -WD_EINVAL; - } - - /* hash_std_table is checked by caller */ - if (!hash_table->ext_table || !hash_table->ext_table_row_num) { - WD_ERR("invalid: hash extend table is null!\n"); - return -WD_EINVAL; - } - - hw_table = &agg_ctx->table_data; - if (hw_table->std_table_size) - is_rehash = true; - - ret = dae_ext_table_init(agg_ctx, hash_table, is_rehash); - if (ret) - return ret; - - ret = dae_std_table_init(hw_table, hash_table, agg_ctx->row_size); - if (ret) - goto update_table; - - return WD_SUCCESS; - -update_table: - if (is_rehash) - memcpy(hw_table, &agg_ctx->rehash_table, sizeof(struct hash_table_data)); - else - memset(hw_table, 0, sizeof(struct hash_table_data)); - return ret; -} - -static int dae_init(struct wd_alg_driver *drv, void *conf) -{ - struct wd_ctx_config_internal *config = conf; - struct hisi_qm_priv qm_priv; - struct hisi_dae_ctx *priv; - handle_t h_qp = 0; - handle_t h_ctx; - __u32 i, j; - int ret; - - if (!config || !config->ctx_num) { - WD_ERR("invalid: dae init config is null or ctx num is 0!\n"); - return -WD_EINVAL; - } - - priv = malloc(sizeof(struct hisi_dae_ctx)); - if (!priv) - return -WD_ENOMEM; - - qm_priv.op_type = DAE_HASH_AGG_TYPE; - qm_priv.sqe_size = sizeof(struct dae_sqe); - /* Allocate qp for each context */ - for (i = 0; i < config->ctx_num; i++) { - h_ctx = config->ctxs[i].ctx; - qm_priv.qp_mode = config->ctxs[i].ctx_mode; - /* Setting the epoll en to 0 for ASYNC ctx */ - qm_priv.epoll_en = (qm_priv.qp_mode == CTX_MODE_SYNC) ? - config->epoll_en : 0; - qm_priv.idx = i; - h_qp = hisi_qm_alloc_qp(&qm_priv, h_ctx); - if (!h_qp) { - ret = -WD_ENOMEM; - goto out; - } - config->ctxs[i].sqn = qm_priv.sqn; - ret = dae_init_qp_priv(h_qp); - if (ret) - goto free_h_qp; - } - memcpy(&priv->config, config, sizeof(struct wd_ctx_config_internal)); - drv->priv = priv; - - return WD_SUCCESS; - -free_h_qp: - hisi_qm_free_qp(h_qp); -out: - for (j = 0; j < i; j++) { - h_qp = (handle_t)wd_ctx_get_priv(config->ctxs[j].ctx); - dae_uninit_qp_priv(h_qp); - hisi_qm_free_qp(h_qp); - } - free(priv); - return ret; -} - -static void dae_exit(struct wd_alg_driver *drv) -{ - struct wd_ctx_config_internal *config; - struct hisi_dae_ctx *priv; - handle_t h_qp; - __u32 i; - - if (!drv || !drv->priv) - return; - - priv = (struct hisi_dae_ctx *)drv->priv; - config = &priv->config; - for (i = 0; i < config->ctx_num; i++) { - h_qp = (handle_t)wd_ctx_get_priv(config->ctxs[i].ctx); - dae_uninit_qp_priv(h_qp); - hisi_qm_free_qp(h_qp); - } - - free(priv); - drv->priv = NULL; + return dae_hash_table_init(&agg_ctx->table_data, &agg_ctx->rehash_table, + hash_table, agg_ctx->row_size); } static int dae_get_usage(void *param) @@ -1844,8 +1308,8 @@ static int dae_get_extend_ops(void *ops) if (!agg_ops) return -WD_EINVAL; - agg_ops->get_row_size = dae_get_row_size; - agg_ops->hash_table_init = dae_hash_table_init; + agg_ops->get_row_size = agg_get_row_size; + agg_ops->hash_table_init = agg_hash_table_init; agg_ops->sess_init = hashagg_sess_priv_init; agg_ops->sess_uninit = hashagg_sess_priv_uninit; diff --git a/drv/hisi_dae.h b/drv/hisi_dae.h new file mode 100644 index 00000000..12648138 --- /dev/null +++ b/drv/hisi_dae.h @@ -0,0 +1,229 @@ +/* SPDX-License-Identifier: Apache-2.0 */ +/* + * Copyright 2025 Huawei Technologies Co.,Ltd. All rights reserved. + */ + +#ifndef __HDAE_DRV_H__ +#define __HDAE_DRV_H__ + +#include <stdbool.h> +#include <stddef.h> +#include <pthread.h> +#include <linux/types.h> + +#include "config.h" +#include "wd_alg.h" +#include "wd_dae.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define DAE_SQC_ALG_TYPE 2 +#define DAE_EXT_SQE_SIZE 128 + +/* align size */ +#define DAE_TABLE_ALIGN_SIZE 128 +#define DAE_ADDR_ALIGN_SIZE 128 +#define DAE_CHAR_ALIGN_SIZE 4 + +/* decimal infomartion */ +#define DAE_DECIMAL_PRECISION_OFFSET 8 +#define DAE_DECIMAL128_MAX_PRECISION 38 +#define DAE_DECIMAL64_MAX_PRECISION 18 + +/* hash table */ +#define HASH_EXT_TABLE_INVALID_OFFSET 5 +#define HASH_EXT_TABLE_VALID 0x80 +#define HASH_TABLE_HEAD_TAIL_SIZE 8 +#define HASH_TABLE_EMPTY_SIZE 4 +#define HASH_TABLE_WITDH_POWER 2 +#define HASH_TABLE_MIN_WIDTH 10 +#define HASH_TABLE_MAX_WIDTH 43 +#define HASH_TABLE_OFFSET_3ROW 3 +#define HASH_TABLE_OFFSET_1ROW 1 + +#define __ALIGN_MASK(x, mask) (((x) + (mask)) & ~(mask)) +#define ALIGN(x, a) __ALIGN_MASK(x, (typeof(x))(a)-1) +#define PTR_ALIGN(p, a) ((typeof(p))ALIGN((uintptr_t)(p), (a))) + +#define BIT(nr) (1UL << (nr)) +#define BITS_PER_LONG (__SIZEOF_LONG__ * 8) +#define GENMASK(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h)))) + +/* DAE hardware protocol data */ +enum dae_done_flag { + DAE_HW_TASK_NOT_PROCESS = 0x0, + DAE_HW_TASK_DONE = 0x1, + DAE_HW_TASK_ERR = 0x2, +}; + +enum dae_error_type { + DAE_TASK_SUCCESS = 0x0, + DAE_TASK_BD_ERROR_MIN = 0x1, + DAE_TASK_BD_ERROR_MAX = 0x7f, + DAE_HASH_TABLE_NEED_REHASH = 0x82, + DAE_HASH_TABLE_INVALID = 0x83, + DAE_HASHAGG_VCHAR_OVERFLOW = 0x84, + DAE_HASHAGG_RESULT_OVERFLOW = 0x85, + DAE_TASK_BUS_ERROR = 0x86, + DAE_HASHAGG_VCHAR_LEN_ERROR = 0x87, +}; + +enum dae_data_type { + DAE_SINT32 = 0x0, + DAE_SINT64 = 0x2, + DAE_DECIMAL64 = 0x9, + DAE_DECIMAL128 = 0xA, + DAE_CHAR = 0xC, + DAE_VCHAR = 0xD, +}; + +enum dae_date_type_size { + SINT32_SIZE = 4, + SINT64_SIZE = 8, + DECIMAL128_SIZE = 16, + DEFAULT_VCHAR_SIZE = 30, +}; + +enum dae_table_row_size { + ROW_SIZE32 = 32, + ROW_SIZE64 = 64, + ROW_SIZE128 = 128, + ROW_SIZE256 = 256, + ROW_SIZE512 = 512, +}; + +enum dae_bd_type { + DAE_BD_TYPE_V1 = 0x0, + DAE_BD_TYPE_V2 = 0x1, +}; + +struct dae_sqe { + __u32 bd_type : 6; + __u32 resv1 : 2; + __u32 task_type : 6; + __u32 resv2 : 2; + __u32 task_type_ext : 6; + __u32 resv3 : 9; + __u32 bd_invlid : 1; + __u16 table_row_size; + __u16 resv4; + __u32 batch_num; + __u32 low_tag; + __u32 hi_tag; + __u32 data_row_num; + __u32 init_row_num; + __u32 src_table_width : 6; + __u32 dst_table_width : 6; + __u32 key_out_en : 1; + __u32 break_point_en : 1; + __u32 multi_batch_en : 1; + __u32 sva_prefetch_en : 1; + __u32 counta_vld : 1; + __u32 index_num : 5; + __u32 resv5 : 8; + __u32 index_batch_type : 1; + __u32 resv6 : 1; + __u8 key_data_type[16]; + __u8 agg_data_type[16]; + __u32 resv9[6]; + __u64 addr_ext; + __u16 key_col_bitmap; + __u16 has_empty; + __u32 agg_col_bitmap; + __u64 addr_list; + __u32 done_flag : 3; + __u32 output_end : 1; + __u32 ext_err_type : 12; + __u32 err_type : 8; + __u32 wtype : 8; + __u32 out_raw_num; + __u32 data_row_offset; + __u16 sum_overflow_cols; + __u16 resv10; +}; + +struct dae_ext_sqe { + /* + * If date type is char/vchar, data info fill data type size + * If data type is decimal64/decimal128, data info fill data precision + */ + __u16 key_data_info[16]; + __u16 agg_data_info[16]; + /* Aggregated output from input agg col index */ + __u64 out_from_in_idx; + /* Aggregated output from input agg col operation, sum or count */ + __u64 out_optype; + __u32 resv[12]; +}; + +struct dae_col_addr { + __u64 empty_addr; + __u64 empty_size; + __u64 value_addr; + __u64 value_size; +}; + +struct dae_table_addr { + __u64 std_table_addr; + __u64 std_table_size; + __u64 ext_table_addr; + __u64 ext_table_size; +}; + +struct dae_probe_info_addr { + __u64 batch_num_index; + __u64 batch_addr_index; + __u64 probe_index_addr; + __u64 resv1; + __u64 break_point_addr; + __u64 resv2; +}; + +struct dae_addr_list { + __u64 ext_sqe_addr; + __u64 ext_sqe_size; + struct dae_table_addr src_table; + struct dae_table_addr dst_table; + struct dae_probe_info_addr probe_info; + struct dae_col_addr input_addr[32]; + struct dae_col_addr output_addr[32]; +}; + +struct dae_extend_addr { + struct dae_ext_sqe *ext_sqe; + struct dae_addr_list *addr_list; + __u8 *addr_status; + __u16 addr_num; + __u16 tail; +}; + +struct hash_table_data { + void *std_table; + void *ext_table; + __u64 std_table_size; + __u64 ext_table_size; + __u32 table_width; +}; + +struct hisi_dae_ctx { + struct wd_ctx_config_internal config; +}; + +void dae_exit(struct wd_alg_driver *drv); +int dae_init(struct wd_alg_driver *drv, void *conf); +int dae_hash_table_init(struct hash_table_data *hw_table, + struct hash_table_data *rehash_table, + struct wd_dae_hash_table *hash_table, + __u32 row_size); +int get_free_ext_addr(struct dae_extend_addr *ext_addr); +void put_ext_addr(struct dae_extend_addr *ext_addr, int idx); +__u32 get_data_type_size(enum dae_data_type type, __u16 data_info); +int dae_decimal_precision_check(__u16 data_info, bool longdecimal); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/drv/hisi_dae_common.c b/drv/hisi_dae_common.c new file mode 100644 index 00000000..43b53e0f --- /dev/null +++ b/drv/hisi_dae_common.c @@ -0,0 +1,387 @@ +// SPDX-License-Identifier: Apache-2.0 +/* Copyright 2024 Huawei Technologies Co.,Ltd. All rights reserved. */ + +#include <math.h> +#include "hisi_qm_udrv.h" +#include "hisi_dae.h" + +int dae_decimal_precision_check(__u16 data_info, bool longdecimal) +{ + __u8 all_precision; + + /* + * low 8bits: overall precision + * high 8bits: precision of the decimal part + */ + all_precision = data_info; + if (longdecimal) { + if (all_precision > DAE_DECIMAL128_MAX_PRECISION) { + WD_ERR("invalid: longdecimal precision %u is more than support %d!\n", + all_precision, DAE_DECIMAL128_MAX_PRECISION); + return -WD_EINVAL; + } + return WD_SUCCESS; + } + + if (all_precision > DAE_DECIMAL64_MAX_PRECISION) { + WD_ERR("invalid: shortdecimal precision %u is more than support %d!\n", + all_precision, DAE_DECIMAL64_MAX_PRECISION); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +__u32 get_data_type_size(enum dae_data_type type, __u16 data_info) +{ + switch (type) { + case DAE_SINT32: + return SINT32_SIZE; + case DAE_SINT64: + case DAE_DECIMAL64: + return SINT64_SIZE; + case DAE_DECIMAL128: + return DECIMAL128_SIZE; + case DAE_CHAR: + return ALIGN(data_info, DAE_CHAR_ALIGN_SIZE); + case DAE_VCHAR: + return data_info; + default: + break; + } + + return 0; +} + +/* The caller ensures that the address pointer or num is not null. */ +int get_free_ext_addr(struct dae_extend_addr *ext_addr) +{ + __u16 addr_num = ext_addr->addr_num; + __u16 idx = ext_addr->tail; + __u16 cnt = 0; + + while (__atomic_test_and_set(&ext_addr->addr_status[idx], __ATOMIC_ACQUIRE)) { + idx = (idx + 1) % addr_num; + cnt++; + if (cnt == addr_num) + return -WD_EBUSY; + } + + ext_addr->tail = (idx + 1) % addr_num; + + return idx; +} + +void put_ext_addr(struct dae_extend_addr *ext_addr, int idx) +{ + __atomic_clear(&ext_addr->addr_status[idx], __ATOMIC_RELEASE); +} + +static void dae_uninit_qp_priv(handle_t h_qp) +{ + struct hisi_qp *qp = (struct hisi_qp *)h_qp; + struct dae_extend_addr *ext_addr = (struct dae_extend_addr *)qp->priv; + + free(ext_addr->addr_list); + free(ext_addr->addr_status); + free(ext_addr->ext_sqe); + free(ext_addr); + qp->priv = NULL; +} + +static int dae_init_qp_priv(handle_t h_qp) +{ + struct hisi_qp *qp = (struct hisi_qp *)h_qp; + __u16 sq_depth = qp->q_info.sq_depth; + struct dae_extend_addr *ext_addr; + int ret = -WD_ENOMEM; + + ext_addr = calloc(1, sizeof(struct dae_extend_addr)); + if (!ext_addr) + return ret; + + ext_addr->ext_sqe = aligned_alloc(DAE_ADDR_ALIGN_SIZE, DAE_EXT_SQE_SIZE * sq_depth); + if (!ext_addr->ext_sqe) + goto free_ext_addr; + + ext_addr->addr_status = calloc(1, sizeof(__u8) * sq_depth); + if (!ext_addr->addr_status) + goto free_ext_sqe; + + ext_addr->addr_list = aligned_alloc(DAE_ADDR_ALIGN_SIZE, + sizeof(struct dae_addr_list) * sq_depth); + if (!ext_addr->addr_list) + goto free_addr_status; + + ext_addr->addr_num = sq_depth; + qp->priv = ext_addr; + + return WD_SUCCESS; + +free_addr_status: + free(ext_addr->addr_status); +free_ext_sqe: + free(ext_addr->ext_sqe); +free_ext_addr: + free(ext_addr); + + return ret; +} + +static __u32 dae_ext_table_rownum(void **ext_table, struct wd_dae_hash_table *hash_table, + __u32 row_size) +{ + __u64 tlb_size, tmp_size, row_num; + void *tmp_table; + + /* + * The first row of the extended hash table stores the hash table information, + * and the second row stores the aggregated data. The 128-bytes aligned address + * in the second row provides the optimal performance. + */ + tmp_table = PTR_ALIGN(hash_table->ext_table, DAE_TABLE_ALIGN_SIZE); + tlb_size = (__u64)hash_table->table_row_size * hash_table->ext_table_row_num; + tmp_size = (__u64)(uintptr_t)tmp_table - (__u64)(uintptr_t)hash_table->ext_table; + if (tmp_size >= tlb_size) + return 0; + + row_num = (tlb_size - tmp_size) / row_size; + if (row_size == ROW_SIZE32) { + if (tmp_size >= row_size) { + tmp_table = (__u8 *)tmp_table - row_size; + row_num += 1; + } else { + /* + * When row size is 32 bytes, the first 96 bytes are not used. + * Ensure that the address of the second row is 128 bytes aligned. + */ + if (row_num > HASH_TABLE_OFFSET_3ROW) { + tmp_table = (__u8 *)tmp_table + HASH_TABLE_OFFSET_3ROW * row_size; + row_num -= HASH_TABLE_OFFSET_3ROW; + } else { + return 0; + } + } + } else if (row_size == ROW_SIZE64) { + if (tmp_size >= row_size) { + tmp_table = (__u8 *)tmp_table - row_size; + row_num += 1; + } else { + /* + * When row size is 64 bytes, the first 64 bytes are not used. + * Ensure that the address of the second row is 128 bytes aligned. + */ + if (row_num > HASH_TABLE_OFFSET_1ROW) { + tmp_table = (__u8 *)tmp_table + HASH_TABLE_OFFSET_1ROW * row_size; + row_num -= HASH_TABLE_OFFSET_1ROW; + } else { + return 0; + } + } + } + + *ext_table = tmp_table; + + return row_num; +} + +static int dae_ext_table_init(struct hash_table_data *hw_table, + struct hash_table_data *rehash_table, + struct wd_dae_hash_table *hash_table, + __u32 row_size, bool is_rehash) +{ + __u64 ext_size = hw_table->ext_table_size; + __u64 tlb_size, row_num; + void *ext_table; + __u8 *ext_valid; + __u64 *ext_row; + + row_num = dae_ext_table_rownum(&ext_table, hash_table, row_size); + if (row_num <= 1) { + WD_ERR("invalid: after aligned, extend table row num is less than device need!\n"); + return -WD_EINVAL; + } + + tlb_size = row_num * row_size; + if (is_rehash && tlb_size <= ext_size) { + WD_ERR("invalid: rehash extend table size %llu is not longer than current %llu!\n", + tlb_size, ext_size); + return -WD_EINVAL; + } + + /* + * If table has been initialized, save the previous data + * before replacing the new table. + */ + if (is_rehash) + memcpy(rehash_table, hw_table, sizeof(struct hash_table_data)); + + /* Initialize the extend table value. */ + memset(ext_table, 0, tlb_size); + ext_valid = (__u8 *)ext_table + HASH_EXT_TABLE_INVALID_OFFSET; + *ext_valid = HASH_EXT_TABLE_VALID; + ext_row = (__u64 *)ext_table + 1; + *ext_row = row_num - 1; + + hw_table->ext_table = ext_table; + hw_table->ext_table_size = tlb_size; + + return WD_SUCCESS; +} + +static int dae_std_table_init(struct hash_table_data *hw_table, + struct wd_dae_hash_table *hash_table, __u32 row_size) +{ + __u64 tlb_size, row_num, tmp_size; + + /* + * Hash table address must be 128-bytes aligned, and the number + * of rows in a standard hash table must be a power of 2. + */ + hw_table->std_table = PTR_ALIGN(hash_table->std_table, DAE_TABLE_ALIGN_SIZE); + tlb_size = (__u64)hash_table->table_row_size * hash_table->std_table_row_num; + tmp_size = (__u64)(uintptr_t)hw_table->std_table - (__u64)(uintptr_t)hash_table->std_table; + if (tmp_size >= tlb_size) { + WD_ERR("invalid: after aligned, standard table size is less than 0!\n"); + return -WD_EINVAL; + } + + row_num = (tlb_size - tmp_size) / row_size; + if (!row_num) { + WD_ERR("invalid: standard table row num is 0!\n"); + return -WD_EINVAL; + } + + hw_table->table_width = (__u32)log2(row_num); + if (hw_table->table_width < HASH_TABLE_MIN_WIDTH || + hw_table->table_width > HASH_TABLE_MAX_WIDTH) { + WD_ERR("invalid: standard table width %u is out of device support range %d~%d!\n", + hw_table->table_width, HASH_TABLE_MIN_WIDTH, HASH_TABLE_MAX_WIDTH); + return -WD_EINVAL; + } + + row_num = (__u64)pow(HASH_TABLE_WITDH_POWER, hw_table->table_width); + hw_table->std_table_size = row_num * row_size; + memset(hw_table->std_table, 0, hw_table->std_table_size); + + return WD_SUCCESS; +} + +int dae_hash_table_init(struct hash_table_data *hw_table, + struct hash_table_data *rehash_table, + struct wd_dae_hash_table *hash_table, + __u32 row_size) +{ + bool is_rehash = false; + int ret; + + if (!row_size || row_size > hash_table->table_row_size) { + WD_ERR("invalid: row size %u is error, device need %u!\n", + hash_table->table_row_size, row_size); + return -WD_EINVAL; + } + + /* hash_std_table is checked by caller */ + if (!hash_table->ext_table || !hash_table->ext_table_row_num) { + WD_ERR("invalid: hash extend table is null!\n"); + return -WD_EINVAL; + } + + if (hw_table->std_table_size) + is_rehash = true; + + ret = dae_ext_table_init(hw_table, rehash_table, hash_table, row_size, is_rehash); + if (ret) + return ret; + + ret = dae_std_table_init(hw_table, hash_table, row_size); + if (ret) + goto update_table; + + return WD_SUCCESS; + +update_table: + if (is_rehash) + memcpy(hw_table, rehash_table, sizeof(struct hash_table_data)); + else + memset(hw_table, 0, sizeof(struct hash_table_data)); + return ret; +} + +int dae_init(struct wd_alg_driver *drv, void *conf) +{ + struct wd_ctx_config_internal *config = conf; + struct hisi_qm_priv qm_priv; + struct hisi_dae_ctx *priv; + handle_t h_qp = 0; + handle_t h_ctx; + __u32 i, j; + int ret; + + if (!config || !config->ctx_num) { + WD_ERR("invalid: dae init config is null or ctx num is 0!\n"); + return -WD_EINVAL; + } + + priv = malloc(sizeof(struct hisi_dae_ctx)); + if (!priv) + return -WD_ENOMEM; + + qm_priv.op_type = DAE_SQC_ALG_TYPE; + qm_priv.sqe_size = sizeof(struct dae_sqe); + /* Allocate qp for each context */ + for (i = 0; i < config->ctx_num; i++) { + h_ctx = config->ctxs[i].ctx; + qm_priv.qp_mode = config->ctxs[i].ctx_mode; + /* Setting the epoll en to 0 for ASYNC ctx */ + qm_priv.epoll_en = (qm_priv.qp_mode == CTX_MODE_SYNC) ? + config->epoll_en : 0; + qm_priv.idx = i; + h_qp = hisi_qm_alloc_qp(&qm_priv, h_ctx); + if (!h_qp) { + ret = -WD_ENOMEM; + goto out; + } + config->ctxs[i].sqn = qm_priv.sqn; + ret = dae_init_qp_priv(h_qp); + if (ret) + goto free_h_qp; + } + memcpy(&priv->config, config, sizeof(struct wd_ctx_config_internal)); + drv->priv = priv; + + return WD_SUCCESS; + +free_h_qp: + hisi_qm_free_qp(h_qp); +out: + for (j = 0; j < i; j++) { + h_qp = (handle_t)wd_ctx_get_priv(config->ctxs[j].ctx); + dae_uninit_qp_priv(h_qp); + hisi_qm_free_qp(h_qp); + } + free(priv); + return ret; +} + +void dae_exit(struct wd_alg_driver *drv) +{ + struct wd_ctx_config_internal *config; + struct hisi_dae_ctx *priv; + handle_t h_qp; + __u32 i; + + if (!drv || !drv->priv) + return; + + priv = (struct hisi_dae_ctx *)drv->priv; + config = &priv->config; + for (i = 0; i < config->ctx_num; i++) { + h_qp = (handle_t)wd_ctx_get_priv(config->ctxs[i].ctx); + dae_uninit_qp_priv(h_qp); + hisi_qm_free_qp(h_qp); + } + + free(priv); + drv->priv = NULL; +} diff --git a/drv/hisi_dae_join_gather.c b/drv/hisi_dae_join_gather.c new file mode 100644 index 00000000..db965d35 --- /dev/null +++ b/drv/hisi_dae_join_gather.c @@ -0,0 +1,1040 @@ +// SPDX-License-Identifier: Apache-2.0 +/* Copyright 2025 Huawei Technologies Co.,Ltd. All rights reserved. */ + +#include "hisi_qm_udrv.h" +#include "hisi_dae.h" +#include "../include/drv/wd_join_gather_drv.h" + +#define DAE_EXT_SQE_SIZE 128 +#define DAE_CTX_Q_NUM_DEF 1 + +/* column information */ +#define DAE_MAX_KEY_COLS 9 +#define DAE_MAX_CHAR_SIZE 32 +#define DAE_MAX_ROW_SIZE 512 +#define DAE_JOIN_MAX_ROW_NUN 50000 +#define DAE_JOIN_MAX_BATCH_NUM 2800 +#define DAE_MAX_TABLE_NUM 16 +#define BUILD_INDEX_ROW_SIZE 8 +#define PROBE_INDEX_ROW_SIZE 4 + +/* align size */ +#define DAE_KEY_ALIGN_SIZE 8 +#define DAE_BREAKPOINT_SIZE 81920 +#define DAE_ADDR_INDEX_SHIFT 1 + +/* hash table */ +#define HASH_TABLE_HEAD_TAIL_SIZE 8 +#define HASH_TABLE_INDEX_NUM 1 +#define HASH_TABLE_MAX_INDEX_NUM 15 +#define HASH_TABLE_INDEX_SIZE 12 +#define HASH_TABLE_EMPTY_SIZE 4 +#define GATHER_ROW_BATCH_EMPTY_SIZE 2 + +/* DAE hardware protocol data */ +enum dae_join_stage { + DAE_JOIN_BUILD_HASH = 0x0, + DAE_JOIN_REHASH = 0x6, + DAE_JOIN_PROBE = 0x7, +}; + +enum dae_gather_stage { + DAE_GATHER_CONVERT = 0x0, + DAE_GATHER_COMPLETE = 0x7, +}; + +enum dae_task_type { + DAE_HASH_JOIN = 0x1, + DAE_GATHER = 0x2, +}; + +static enum dae_data_type hw_data_type_order[] = { + DAE_VCHAR, DAE_CHAR, DAE_DECIMAL128, + DAE_DECIMAL64, DAE_SINT64, DAE_SINT32, +}; + +struct hw_join_gather_data { + enum dae_data_type hw_type; + __u32 optype; + __u32 usr_col_idx; + __u16 data_info; +}; + +struct join_gather_col_data { + struct hw_join_gather_data key_data[DAE_MAX_KEY_COLS]; + struct hw_join_gather_data gather_data[DAE_MAX_TABLE_NUM][DAE_MAX_KEY_COLS]; + + __u32 key_num; + __u32 gather_table_num; + __u32 gather_cols_num[DAE_MAX_TABLE_NUM]; + __u16 has_empty[DAE_MAX_TABLE_NUM]; + __u8 index_num; +}; + +struct join_gather_ctx { + struct join_gather_col_data cols_data; + struct hash_table_data table_data; + struct hash_table_data rehash_table; + pthread_spinlock_t lock; + __u32 hash_table_row_size; + __u32 batch_row_size[DAE_MAX_TABLE_NUM]; +}; + +static void fill_join_gather_misc_field(struct wd_join_gather_msg *msg, + struct dae_sqe *sqe) +{ + struct join_gather_ctx *ctx = msg->priv; + struct join_gather_col_data *cols_data = &ctx->cols_data; + + sqe->sva_prefetch_en = true; + + switch (msg->op_type) { + case WD_JOIN_BUILD_HASH: + sqe->task_type = DAE_HASH_JOIN; + sqe->task_type_ext = DAE_JOIN_BUILD_HASH; + sqe->data_row_num = msg->req.input_row_num; + sqe->batch_num = msg->req.join_req.build_batch_index; + sqe->init_row_num = msg->req.join_req.batch_row_offset; + sqe->index_num = cols_data->index_num; + break; + case WD_JOIN_PROBE: + sqe->task_type = DAE_HASH_JOIN; + sqe->task_type_ext = DAE_JOIN_PROBE; + sqe->data_row_num = msg->req.output_row_num; + sqe->batch_num = msg->req.input_row_num; + sqe->init_row_num = msg->req.join_req.batch_row_offset; + sqe->index_num = cols_data->index_num; + sqe->key_out_en = msg->key_out_en; + sqe->break_point_en = sqe->init_row_num ? true : false; + sqe->index_batch_type = msg->index_type; + break; + case WD_JOIN_REHASH: + sqe->task_type = DAE_HASH_JOIN; + sqe->task_type_ext = DAE_JOIN_REHASH; + sqe->data_row_num = msg->req.output_row_num; + sqe->index_num = cols_data->index_num; + break; + case WD_GATHER_CONVERT: + sqe->task_type = DAE_GATHER; + sqe->task_type_ext = DAE_GATHER_CONVERT; + sqe->data_row_num = msg->req.input_row_num; + break; + case WD_GATHER_COMPLETE: + sqe->task_type = DAE_GATHER; + sqe->task_type_ext = DAE_GATHER_COMPLETE; + sqe->multi_batch_en = msg->multi_batch_en; + sqe->index_batch_type = msg->index_type; + sqe->data_row_num = msg->req.output_row_num; + break; + default: + break; + } +} + +static void fill_join_table_data(struct dae_sqe *sqe, struct dae_addr_list *addr_list, + struct wd_join_gather_msg *msg, struct join_gather_ctx *ctx) +{ + struct dae_table_addr *hw_table_src = &addr_list->src_table; + struct dae_table_addr *hw_table_dst = &addr_list->dst_table; + struct hash_table_data *table_data_src, *table_data_dst; + + switch (msg->op_type) { + case WD_JOIN_BUILD_HASH: + table_data_src = NULL; + table_data_dst = &ctx->table_data; + break; + case WD_JOIN_REHASH: + table_data_src = &ctx->rehash_table; + table_data_dst = &ctx->table_data; + break; + case WD_JOIN_PROBE: + table_data_src = &ctx->table_data; + table_data_dst = NULL; + break; + default: + return; + } + + sqe->table_row_size = ctx->hash_table_row_size; + + if (table_data_src) { + sqe->src_table_width = table_data_src->table_width; + hw_table_src->std_table_addr = (__u64)(uintptr_t)table_data_src->std_table; + hw_table_src->std_table_size = table_data_src->std_table_size; + hw_table_src->ext_table_addr = (__u64)(uintptr_t)table_data_src->ext_table; + hw_table_src->ext_table_size = table_data_src->ext_table_size; + } + + if (table_data_dst) { + sqe->dst_table_width = table_data_dst->table_width; + hw_table_dst->std_table_addr = (__u64)(uintptr_t)table_data_dst->std_table; + hw_table_dst->std_table_size = table_data_dst->std_table_size; + hw_table_dst->ext_table_addr = (__u64)(uintptr_t)table_data_dst->ext_table; + hw_table_dst->ext_table_size = table_data_dst->ext_table_size; + } +} + +static void fill_join_key_data(struct dae_sqe *sqe, struct dae_ext_sqe *ext_sqe, + struct dae_addr_list *addr_list, + struct wd_join_gather_msg *msg, struct join_gather_ctx *ctx) +{ + struct dae_probe_info_addr *info = &addr_list->probe_info; + struct hw_join_gather_data *key_data = ctx->cols_data.key_data; + struct wd_dae_col_addr *usr_key, *out_usr_key = NULL; + struct dae_col_addr *hw_key, *out_hw_key = NULL; + struct wd_join_req *req = &msg->req.join_req; + struct wd_probe_out_info *output = &req->probe_output; + __u16 usr_col_idx; + __u64 offset; + __u32 i; + + sqe->key_col_bitmap = GENMASK(msg->key_cols_num - 1, 0); + + for (i = 0; i < msg->key_cols_num; i++) { + sqe->key_data_type[i] = key_data[i].hw_type; + ext_sqe->key_data_info[i] = key_data[i].data_info; + } + + switch (msg->op_type) { + case WD_JOIN_BUILD_HASH: + usr_key = req->key_cols; + hw_key = addr_list->input_addr; + if (msg->index_type == WD_BATCH_ADDR_INDEX) + sqe->addr_ext = (__u64)(uintptr_t)req->build_batch_addr.addr; + break; + case WD_JOIN_PROBE: + usr_key = req->key_cols; + hw_key = addr_list->input_addr; + if (msg->key_out_en) { + out_usr_key = output->key_cols; + out_hw_key = addr_list->output_addr; + } + + info->batch_num_index = (__u64)(uintptr_t)output->build_index.addr; + info->probe_index_addr = (__u64)(uintptr_t)output->probe_index.addr; + info->break_point_addr = (__u64)(uintptr_t)output->breakpoint.addr; + + if (msg->index_type == WD_BATCH_ADDR_INDEX) { + offset = (__u64)output->build_index.row_size * output->build_index.row_num; + offset = offset >> DAE_ADDR_INDEX_SHIFT; + info->batch_addr_index = info->batch_num_index + offset; + } + break; + default: + return; + } + + for (i = 0; i < msg->key_cols_num; i++) { + usr_col_idx = key_data[i].usr_col_idx; + hw_key[i].empty_addr = (__u64)(uintptr_t)usr_key[usr_col_idx].empty; + hw_key[i].empty_size = usr_key[usr_col_idx].empty_size; + hw_key[i].value_addr = (__u64)(uintptr_t)usr_key[usr_col_idx].value; + hw_key[i].value_size = usr_key[usr_col_idx].value_size; + + if (!out_usr_key) + continue; + out_hw_key[i].empty_addr = (__u64)(uintptr_t)out_usr_key[usr_col_idx].empty; + out_hw_key[i].empty_size = out_usr_key[usr_col_idx].empty_size; + /* The hardware does not output the empty data, set the data by software. */ + memset(out_usr_key[usr_col_idx].empty, 0, out_usr_key[usr_col_idx].empty_size); + + out_hw_key[i].value_addr = (__u64)(uintptr_t)out_usr_key[usr_col_idx].value; + out_hw_key[i].value_size = out_usr_key[usr_col_idx].value_size; + } +} + +static void fill_gather_col_data(struct dae_sqe *sqe, struct dae_ext_sqe *ext_sqe, + struct dae_addr_list *addr_list, + struct wd_join_gather_msg *msg, struct join_gather_ctx *ctx) +{ + struct dae_probe_info_addr *info = &addr_list->probe_info; + struct join_gather_col_data *cols_data = &ctx->cols_data; + struct wd_gather_req *gather_req = &msg->req.gather_req; + __u32 table_index = gather_req->table_index; + struct hw_join_gather_data *gather_data = cols_data->gather_data[table_index]; + __u16 cols_num = cols_data->gather_cols_num[table_index]; + struct wd_dae_col_addr *usr_data; + struct dae_col_addr *hw_data; + __u16 usr_col_idx; + void **batch_addr; + __u64 offset; + __u32 i; + + sqe->key_col_bitmap = GENMASK(cols_num - 1, 0); + sqe->has_empty = cols_data->has_empty[table_index]; + sqe->table_row_size = ctx->batch_row_size[table_index]; + + usr_data = gather_req->data_cols; + batch_addr = gather_req->row_batchs.batch_addr; + + switch (msg->op_type) { + case WD_GATHER_CONVERT: + hw_data = addr_list->input_addr; + /* Single batch tasks use the first element of the array. */ + addr_list->dst_table.std_table_addr = (__u64)(uintptr_t)batch_addr[0]; + break; + case WD_GATHER_COMPLETE: + hw_data = addr_list->output_addr; + if (!msg->multi_batch_en) { + info->probe_index_addr = (__u64)(uintptr_t)gather_req->index.addr; + addr_list->src_table.std_table_addr = (__u64)(uintptr_t)batch_addr[0]; + break; + } + + info->batch_num_index = (__u64)(uintptr_t)gather_req->index.addr; + if (msg->index_type == WD_BATCH_ADDR_INDEX) { + offset = (__u64)gather_req->index.row_size * gather_req->index.row_num; + offset = offset >> DAE_ADDR_INDEX_SHIFT; + info->batch_addr_index = info->batch_num_index + offset; + } else { + addr_list->src_table.std_table_addr = (__u64)(uintptr_t)batch_addr; + } + break; + default: + return; + } + + for (i = 0; i < cols_num; i++) { + sqe->key_data_type[i] = gather_data[i].hw_type; + ext_sqe->key_data_info[i] = gather_data[i].data_info; + + usr_col_idx = gather_data[i].usr_col_idx; + hw_data[i].empty_addr = (__u64)(uintptr_t)usr_data[usr_col_idx].empty; + hw_data[i].empty_size = usr_data[usr_col_idx].empty_size; + hw_data[i].value_addr = (__u64)(uintptr_t)usr_data[usr_col_idx].value; + hw_data[i].value_size = usr_data[usr_col_idx].value_size; + } +} + +static void fill_join_gather_ext_addr(struct dae_sqe *sqe, struct dae_ext_sqe *ext_sqe, + struct dae_addr_list *addr_list) +{ + memset(ext_sqe, 0, DAE_EXT_SQE_SIZE); + memset(addr_list, 0, sizeof(struct dae_addr_list)); + sqe->addr_list = (__u64)(uintptr_t)addr_list; + addr_list->ext_sqe_addr = (__u64)(uintptr_t)ext_sqe; + addr_list->ext_sqe_size = DAE_EXT_SQE_SIZE; +} + +static void fill_join_gather_info(struct dae_sqe *sqe, struct dae_ext_sqe *ext_sqe, + struct dae_addr_list *addr_list, + struct wd_join_gather_msg *msg) +{ + struct join_gather_ctx *ctx = (struct join_gather_ctx *)msg->priv; + + fill_join_gather_ext_addr(sqe, ext_sqe, addr_list); + sqe->bd_type = DAE_BD_TYPE_V2; + + switch (msg->op_type) { + case WD_JOIN_BUILD_HASH: + case WD_JOIN_PROBE: + case WD_JOIN_REHASH: + fill_join_table_data(sqe, addr_list, msg, ctx); + fill_join_key_data(sqe, ext_sqe, addr_list, msg, ctx); + break; + case WD_GATHER_CONVERT: + case WD_GATHER_COMPLETE: + fill_gather_col_data(sqe, ext_sqe, addr_list, msg, ctx); + break; + default: + break; + } +} + +static int check_join_gather_param(struct wd_join_gather_msg *msg) +{ + struct wd_probe_out_info *output; + struct wd_gather_req *greq; + __u64 row_num, size; + + if (!msg) { + WD_ERR("invalid: input join gather msg is NULL!\n"); + return -WD_EINVAL; + } + + output = &msg->req.join_req.probe_output; + greq = &msg->req.gather_req; + + switch (msg->op_type) { + case WD_JOIN_BUILD_HASH: + if (msg->req.input_row_num > DAE_JOIN_MAX_ROW_NUN) { + WD_ERR("invalid: build table row count %u is more than %d!\n", + msg->req.input_row_num, DAE_JOIN_MAX_ROW_NUN); + return -WD_EINVAL; + } + if (msg->index_type == WD_BATCH_NUMBER_INDEX) { + if (msg->req.join_req.build_batch_index >= DAE_JOIN_MAX_BATCH_NUM) { + WD_ERR("invalid: input join batch index is more than %d!\n", + DAE_JOIN_MAX_BATCH_NUM - 1); + return -WD_EINVAL; + } + } else { + if (!msg->req.join_req.build_batch_addr.addr || + !msg->req.join_req.build_batch_addr.row_num || + !msg->req.join_req.build_batch_addr.row_size) { + WD_ERR("invalid: input join build batch addr is NULL!\n"); + return -WD_EINVAL; + } + } + break; + case WD_JOIN_PROBE: + size = (__u64)output->breakpoint.row_size * output->breakpoint.row_num; + if (!output->breakpoint.addr || size < DAE_BREAKPOINT_SIZE) { + WD_ERR("invalid probe breakpoint size: %llu\n", size); + return -WD_EINVAL; + } + if (msg->index_type == WD_BATCH_ADDR_INDEX) { + row_num = msg->req.output_row_num << DAE_ADDR_INDEX_SHIFT; + if (output->build_index.row_num < row_num) { + WD_ERR("build index row number is less than: %llu\n", + row_num); + return -WD_EINVAL; + } + } + + if (output->probe_index.row_size != PROBE_INDEX_ROW_SIZE || + output->build_index.row_size != BUILD_INDEX_ROW_SIZE) { + WD_ERR("build and probe index row size need be %d, %d!\n", + BUILD_INDEX_ROW_SIZE, PROBE_INDEX_ROW_SIZE); + return -WD_EINVAL; + } + break; + case WD_JOIN_REHASH: + case WD_GATHER_CONVERT: + break; + case WD_GATHER_COMPLETE: + if (!msg->multi_batch_en) { + if (greq->index.row_size != PROBE_INDEX_ROW_SIZE) { + WD_ERR("invalid: probe index row size need be %d!\n", + PROBE_INDEX_ROW_SIZE); + return -WD_EINVAL; + } + break; + } + + if (greq->index.row_size != BUILD_INDEX_ROW_SIZE) { + WD_ERR("invalid: build index row size need be %d!\n", + BUILD_INDEX_ROW_SIZE); + return -WD_EINVAL; + } + if (msg->index_type == WD_BATCH_NUMBER_INDEX) { + if (greq->row_batchs.batch_num > DAE_JOIN_MAX_BATCH_NUM) { + WD_ERR("invalid: gather row batch num is more than %d!\n", + DAE_JOIN_MAX_BATCH_NUM); + return -WD_EINVAL; + } + } else { + row_num = msg->req.output_row_num << DAE_ADDR_INDEX_SHIFT; + if (greq->index.row_num < row_num) { + WD_ERR("build index row number is less than: %llu\n", + row_num); + return -WD_EINVAL; + } + } + break; + default: + break; + } + + return WD_SUCCESS; +} + +static int join_gather_send(struct wd_alg_driver *drv, handle_t ctx, void *send_msg) +{ + handle_t h_qp = (handle_t)wd_ctx_get_priv(ctx); + struct hisi_qp *qp = (struct hisi_qp *)h_qp; + struct dae_extend_addr *ext_addr = qp->priv; + struct wd_join_gather_msg *msg = send_msg; + struct dae_addr_list *addr_list; + struct dae_ext_sqe *ext_sqe; + struct dae_sqe sqe = {0}; + __u16 send_cnt = 0; + int ret, idx; + + ret = check_join_gather_param(msg); + if (ret) + return ret; + + fill_join_gather_misc_field(msg, &sqe); + + idx = get_free_ext_addr(ext_addr); + if (idx < 0) + return -WD_EBUSY; + addr_list = &ext_addr->addr_list[idx]; + ext_sqe = &ext_addr->ext_sqe[idx]; + + fill_join_gather_info(&sqe, ext_sqe, addr_list, msg); + + hisi_set_msg_id(h_qp, &msg->tag); + sqe.low_tag = msg->tag; + sqe.hi_tag = idx; + + ret = hisi_qm_send(h_qp, &sqe, 1, &send_cnt); + if (ret) { + if (ret != -WD_EBUSY) + WD_ERR("failed to send to hardware, ret = %d!\n", ret); + put_ext_addr(ext_addr, idx); + return ret; + } + + return WD_SUCCESS; +} + +static void fill_join_gather_task_done(struct dae_sqe *sqe, struct wd_join_gather_msg *msg) +{ + if (sqe->task_type == DAE_HASH_JOIN) { + if (sqe->task_type_ext == DAE_JOIN_PROBE) { + msg->consumed_row_num = sqe->data_row_offset; + msg->produced_row_num = sqe->out_raw_num; + msg->output_done = sqe->output_end; + } else if (sqe->task_type_ext == DAE_JOIN_REHASH) { + msg->output_done = sqe->output_end; + } + } +} + +static void fill_join_gather_task_err(struct dae_sqe *sqe, struct wd_join_gather_msg *msg) +{ + switch (sqe->err_type) { + case DAE_TASK_BD_ERROR_MIN ... DAE_TASK_BD_ERROR_MAX: + WD_ERR("failed to do join gather task, bd error=0x%x!\n", sqe->err_type); + msg->result = WD_JOIN_GATHER_PARSE_ERROR; + break; + case DAE_HASH_TABLE_NEED_REHASH: + msg->result = WD_JOIN_GATHER_NEED_REHASH; + break; + case DAE_HASH_TABLE_INVALID: + msg->result = WD_JOIN_GATHER_INVALID_HASH_TABLE; + break; + case DAE_TASK_BUS_ERROR: + WD_ERR("failed to do join gather task, bus error %u!\n", sqe->err_type); + msg->result = WD_JOIN_GATHER_BUS_ERROR; + break; + default: + WD_ERR("failed to do dae task! done_flag=0x%x, etype=0x%x, ext_type = 0x%x!\n", + (__u32)sqe->done_flag, (__u32)sqe->err_type, (__u32)sqe->ext_err_type); + msg->result = WD_JOIN_GATHER_PARSE_ERROR; + break; + } + + if (sqe->task_type == DAE_HASH_JOIN && sqe->task_type_ext == DAE_JOIN_PROBE) { + msg->produced_row_num = sqe->out_raw_num; + msg->consumed_row_num = sqe->data_row_offset; + msg->output_done = sqe->output_end; + } +} + +static int join_gather_recv(struct wd_alg_driver *drv, handle_t hctx, void *recv_msg) +{ + handle_t h_qp = (handle_t)wd_ctx_get_priv(hctx); + struct hisi_qp *qp = (struct hisi_qp *)h_qp; + struct dae_extend_addr *ext_addr = qp->priv; + struct wd_join_gather_msg *msg = recv_msg; + struct wd_join_gather_msg *send_msg; + struct dae_sqe sqe = {0}; + __u16 recv_cnt = 0; + int ret; + + ret = hisi_qm_recv(h_qp, &sqe, 1, &recv_cnt); + if (ret) + return ret; + + ret = hisi_check_bd_id(h_qp, msg->tag, sqe.low_tag); + if (ret) + goto out; + + msg->tag = sqe.low_tag; + if (qp->q_info.qp_mode == CTX_MODE_ASYNC) { + send_msg = wd_join_gather_get_msg(qp->q_info.idx, msg->tag); + if (!send_msg) { + msg->result = WD_JOIN_GATHER_IN_EPARA; + WD_ERR("failed to get send msg! idx = %u, tag = %u.\n", + qp->q_info.idx, msg->tag); + ret = -WD_EINVAL; + goto out; + } + } + + msg->result = WD_JOIN_GATHER_TASK_DONE; + msg->consumed_row_num = 0; + + if (likely(sqe.done_flag == DAE_HW_TASK_DONE)) { + fill_join_gather_task_done(&sqe, msg); + } else if (sqe.done_flag == DAE_HW_TASK_ERR) { + fill_join_gather_task_err(&sqe, msg); + } else { + msg->result = WD_JOIN_GATHER_PARSE_ERROR; + WD_ERR("failed to do join gather task, hardware doesn't process the task!\n"); + } + +out: + put_ext_addr(ext_addr, sqe.hi_tag); + return ret; +} + +static int join_check_params(struct wd_join_gather_col_info *key_info, __u32 cols_num) +{ + __u32 i; + int ret; + + if (cols_num > DAE_MAX_KEY_COLS) { + WD_ERR("invalid: join key cols num %u is more than device support %d!\n", + cols_num, DAE_MAX_KEY_COLS); + return -WD_EINVAL; + } + + for (i = 0; i < cols_num; i++) { + switch (key_info[i].data_type) { + case WD_DAE_SHORT_DECIMAL: + ret = dae_decimal_precision_check(key_info[i].data_info, false); + if (ret) + return ret; + break; + case WD_DAE_LONG_DECIMAL: + ret = dae_decimal_precision_check(key_info[i].data_info, true); + if (ret) + return ret; + break; + case WD_DAE_CHAR: + case WD_DAE_VARCHAR: + WD_ERR("invalid: key col %u, char or varchar isn't supported!\n", i); + return -WD_EINVAL; + default: + break; + } + } + + return WD_SUCCESS; +} + +static int gather_check_params(struct wd_join_gather_sess_setup *setup) +{ + struct wd_gather_table_info *table = setup->gather_tables; + struct wd_join_gather_col_info *col; + __u32 i, j; + int ret; + + if (setup->gather_table_num > DAE_MAX_TABLE_NUM) { + WD_ERR("invalid: gather table num %u is more than device support %d!\n", + setup->gather_table_num, DAE_MAX_TABLE_NUM); + return -WD_EINVAL; + } + + for (i = 0; i < setup->gather_table_num; i++) { + col = table[i].cols; + if (table[i].cols_num > DAE_MAX_KEY_COLS) { + WD_ERR("invalid: gather cols num %u is more than device support %d!\n", + table[i].cols_num, DAE_MAX_KEY_COLS); + return -WD_EINVAL; + } + for (j = 0; j < table[i].cols_num; j++) { + switch (col[j].data_type) { + case WD_DAE_SHORT_DECIMAL: + ret = dae_decimal_precision_check(col[j].data_info, false); + if (ret) + return ret; + break; + case WD_DAE_LONG_DECIMAL: + ret = dae_decimal_precision_check(col[j].data_info, true); + if (ret) + return ret; + break; + case WD_DAE_CHAR: + if (col[j].data_info > DAE_MAX_CHAR_SIZE) { + WD_ERR("gather col %u, char size isn't supported!\n", j); + return -WD_EINVAL; + } + break; + case WD_DAE_VARCHAR: + WD_ERR("invalid: gather col %u, varchar isn't supported!\n", j); + return -WD_EINVAL; + default: + break; + } + } + } + + return WD_SUCCESS; +} + +static int join_gather_param_check(struct wd_join_gather_sess_setup *setup) +{ + int ret; + + switch (setup->alg) { + case WD_JOIN: + return join_check_params(setup->join_table.build_key_cols, + setup->join_table.build_key_cols_num); + case WD_GATHER: + return gather_check_params(setup); + case WD_JOIN_GATHER: + ret = join_check_params(setup->join_table.build_key_cols, + setup->join_table.build_key_cols_num); + if (ret) + return ret; + + return gather_check_params(setup); + default: + return -WD_EINVAL; + } +} + +static int transfer_col_info(struct wd_join_gather_col_info *cols, + struct hw_join_gather_data *data, __u32 col_num) +{ + __u32 i; + + for (i = 0; i < col_num; i++) { + switch (cols[i].data_type) { + case WD_DAE_CHAR: + data[i].hw_type = DAE_CHAR; + data[i].data_info = cols[i].data_info; + break; + case WD_DAE_LONG_DECIMAL: + data[i].hw_type = DAE_DECIMAL128; + break; + case WD_DAE_SHORT_DECIMAL: + data[i].hw_type = DAE_DECIMAL64; + break; + case WD_DAE_LONG: + data[i].hw_type = DAE_SINT64; + break; + case WD_DAE_INT: + case WD_DAE_DATE: + data[i].hw_type = DAE_SINT32; + break; + default: + return -WD_EINVAL; + } + } + + return WD_SUCCESS; +} + +static int transfer_cols_to_hw_type(struct wd_join_gather_col_info *cols, + struct hw_join_gather_data *hw_data, __u32 cols_num) +{ + struct hw_join_gather_data tmp_data[DAE_MAX_KEY_COLS] = {0}; + __u32 type_num = ARRAY_SIZE(hw_data_type_order); + __u32 i, j, k = 0; + int ret; + + ret = transfer_col_info(cols, tmp_data, cols_num); + if (ret) + return ret; + + for (i = 0; i < type_num; i++) { + for (j = 0; j < cols_num; j++) { + if (hw_data_type_order[i] != tmp_data[j].hw_type) + continue; + hw_data[k].usr_col_idx = j; + hw_data[k].hw_type = tmp_data[j].hw_type; + hw_data[k++].data_info = tmp_data[j].data_info; + } + } + + return WD_SUCCESS; +} + +static int transfer_data_to_hw_type(struct join_gather_col_data *cols_data, + struct wd_join_gather_sess_setup *setup) +{ + struct wd_gather_table_info *tables = setup->gather_tables; + struct wd_join_gather_col_info *gather_cols; + struct hw_join_gather_data *hw_data; + __u32 n, j; + int ret; + + for (n = 0; n < setup->gather_table_num; n++) { + gather_cols = tables[n].cols; + hw_data = cols_data->gather_data[n]; + ret = transfer_cols_to_hw_type(gather_cols, hw_data, tables[n].cols_num); + if (ret) + return ret; + + cols_data->gather_cols_num[n] = tables[n].cols_num; + for (j = 0; j < tables[n].cols_num; j++) + if (gather_cols[j].has_empty) + cols_data->has_empty[n] |= (1 << j); + } + + return WD_SUCCESS; +} + +static int transfer_key_to_hw_type(struct join_gather_col_data *cols_data, + struct wd_join_gather_sess_setup *setup) +{ + struct wd_join_gather_col_info *key_cols = setup->join_table.build_key_cols; + struct hw_join_gather_data *hw_key_data = cols_data->key_data; + __u32 cols_num = setup->join_table.build_key_cols_num; + int ret; + + ret = transfer_cols_to_hw_type(key_cols, hw_key_data, cols_num); + if (ret) + return ret; + + cols_data->key_num = cols_num; + + return WD_SUCCESS; +} + +static int join_get_table_rowsize(struct join_gather_col_data *cols_data, + struct wd_join_gather_sess_setup *setup) +{ + struct hw_join_gather_data *key_data = cols_data->key_data; + __u32 key_num = cols_data->key_num; + __u64 row_count_size = 0; + __u32 i; + + cols_data->index_num = setup->join_table.hash_table_index_num; + + if (cols_data->index_num > HASH_TABLE_MAX_INDEX_NUM) { + WD_ERR("invalid: hash table index num is not supported!\n"); + return -WD_EINVAL; + } else if (!cols_data->index_num) { + WD_INFO("Hash table index num is not set, set to default: 1!\n"); + cols_data->index_num = HASH_TABLE_INDEX_NUM; + } + + /* With a restriction on the col number, the sum lengths will not overflow. */ + for (i = 0; i < key_num; i++) + row_count_size += get_data_type_size(key_data[i].hw_type, 0); + + row_count_size = ALIGN(row_count_size, DAE_KEY_ALIGN_SIZE); + row_count_size += HASH_TABLE_HEAD_TAIL_SIZE + + cols_data->index_num * HASH_TABLE_INDEX_SIZE; + if (row_count_size > DAE_MAX_ROW_SIZE) { + WD_ERR("invalid: hash table row size %llu, hash_table_index_num %u!\n", + row_count_size, cols_data->index_num); + return -WD_EINVAL; + } + + if (row_count_size <= ROW_SIZE32) + return ROW_SIZE32; + + if (row_count_size <= ROW_SIZE64) + return ROW_SIZE64; + + if (row_count_size <= ROW_SIZE128) + return ROW_SIZE128; + + if (row_count_size <= ROW_SIZE256) + return ROW_SIZE256; + + return ROW_SIZE512; +} + +static void gather_get_batch_rowsize(struct join_gather_col_data *cols_data, + struct wd_join_gather_sess_setup *setup, + __u32 *batch_row_size) +{ + struct wd_gather_table_info *tables = setup->gather_tables; + struct hw_join_gather_data *gather_data; + __u32 row_count_size = 0; + __u32 n, i; + + cols_data->gather_table_num = setup->gather_table_num; + for (n = 0; n < setup->gather_table_num; n++) { + row_count_size = 0; + gather_data = cols_data->gather_data[n]; + + /* With a restriction on the col number, the sum length will not overflow. */ + for (i = 0; i < tables[n].cols_num; i++) + row_count_size += get_data_type_size(gather_data[i].hw_type, + gather_data[i].data_info); + + batch_row_size[n] = row_count_size + GATHER_ROW_BATCH_EMPTY_SIZE; + } +} + +static int join_gather_fill_ctx(struct join_gather_ctx *ctx, + struct wd_join_gather_sess_setup *setup) +{ + struct join_gather_col_data *cols_data = &ctx->cols_data; + int ret; + + if (setup->alg != WD_GATHER) { + ret = transfer_key_to_hw_type(cols_data, setup); + if (ret) + return ret; + + ret = join_get_table_rowsize(cols_data, setup); + if (ret < 0) + return -WD_EINVAL; + ctx->hash_table_row_size = ret; + } + + if (setup->alg != WD_JOIN) { + ret = transfer_data_to_hw_type(cols_data, setup); + if (ret) + return ret; + + gather_get_batch_rowsize(cols_data, setup, ctx->batch_row_size); + } + + return WD_SUCCESS; +} + +static void join_gather_sess_priv_uninit(struct wd_alg_driver *drv, void *priv) +{ + struct join_gather_ctx *ctx = priv; + + if (!ctx) { + WD_ERR("invalid: dae sess uninit priv is NULL!\n"); + return; + } + + pthread_spin_destroy(&ctx->lock); + free(ctx); +} + +static int join_gather_sess_priv_init(struct wd_alg_driver *drv, + struct wd_join_gather_sess_setup *setup, void **priv) +{ + struct join_gather_ctx *ctx; + int ret; + + if (!drv || !drv->priv) { + WD_ERR("invalid: dae drv is NULL!\n"); + return -WD_EINVAL; + } + + if (!setup || !priv) { + WD_ERR("invalid: dae sess priv is NULL!\n"); + return -WD_EINVAL; + } + + ret = join_gather_param_check(setup); + if (ret) + return -WD_EINVAL; + + ctx = calloc(1, sizeof(struct join_gather_ctx)); + if (!ctx) + return -WD_ENOMEM; + + ret = join_gather_fill_ctx(ctx, setup); + if (ret) + goto free_ctx; + + ret = pthread_spin_init(&ctx->lock, PTHREAD_PROCESS_SHARED); + if (ret) + goto free_ctx; + + *priv = ctx; + + return WD_SUCCESS; + +free_ctx: + free(ctx); + return ret; +} + +static int join_get_table_row_size(struct wd_alg_driver *drv, void *param) +{ + struct join_gather_ctx *ctx = param; + + if (!ctx) + return -WD_EINVAL; + + return ctx->hash_table_row_size; +} + +static int gather_get_batch_row_size(struct wd_alg_driver *drv, void *param, + __u32 *row_size, __u32 size) +{ + struct join_gather_ctx *ctx = param; + + if (!ctx) + return -WD_EINVAL; + + if (!size || size > DAE_MAX_TABLE_NUM * sizeof(__u32)) + return -WD_EINVAL; + + memcpy(row_size, ctx->batch_row_size, size); + + return 0; +} + +static int join_hash_table_init(struct wd_alg_driver *drv, + struct wd_dae_hash_table *table, void *priv) +{ + struct join_gather_ctx *ctx = priv; + + if (!ctx || !table) + return -WD_EINVAL; + + return dae_hash_table_init(&ctx->table_data, &ctx->rehash_table, + table, ctx->hash_table_row_size); +} + +static int join_gather_get_extend_ops(void *ops) +{ + struct wd_join_gather_ops *join_gather_ops = (struct wd_join_gather_ops *)ops; + + if (!join_gather_ops) + return -WD_EINVAL; + + join_gather_ops->get_table_row_size = join_get_table_row_size; + join_gather_ops->get_batch_row_size = gather_get_batch_row_size; + join_gather_ops->hash_table_init = join_hash_table_init; + join_gather_ops->sess_init = join_gather_sess_priv_init; + join_gather_ops->sess_uninit = join_gather_sess_priv_uninit; + + return WD_SUCCESS; +} + + +#define GEN_JOIN_GATHER_DRIVER(dae_alg_name) \ +{\ + .drv_name = "hisi_zip",\ + .alg_name = (dae_alg_name),\ + .calc_type = UADK_ALG_HW,\ + .priority = 100,\ + .queue_num = DAE_CTX_Q_NUM_DEF,\ + .op_type_num = 1,\ + .fallback = 0,\ + .init = dae_init,\ + .exit = dae_exit,\ + .send = join_gather_send,\ + .recv = join_gather_recv,\ + .get_extend_ops = join_gather_get_extend_ops,\ +} + +static struct wd_alg_driver join_gather_driver[] = { + GEN_JOIN_GATHER_DRIVER("hashjoin"), + GEN_JOIN_GATHER_DRIVER("gather"), + GEN_JOIN_GATHER_DRIVER("join-gather"), +}; + +#ifdef WD_STATIC_DRV +void hisi_dae_join_gather_probe(void) +#else +static void __attribute__((constructor)) hisi_dae_join_gather_probe(void) +#endif +{ + __u32 alg_num = ARRAY_SIZE(join_gather_driver); + int ret; + __u32 i; + + WD_INFO("Info: register DAE hashjoin and gather alg drivers!\n"); + for (i = 0; i < alg_num; i++) { + ret = wd_alg_driver_register(&join_gather_driver[i]); + if (ret && ret != -WD_ENODEV) + WD_ERR("Error: register %s failed!\n", + join_gather_driver[i].alg_name); + } +} + +#ifdef WD_STATIC_DRV +void hisi_dae_join_gather_remove(void) +#else +static void __attribute__((destructor)) hisi_dae_join_gather_remove(void) +#endif +{ + __u32 alg_num = ARRAY_SIZE(join_gather_driver); + __u32 i; + + WD_INFO("Info: unregister DAE alg drivers!\n"); + for (i = 0; i < alg_num; i++) + wd_alg_driver_unregister(&join_gather_driver[i]); +} diff --git a/include/drv/wd_join_gather_drv.h b/include/drv/wd_join_gather_drv.h new file mode 100644 index 00000000..80fb9322 --- /dev/null +++ b/include/drv/wd_join_gather_drv.h @@ -0,0 +1,52 @@ +/* SPDX-License-Identifier: Apache-2.0 */ +/* + * Copyright 2025 Huawei Technologies Co.,Ltd. All rights reserved. + */ + +#ifndef __WD_JOIN_GATHER_DRV_H +#define __WD_JOIN_GATHER_DRV_H + +#include <asm/types.h> +#include "wd_join_gather.h" +#include "wd_util.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct wd_join_gather_msg { + __u32 tag; + __u32 key_cols_num; + __u32 result; + __u32 input_row_num; + __u32 output_row_num; + __u32 consumed_row_num; + __u32 produced_row_num; + enum wd_join_gather_op_type op_type; + enum multi_batch_index_type index_type; + bool output_done; + bool key_out_en; + bool multi_batch_en; + struct wd_join_gather_req req; + struct wd_dae_hash_table hash_table; + void *priv; +}; + +struct wd_join_gather_ops { + int (*get_table_row_size)(struct wd_alg_driver *drv, void *priv); + int (*get_batch_row_size)(struct wd_alg_driver *drv, void *priv, + __u32 *batch_row_size, __u32 size); + int (*sess_init)(struct wd_alg_driver *drv, + struct wd_join_gather_sess_setup *setup, void **priv); + void (*sess_uninit)(struct wd_alg_driver *drv, void *priv); + int (*hash_table_init)(struct wd_alg_driver *drv, + struct wd_dae_hash_table *hash_table, void *priv); +}; + +struct wd_join_gather_msg *wd_join_gather_get_msg(__u32 idx, __u32 tag); + +#ifdef __cplusplus +} +#endif + +#endif /* __WD_JOIN_GATHER_DRV_H */ diff --git a/include/wd_alg.h b/include/wd_alg.h index 441b3bef..2fc350af 100644 --- a/include/wd_alg.h +++ b/include/wd_alg.h @@ -205,12 +205,14 @@ void hisi_hpre_probe(void); void hisi_zip_probe(void); void hisi_dae_probe(void); void hisi_udma_probe(void); +void hisi_dae_join_gather_probe(void); void hisi_sec2_remove(void); void hisi_hpre_remove(void); void hisi_zip_remove(void); void hisi_dae_remove(void); void hisi_udma_remove(void); +void hisi_dae_join_gather_remove(void); #endif diff --git a/include/wd_dae.h b/include/wd_dae.h index aa9f966c..64f17dc4 100644 --- a/include/wd_dae.h +++ b/include/wd_dae.h @@ -57,6 +57,18 @@ struct wd_dae_col_addr { __u64 offset_size; }; +/** + * wd_dae_row_addr - information of row memory. + * @addr: The start address of row memory. + * @row_size: Memory size occupied by a row. + * @row_num: Total number of rows. + */ +struct wd_dae_row_addr { + void *addr; + __u32 row_size; + __u32 row_num; +}; + /** * wd_dae_hash_table - Hash table information of DAE. * @std_table: Address of standard hash table. diff --git a/include/wd_join_gather.h b/include/wd_join_gather.h new file mode 100644 index 00000000..4962ee35 --- /dev/null +++ b/include/wd_join_gather.h @@ -0,0 +1,352 @@ +/* SPDX-License-Identifier: Apache-2.0 */ +/* + * Copyright 2025 Huawei Technologies Co.,Ltd. All rights reserved. + */ + +#ifndef __WD_JOIN_GATHER_H +#define __WD_JOIN_GATHER_H + +#include <dlfcn.h> +#include <asm/types.h> +#include "wd_dae.h" + +#ifdef __cplusplus +extern "C" { +#endif + +enum wd_join_gather_alg { + WD_JOIN, + WD_GATHER, + WD_JOIN_GATHER, + WD_JOIN_GATHER_ALG_MAX, +}; + +/** + * wd_join_gather_op_type - operation type for hash join and gather. + */ +enum wd_join_gather_op_type { + WD_JOIN_BUILD_HASH, + WD_JOIN_PROBE, + WD_JOIN_REHASH, + WD_GATHER_CONVERT, + WD_GATHER_COMPLETE, + WD_JOIN_GATHER_OP_TYPE_MAX, +}; + +/** + * wd_join_gather_task_error_type - hash join and gather task error type. + */ +enum wd_join_gather_task_error_type { + WD_JOIN_GATHER_TASK_DONE, + WD_JOIN_GATHER_IN_EPARA, + WD_JOIN_GATHER_NEED_REHASH, + WD_JOIN_GATHER_INVALID_HASH_TABLE, + WD_JOIN_GATHER_PARSE_ERROR, + WD_JOIN_GATHER_BUS_ERROR, +}; + +enum multi_batch_index_type { + WD_BATCH_NUMBER_INDEX, + WD_BATCH_ADDR_INDEX, + WD_BATCH_INDEX_TYPE_MAX, +}; + +/** + * wd_join_gather_col_info - column information. + * @data_type: column data type. + * @data_info: For CHAR, it is size of data, at least 1B. + * For DECIMAL, it is precision of data, high 8 bit: decimal part precision, + * low 8 bit: the whole data precision. + * @has_empty: indicates whether the column contains empty data. + */ +struct wd_join_gather_col_info { + enum wd_dae_data_type data_type; + __u16 data_info; + bool has_empty; +}; + +/** + * wd_gather_table_info - gather table information. + * @cols: Information of gather table columns. + * @cols_num: Number of gather table columns. + * @is_multi_batch: indicates single or multi batch task. + */ +struct wd_gather_table_info { + struct wd_join_gather_col_info *cols; + __u32 cols_num; + bool is_multi_batch; +}; + +/** + * wd_join_table_info - join table information. + * @build_key_cols: Information of build table key columns. + * @probe_key_cols: Information of probe table key columns. + * @build_key_cols_num: Number of build table key columns. + * @probe_key_cols_num: Number of probe table key columns. + * @key_output_enable: Indicates whether output key columns. + * @hash_table_index_num: Number of original rows can be stored + * in each row of a hash table. + */ +struct wd_join_table_info { + struct wd_join_gather_col_info *build_key_cols; + struct wd_join_gather_col_info *probe_key_cols; + __u32 build_key_cols_num; + __u32 probe_key_cols_num; + bool key_output_enable; + __u32 hash_table_index_num; +}; + +/** + * wd_join_gather_sess_setup - Hash join and gather session setup information. + * @join_table: Information of join table. + * @gather_tables: Information of gather table. + * @gather_table_num: Number of gather table. + * @alg: Alg for this session. + * @index_type: Indicates the index type, 0 for batch number and row number, + * 1 for batch address and row number. + * @charset_info: Charset information + * @sched_param: Parameters of the scheduling policy, + * usually allocated according to struct sched_params. + */ +struct wd_join_gather_sess_setup { + struct wd_join_table_info join_table; + struct wd_gather_table_info *gather_tables; + __u32 gather_table_num; + + enum wd_join_gather_alg alg; + enum multi_batch_index_type index_type; + struct wd_dae_charset charset_info; + void *sched_param; +}; + +struct wd_join_gather_req; +typedef void *wd_join_gather_cb_t(struct wd_join_gather_req *req, void *cb_param); + +/** + * wd_probe_out_info - Hash join probe output info. + * @build_index: address information of multi batch index. + * @probe_index: address information of single batch index. + * @breakpoint: address information of probe breakpoint. + * @key_cols: address information of output key columns. + * @key_cols_num: number of output key columns. + */ +struct wd_probe_out_info { + struct wd_dae_row_addr build_index; + struct wd_dae_row_addr probe_index; + struct wd_dae_row_addr breakpoint; + struct wd_dae_col_addr *key_cols; + __u32 key_cols_num; +}; + +/** + * wd_join_req - Hash join request. + * @build_batch_addr: Row-storaged batch address, the batch is used to store build + * table data cols in row format. This field is only used for batch addr index. + * + * @probe_output: The information for hash join probe stage. + * @key_cols: key columns from build table or probe table. + * @key_cols_num: key columns number. + * @batch_row_offset: Indicates the start row number of the input column. + * @build_batch_index: build table batch index, start from 0. + */ +struct wd_join_req { + struct wd_dae_row_addr build_batch_addr; + struct wd_probe_out_info probe_output; + struct wd_dae_col_addr *key_cols; + __u32 key_cols_num; + __u32 batch_row_offset; + __u32 build_batch_index; +}; + +/** + * wd_row_batch_info - Information of some row-storaged batchs. + * @batch_addr: Addr list of row batchs. + * @batch_row_size: Row size of each row batch. + * @batch_row_num: Row number of each row batch. + * @batch_num: Total number of row batchs. + */ +struct wd_row_batch_info { + void **batch_addr; + __u32 *batch_row_size; + __u32 *batch_row_num; + __u32 batch_num; +}; + +/** + * wd_gather_req - Hash join and gather operation request. + * @index: address information of multi batch index or single batch index. + * @row_batchs: address information of row batchs. + * @data_cols: data columns from gather table. + * @data_cols_num: columns number from gather table. + * @table_index: The table index from the session's gather_tables to do tasks. + */ +struct wd_gather_req { + struct wd_dae_row_addr index; + struct wd_row_batch_info row_batchs; + struct wd_dae_col_addr *data_cols; + __u32 data_cols_num; + __u32 table_index; +}; + +/** + * wd_join_gather_req - Hash join and gather operation request. + * @op_type: The operation type for hash join or gather task. + * @join_req: The request for hash join. + * @gather_req: The request for gather. + * @input_row_num: Row count of input column. + * @output_row_num: Expected row count of output column. + * @consumed_row_num: Row count of input data that has been processed. + * @produced_row_num: Real row count of output column. + * @cb: Callback function for the asynchronous mode. + * @cb_param: Parameters of the callback function. + * @state: Error information written back by the hardware. + * @output_done: For rehash, it indicates whether all data in hash table has been output. + * For probe task, it indicates whether all data of one probe batch has been processed. + * @priv: Private data from user(reserved). + */ +struct wd_join_gather_req { + /* user fill-in fields */ + enum wd_join_gather_op_type op_type; + struct wd_join_req join_req; + struct wd_gather_req gather_req; + __u32 input_row_num; + __u32 output_row_num; + wd_join_gather_cb_t *cb; + void *cb_param; + void *priv; + + /* uadk driver writeback fields */ + enum wd_join_gather_task_error_type state; + __u32 consumed_row_num; + __u32 produced_row_num; + bool output_done; +}; + +/** + * wd_join_gather_init() - A simplify interface to initializate uadk. + * Users just need to descripe the deployment of business scenarios. + * Then the initialization will request appropriate + * resources to support the business scenarios. + * To make the initializate simpler, ctx_params support set NULL. + * And then the function will set them as driver's default. + * + * @alg: Supported algorithms: hashjoin, gather, join-gather. + * @sched_type: The scheduling type users want to use. + * @task_type: Task types, including soft computing, hardware and hybrid computing. + * @ctx_params: The ctxs resources users want to use. Include per operation + * type ctx numbers and business process run numa. + * + * Return 0 if succeed and others if fail. + */ +int wd_join_gather_init(char *alg, __u32 sched_type, int task_type, + struct wd_ctx_params *ctx_params); + +/** + * wd_join_gather_uninit() - Uninitialise ctx configuration and scheduler. + */ +void wd_join_gather_uninit(void); + +/** + * wd_join_gather_alloc_sess() - Allocate a hash join or gather session + * @setup: Parameters to setup this session. + * + * Return 0 if fail and others if succeed. + */ +handle_t wd_join_gather_alloc_sess(struct wd_join_gather_sess_setup *setup); + +/** + * wd_join_gather_free_sess() - Free a hash join or gather session + * @sess: The session need to be freed. + */ +void wd_join_gather_free_sess(handle_t h_sess); + +/** + * wd_join_set_hash_table() - Set hash table to the wd session + * @sess, Session to be initialized. + * @info, Hash table information to set. + * + * Return 0 if succeed and others if fail. + */ +int wd_join_set_hash_table(handle_t h_sess, struct wd_dae_hash_table *info); + +/** + * wd_join_build_hash_sync()/wd_join_build_hash_async() - Build the hash table. + * @sess: Wd session + * @req: Operational data. + * + * Return 0 if succeed and others if fail. + */ +int wd_join_build_hash_sync(handle_t h_sess, struct wd_join_gather_req *req); +int wd_join_build_hash_async(handle_t h_sess, struct wd_join_gather_req *req); + +/** + * wd_join_probe_sync()/wd_join_probe_async() - Probe and output the index or key. + * @sess: Wd session + * @req: Operational data. + * + * Return 0 if succeed and others if fail. + */ +int wd_join_probe_sync(handle_t h_sess, struct wd_join_gather_req *req); +int wd_join_probe_async(handle_t h_sess, struct wd_join_gather_req *req); + +/** + * wd_gather_convert_sync()/wd_gather_convert_async() - Convert a column batch to a row batch. + * @sess: Wd session + * @req: Operational data. + * + * Return 0 if succeed and others if fail. + */ +int wd_gather_convert_sync(handle_t h_sess, struct wd_join_gather_req *req); +int wd_gather_convert_async(handle_t h_sess, struct wd_join_gather_req *req); + +/** + * wd_gather_complete_sync()/wd_gather_complete_async() - map the index with a row batch + * and output the result to a column batch. + * @sess: Wd session + * @req: Operational data. + * + * Return 0 if succeed and others if fail. + */ +int wd_gather_complete_sync(handle_t h_sess, struct wd_join_gather_req *req); +int wd_gather_complete_async(handle_t h_sess, struct wd_join_gather_req *req); + +/** + * wd_join_rehash_sync - Rehash operation, only the synchronous mode is supported. + * @sess: Wd hash join session + * @req: Operational data. + * + * Return 0 if succeed and others if fail. + */ +int wd_join_rehash_sync(handle_t h_sess, struct wd_join_gather_req *req); + +/** + * wd_join_gather_poll() - Poll finished request. + * This function will call poll_policy function which is registered to wd + * by user. + * + * Return 0 if succeed and others if fail. + */ +int wd_join_gather_poll(__u32 expt, __u32 *count); + +/** + * wd_join_get_table_rowsize - Get the hash table's row size. + * @h_sess: Wd session handler. + * + * Return negative value if fail and others if succeed. + */ +int wd_join_get_table_rowsize(handle_t h_sess); + +/** + * wd_gather_get_batch_rowsize - Get the batch row size. + * @h_sess: Wd session handler. + * @table_index: The table index from the session's gather_tables. + * + * Return negative value if fail and others if succeed. + */ +int wd_gather_get_batch_rowsize(handle_t h_sess, __u8 table_index); + +#ifdef __cplusplus +} +#endif + +#endif /* __WD_JOIN_GATHER_H */ diff --git a/include/wd_util.h b/include/wd_util.h index bbb18a7c..4a5204de 100644 --- a/include/wd_util.h +++ b/include/wd_util.h @@ -43,6 +43,7 @@ enum wd_type { WD_ECC_TYPE, WD_AGG_TYPE, WD_UDMA_TYPE, + WD_JOIN_GATHER_TYPE, WD_TYPE_MAX, }; diff --git a/libwd_dae.map b/libwd_dae.map index 6597ff98..f3b06337 100644 --- a/libwd_dae.map +++ b/libwd_dae.map @@ -1,5 +1,24 @@ UADK_DAE_2.0 { global: + wd_join_gather_alloc_sess; + wd_join_gather_free_sess; + wd_join_get_table_rowsize; + wd_gather_get_batch_rowsize; + wd_join_set_hash_table; + wd_join_gather_init; + wd_join_gather_uninit; + wd_join_build_hash_sync; + wd_join_build_hash_async; + wd_join_probe_sync; + wd_join_probe_async; + wd_join_rehash_sync; + wd_join_gather_get_msg; + wd_join_gather_poll; + wd_gather_convert_sync; + wd_gather_complete_sync; + wd_gather_convert_async; + wd_gather_complete_async; + wd_agg_alloc_sess; wd_agg_free_sess; wd_agg_get_table_rowsize; diff --git a/wd_join_gather.c b/wd_join_gather.c new file mode 100644 index 00000000..0a1b2d1f --- /dev/null +++ b/wd_join_gather.c @@ -0,0 +1,1823 @@ +/* SPDX-License-Identifier: Apache-2.0 */ +/* + * Copyright 2025 Huawei Technologies Co.,Ltd. All rights reserved. + */ + +#include <stdlib.h> +#include <pthread.h> +#include <sched.h> +#include <limits.h> +#include "include/drv/wd_join_gather_drv.h" +#include "wd_join_gather.h" + +#define DECIMAL_PRECISION_OFFSET 8 +#define DAE_INT_SIZE 4 +#define DAE_LONG_SIZE 8 +#define DAE_LONG_DECIMAL_SIZE 16 + +/* Sum of the max row number of standard and external hash table */ +#define MAX_HASH_TABLE_ROW_NUM 0x1FFFFFFFE + +enum wd_join_sess_state { + WD_JOIN_SESS_UNINIT, /* Uninit session */ + WD_JOIN_SESS_INIT, /* Hash table has been set */ + WD_JOIN_SESS_BUILD_HASH, /* Input stage has started */ + WD_JOIN_SESS_PREPARE_REHASH, /* New hash table has been set */ + WD_JOIN_SESS_REHASH, /* Rehash stage has started */ + WD_JOIN_SESS_PROBE, /* Output stage has started */ +}; + +struct wd_join_gather_setting { + enum wd_status status; + struct wd_ctx_config_internal config; + struct wd_sched sched; + struct wd_async_msg_pool pool; + struct wd_alg_driver *driver; + void *priv; + void *dlhandle; + void *dlh_list; +}; + +struct wd_join_cols_conf { + struct wd_join_gather_col_info *cols; + __u64 *data_size; + __u32 cols_num; + bool key_output_enable; +}; + +struct wd_gather_tables_conf { + struct wd_gather_table_info *tables; + __u32 *batch_row_size; + __u64 **data_size; + __u32 table_num; +}; + +struct wd_join_gather_sess { + enum multi_batch_index_type index_type; + enum wd_join_sess_state state; + enum wd_join_gather_alg alg; + struct wd_join_gather_ops ops; + struct wd_join_cols_conf join_conf; + struct wd_gather_tables_conf gather_conf; + struct wd_dae_hash_table hash_table; + wd_dev_mask_t *dev_mask; + void *sched_key; + void *priv; +}; + +static const char *wd_join_gather_alg[WD_JOIN_GATHER_ALG_MAX] = { + "hashjoin", "gather", "join-gather" +}; + +static struct wd_init_attrs wd_join_gather_init_attrs; +static struct wd_join_gather_setting wd_join_gather_setting; +static int wd_join_gather_poll_ctx(__u32 idx, __u32 expt, __u32 *count); + +static void wd_join_gather_close_driver(void) +{ +#ifndef WD_STATIC_DRV + wd_dlclose_drv(wd_join_gather_setting.dlh_list); +#else + wd_release_drv(wd_join_gather_setting.driver); + hisi_dae_join_gather_remove(); +#endif +} + +static int wd_join_gather_open_driver(void) +{ +#ifndef WD_STATIC_DRV + /* + * Driver lib file path could set by env param. + * then open tham by wd_dlopen_drv() + * use NULL means dynamic query path + */ + wd_join_gather_setting.dlh_list = wd_dlopen_drv(NULL); + if (!wd_join_gather_setting.dlh_list) { + WD_ERR("fail to open driver lib files.\n"); + return -WD_EINVAL; + } +#else + hisi_dae_join_gather_probe(); +#endif + return WD_SUCCESS; +} + +static bool wd_join_gather_check_inner(void) +{ + struct uacce_dev_list *list; + + list = wd_get_accel_list("hashjoin"); + if (!list) + goto out; + wd_free_list_accels(list); + + list = wd_get_accel_list("gather"); + if (!list) + goto out; + wd_free_list_accels(list); + + return true; +out: + WD_ERR("invalid: the device cannot support hashjoin and gather!\n"); + return false; +} + +static bool wd_join_gather_alg_check(const char *alg_name) +{ + __u32 i; + + /* Check for the virtual algorithms */ + if (!strcmp(alg_name, "join-gather")) + return wd_join_gather_check_inner(); + + for (i = 0; i < WD_JOIN_GATHER_ALG_MAX; i++) { + /* Some algorithms do not support all modes */ + if (!wd_join_gather_alg[i] || !strlen(wd_join_gather_alg[i])) + continue; + if (!strcmp(alg_name, wd_join_gather_alg[i])) + return true; + } + + return false; +} + +static int check_col_data_info(enum wd_dae_data_type type, __u16 col_data_info) +{ + __u8 all_precision, decimal_precision; + + switch (type) { + case WD_DAE_DATE: + case WD_DAE_INT: + case WD_DAE_LONG: + case WD_DAE_VARCHAR: + break; + case WD_DAE_SHORT_DECIMAL: + case WD_DAE_LONG_DECIMAL: + /* High 8 bit: decimal part precision, low 8 bit: the whole data precision */ + all_precision = col_data_info; + decimal_precision = col_data_info >> DECIMAL_PRECISION_OFFSET; + if (!all_precision || decimal_precision > all_precision) { + WD_ERR("failed to check data precision, all: %u, decimal: %u!\n", + all_precision, decimal_precision); + return -WD_EINVAL; + } + break; + case WD_DAE_CHAR: + if (!col_data_info) { + WD_ERR("invalid: char length is zero!\n"); + return -WD_EINVAL; + } + break; + default: + WD_ERR("invalid: data type %u is not supported!\n", type); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int get_data_type_size(enum wd_dae_data_type type, __u16 col_data_info, + __u64 *col, __u32 idx) +{ + switch (type) { + case WD_DAE_DATE: + case WD_DAE_INT: + col[idx] = DAE_INT_SIZE; + break; + case WD_DAE_LONG: + case WD_DAE_SHORT_DECIMAL: + col[idx] = DAE_LONG_SIZE; + break; + case WD_DAE_LONG_DECIMAL: + col[idx] = DAE_LONG_DECIMAL_SIZE; + break; + case WD_DAE_CHAR: + col[idx] = col_data_info; + break; + case WD_DAE_VARCHAR: + col[idx] = 0; + break; + default: + return -WD_EINVAL; + } + return WD_SUCCESS; +} + +static int check_key_cols_info(struct wd_join_gather_sess_setup *setup) +{ + struct wd_join_table_info *table = &setup->join_table; + struct wd_join_gather_col_info *build = table->build_key_cols; + __u32 i; + int ret; + + if (table->build_key_cols_num != table->probe_key_cols_num) { + WD_ERR("invalid: build key_cols_num: %u, probe key_cols_num: %u!\n", + table->build_key_cols_num, table->probe_key_cols_num); + return -WD_EINVAL; + } + + ret = memcmp(table->build_key_cols, table->probe_key_cols, + table->build_key_cols_num * sizeof(struct wd_join_gather_col_info)); + if (ret) { + WD_ERR("invalid: build and probe table key infomation is not same!\n"); + return -WD_EINVAL; + } + + for (i = 0; i < table->build_key_cols_num; i++) { + if (!build[i].has_empty) { + WD_ERR("invalid: key col has no empty data! col: %u\n", i); + return -WD_EINVAL; + } + ret = check_col_data_info(build[i].data_type, build[i].data_info); + if (ret) { + WD_ERR("failed to check key col data info! col: %u\n", i); + return ret; + } + } + + return WD_SUCCESS; +} + +static int wd_join_check_params(struct wd_join_gather_sess_setup *setup) +{ + struct wd_join_table_info *table = &setup->join_table; + + if (!table->build_key_cols_num || !table->build_key_cols) { + WD_ERR("invalid: build key cols is NULL or key_cols_num is 0!\n"); + return -WD_EINVAL; + } + + if (!table->probe_key_cols_num || !table->probe_key_cols) { + WD_ERR("invalid: probe key cols is NULL or key_cols_num is 0!\n"); + return -WD_EINVAL; + } + + if (setup->index_type >= WD_BATCH_INDEX_TYPE_MAX) { + WD_ERR("failed to check batch index type!\n"); + return -WD_EINVAL; + } + + if (check_key_cols_info(setup)) { + WD_ERR("failed to check join setup key cols info!\n"); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int wd_gather_check_params(struct wd_join_gather_sess_setup *setup) +{ + struct wd_gather_table_info *table = setup->gather_tables; + struct wd_join_gather_col_info *col; + __u32 i, j; + int ret; + + if (!setup->gather_tables || !setup->gather_table_num) { + WD_ERR("invalid: gather table is NULL, table num: %u\n", setup->gather_table_num); + return -WD_EINVAL; + } + + if (setup->index_type >= WD_BATCH_INDEX_TYPE_MAX) { + WD_ERR("failed to check gather batch index type!\n"); + return -WD_EINVAL; + } + + for (i = 0; i < setup->gather_table_num; i++) { + if (!table[i].cols || !table[i].cols_num) { + WD_ERR("failed to check gather table cols, num: %u\n", table[i].cols_num); + return -WD_EINVAL; + } + col = table[i].cols; + for (j = 0; j < table[i].cols_num; j++) { + ret = check_col_data_info(col[j].data_type, col[j].data_info); + if (ret) { + WD_ERR("failed to check gather info! col: %u, table: %u\n", j, i); + return ret; + } + } + } + + return WD_SUCCESS; +} + +static int wd_join_gather_check_params(struct wd_join_gather_sess_setup *setup) +{ + int ret; + + if (!setup) { + WD_ERR("invalid: hashjoin or gather sess setup is NULL!\n"); + return -WD_EINVAL; + } + + switch (setup->alg) { + case WD_JOIN: + return wd_join_check_params(setup); + case WD_GATHER: + return wd_gather_check_params(setup); + case WD_JOIN_GATHER: + ret = wd_join_check_params(setup); + if (ret) + return ret; + + return wd_gather_check_params(setup); + default: + WD_ERR("invalid: hashjoin sess setup alg is wrong!\n"); + return -WD_EINVAL; + } +} + +static void sess_data_size_uninit(struct wd_join_gather_sess *sess) +{ + __u32 i; + + if (sess->join_conf.cols) + free(sess->join_conf.cols); + + if (sess->gather_conf.tables) { + for (i = 0; i < sess->gather_conf.table_num; i++) + free(sess->gather_conf.data_size[i]); + + free(sess->gather_conf.tables); + } +} + +static int sess_data_size_init(struct wd_join_gather_sess *sess, + struct wd_join_gather_sess_setup *setup) +{ + struct wd_gather_table_info *gtable = setup->gather_tables; + struct wd_join_table_info *jtable = &setup->join_table; + struct wd_join_gather_col_info *key = jtable->build_key_cols; + __u64 key_size, key_data_size, gather_size, gather_data_size; + __u32 i, j; + + __atomic_store_n(&sess->state, WD_JOIN_SESS_UNINIT, __ATOMIC_RELEASE); + + if (setup->alg != WD_GATHER) { + key_size = jtable->build_key_cols_num * sizeof(struct wd_join_gather_col_info); + key_data_size = jtable->build_key_cols_num * sizeof(__u64); + sess->join_conf.cols = malloc(key_size + key_data_size); + if (!sess->join_conf.cols) + return -WD_ENOMEM; + memcpy(sess->join_conf.cols, key, key_size); + + sess->join_conf.data_size = (void *)sess->join_conf.cols + key_size; + for (i = 0; i < jtable->build_key_cols_num; i++) + (void)get_data_type_size(key[i].data_type, key[i].data_info, + sess->join_conf.data_size, i); + sess->join_conf.cols_num = jtable->build_key_cols_num; + + if (setup->alg == WD_JOIN) + return WD_SUCCESS; + } + + gather_size = setup->gather_table_num * sizeof(struct wd_gather_table_info); + gather_data_size = setup->gather_table_num * sizeof(__u64 *); + sess->gather_conf.tables = malloc(gather_size + gather_data_size); + if (!sess->gather_conf.tables) + goto free_join; + memcpy(sess->gather_conf.tables, gtable, gather_size); + + sess->gather_conf.data_size = (void *)sess->gather_conf.tables + gather_size; + for (i = 0; i < setup->gather_table_num; i++) { + sess->gather_conf.data_size[i] = malloc(gtable[i].cols_num * sizeof(__u64)); + if (!sess->gather_conf.data_size[i]) + goto free_gather; + } + + for (i = 0; i < setup->gather_table_num; i++) + for (j = 0; j < gtable[i].cols_num; j++) + (void)get_data_type_size(gtable[i].cols[j].data_type, + gtable[i].cols[j].data_info, + sess->gather_conf.data_size[i], j); + sess->gather_conf.table_num = setup->gather_table_num; + + return WD_SUCCESS; + +free_gather: + for (j = 0; j < i; j++) + free(sess->gather_conf.data_size[j]); + free(sess->gather_conf.tables); +free_join: + if (setup->alg != WD_GATHER) + free(sess->join_conf.cols); + return -WD_ENOMEM; +} + +static void wd_join_gather_uninit_sess(struct wd_join_gather_sess *sess) +{ + if (sess->gather_conf.batch_row_size) + free(sess->gather_conf.batch_row_size); + + if (sess->ops.sess_uninit) + sess->ops.sess_uninit(wd_join_gather_setting.driver, sess->priv); +} + +static int wd_join_gather_init_sess(struct wd_join_gather_sess *sess, + struct wd_join_gather_sess_setup *setup) +{ + struct wd_alg_driver *drv = wd_join_gather_setting.driver; + __u32 array_size; + int ret; + + if (sess->ops.sess_init) { + if (!sess->ops.sess_uninit) { + WD_ERR("failed to get session uninit ops!\n"); + return -WD_EINVAL; + } + ret = sess->ops.sess_init(drv, setup, &sess->priv); + if (ret) { + WD_ERR("failed to init session priv!\n"); + return ret; + } + } + + if (sess->ops.get_table_row_size && setup->alg != WD_GATHER) { + ret = sess->ops.get_table_row_size(drv, sess->priv); + if (ret <= 0) { + WD_ERR("failed to get hash table row size: %d!\n", ret); + goto uninit; + } + sess->hash_table.table_row_size = ret; + } + + if (sess->ops.get_batch_row_size && setup->alg != WD_JOIN) { + array_size = setup->gather_table_num * sizeof(__u32); + sess->gather_conf.batch_row_size = malloc(array_size); + if (!sess->gather_conf.batch_row_size) + goto uninit; + + ret = sess->ops.get_batch_row_size(drv, sess->priv, + sess->gather_conf.batch_row_size, + array_size); + if (ret) { + WD_ERR("failed to get batch table row size!\n"); + goto free_batch; + } + } + + return WD_SUCCESS; + +free_batch: + free(sess->gather_conf.batch_row_size); +uninit: + if (sess->ops.sess_uninit) + sess->ops.sess_uninit(drv, sess->priv); + return -WD_EINVAL; +} + +handle_t wd_join_gather_alloc_sess(struct wd_join_gather_sess_setup *setup) +{ + struct wd_join_gather_sess *sess; + int ret; + + ret = wd_join_gather_check_params(setup); + if (ret) + return (handle_t)0; + + sess = malloc(sizeof(struct wd_join_gather_sess)); + if (!sess) { + WD_ERR("failed to alloc join gather session memory!\n"); + return (handle_t)0; + } + memset(sess, 0, sizeof(struct wd_join_gather_sess)); + + sess->alg = setup->alg; + sess->index_type = setup->index_type; + sess->join_conf.key_output_enable = setup->join_table.key_output_enable; + + ret = wd_drv_alg_support(wd_join_gather_alg[sess->alg], wd_join_gather_setting.driver); + if (!ret) { + WD_ERR("failed to check driver alg: %s!\n", wd_join_gather_alg[sess->alg]); + goto free_sess; + } + + /* Some simple scheduler don't need scheduling parameters */ + sess->sched_key = (void *)wd_join_gather_setting.sched.sched_init( + wd_join_gather_setting.sched.h_sched_ctx, setup->sched_param); + if (WD_IS_ERR(sess->sched_key)) { + WD_ERR("failed to init join_gather session schedule key!\n"); + goto free_sess; + } + + if (wd_join_gather_setting.driver->get_extend_ops) { + ret = wd_join_gather_setting.driver->get_extend_ops(&sess->ops); + if (ret) { + WD_ERR("failed to get join gather extend ops!\n"); + goto free_key; + } + } + + ret = wd_join_gather_init_sess(sess, setup); + if (ret) + goto free_key; + + ret = sess_data_size_init(sess, setup); + if (ret) { + WD_ERR("failed to init join gather session data size!\n"); + goto uninit_sess; + } + + return (handle_t)sess; + +uninit_sess: + wd_join_gather_uninit_sess(sess); +free_key: + free(sess->sched_key); +free_sess: + free(sess); + return (handle_t)0; +} + +void wd_join_gather_free_sess(handle_t h_sess) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + + if (!sess) { + WD_ERR("invalid: join gather input sess is NULL!\n"); + return; + } + + sess_data_size_uninit(sess); + + wd_join_gather_uninit_sess(sess); + + if (sess->sched_key) + free(sess->sched_key); + + free(sess); +} + +int wd_gather_get_batch_rowsize(handle_t h_sess, __u8 table_index) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + + if (!sess || !sess->gather_conf.batch_row_size) { + WD_ERR("invalid: gather sess or batch_row_size is NULL!\n"); + return -WD_EINVAL; + } + + if (table_index >= sess->gather_conf.table_num) { + WD_ERR("invalid: gather table index(%u) is larger than %u!\n", + table_index, sess->gather_conf.table_num); + return -WD_EINVAL; + } + + return sess->gather_conf.batch_row_size[table_index]; +} + +int wd_join_get_table_rowsize(handle_t h_sess) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + + if (!sess) { + WD_ERR("invalid: hashjoin input sess is NULL!\n"); + return -WD_EINVAL; + } + + if (sess->alg != WD_JOIN && sess->alg != WD_JOIN_GATHER) { + WD_ERR("invalid: the session is not used for hashjoin!\n"); + return -WD_EINVAL; + } + + if (!sess->hash_table.table_row_size) { + WD_ERR("invalid: hashjoin sess hash table row size is 0!\n"); + return -WD_EINVAL; + } + + return sess->hash_table.table_row_size; +} + +static int wd_join_init_sess_state(struct wd_join_gather_sess *sess, + enum wd_join_sess_state *expected) +{ + enum wd_join_sess_state next; + int ret; + + if (sess->hash_table.std_table) { + *expected = WD_JOIN_SESS_BUILD_HASH; + next = WD_JOIN_SESS_PREPARE_REHASH; + } else { + *expected = WD_JOIN_SESS_UNINIT; + next = WD_JOIN_SESS_INIT; + } + + ret = __atomic_compare_exchange_n(&sess->state, expected, next, + false, __ATOMIC_ACQUIRE, __ATOMIC_RELAXED); + if (!ret) { + WD_ERR("invalid: join sess state is %u!\n", *expected); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +int wd_join_set_hash_table(handle_t h_sess, struct wd_dae_hash_table *info) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + enum wd_join_sess_state expected; + int ret; + + if (!sess || !info) { + WD_ERR("invalid: hashjoin sess or hash table is NULL!\n"); + return -WD_EINVAL; + } + + if (sess->alg != WD_JOIN && sess->alg != WD_JOIN_GATHER) { + WD_ERR("invalid: the session is not used for hashjoin!\n"); + return -WD_EINVAL; + } + + ret = wd_join_init_sess_state(sess, &expected); + if (ret) + return ret; + + if (info->table_row_size != sess->hash_table.table_row_size) { + WD_ERR("invalid: hash table row size is not equal, expt: %u, real: %u!\n", + sess->hash_table.table_row_size, info->table_row_size); + ret = -WD_EINVAL; + goto out; + } + + if (!info->std_table) { + WD_ERR("invalid: standard hash table is NULL!\n"); + ret = -WD_EINVAL; + goto out; + } + + if (info->std_table_row_num < sess->hash_table.std_table_row_num) { + WD_ERR("invalid: standard hash table is too small, expt: %u, real: %u!\n", + sess->hash_table.std_table_row_num, info->std_table_row_num); + ret = -WD_EINVAL; + goto out; + } + + if (!info->ext_table_row_num || !info->ext_table) + WD_INFO("info: extern hash table is NULL!\n"); + + if (sess->ops.hash_table_init) { + ret = sess->ops.hash_table_init(wd_join_gather_setting.driver, + info, sess->priv); + if (ret) + goto out; + } + + memcpy(&sess->hash_table, info, sizeof(struct wd_dae_hash_table)); + + return WD_SUCCESS; + +out: + __atomic_store_n(&sess->state, expected, __ATOMIC_RELEASE); + return ret; +} + +static void wd_join_gather_clear_status(void) +{ + wd_alg_clear_init(&wd_join_gather_setting.status); +} + +static int wd_join_gather_alg_init(struct wd_ctx_config *config, struct wd_sched *sched) +{ + int ret; + + ret = wd_set_epoll_en("WD_JOIN_GATHER_EPOLL_EN", &wd_join_gather_setting.config.epoll_en); + if (ret < 0) + return ret; + + ret = wd_init_ctx_config(&wd_join_gather_setting.config, config); + if (ret < 0) + return ret; + + ret = wd_init_sched(&wd_join_gather_setting.sched, sched); + if (ret < 0) + goto out_clear_ctx_config; + + /* Allocate async pool for every ctx */ + ret = wd_init_async_request_pool(&wd_join_gather_setting.pool, config, WD_POOL_MAX_ENTRIES, + sizeof(struct wd_join_gather_msg)); + if (ret < 0) + goto out_clear_sched; + + ret = wd_alg_init_driver(&wd_join_gather_setting.config, wd_join_gather_setting.driver); + if (ret) + goto out_clear_pool; + + return WD_SUCCESS; + +out_clear_pool: + wd_uninit_async_request_pool(&wd_join_gather_setting.pool); +out_clear_sched: + wd_clear_sched(&wd_join_gather_setting.sched); +out_clear_ctx_config: + wd_clear_ctx_config(&wd_join_gather_setting.config); + return ret; +} + +static int wd_join_gather_alg_uninit(void) +{ + enum wd_status status; + + wd_alg_get_init(&wd_join_gather_setting.status, &status); + if (status == WD_UNINIT) + return -WD_EINVAL; + + /* Uninit async request pool */ + wd_uninit_async_request_pool(&wd_join_gather_setting.pool); + + /* Unset config, sched, driver */ + wd_clear_sched(&wd_join_gather_setting.sched); + + wd_alg_uninit_driver(&wd_join_gather_setting.config, wd_join_gather_setting.driver); + + return WD_SUCCESS; +} + +int wd_join_gather_init(char *alg, __u32 sched_type, int task_type, + struct wd_ctx_params *ctx_params) +{ + struct wd_ctx_params join_gather_ctx_params = {0}; + struct wd_ctx_nums join_gather_ctx_num = {0}; + int ret = -WD_EINVAL; + int state; + bool flag; + + pthread_atfork(NULL, NULL, wd_join_gather_clear_status); + + state = wd_alg_try_init(&wd_join_gather_setting.status); + if (state) + return state; + + if (!alg || sched_type >= SCHED_POLICY_BUTT || + task_type < 0 || task_type >= TASK_MAX_TYPE) { + WD_ERR("invalid: join_gathe init input param is wrong!\n"); + goto out_uninit; + } + + flag = wd_join_gather_alg_check(alg); + if (!flag) { + WD_ERR("invalid: alg: %s is unsupported!\n", alg); + goto out_uninit; + } + + state = wd_join_gather_open_driver(); + if (state) + goto out_uninit; + + while (ret != 0) { + memset(&wd_join_gather_setting.config, 0, sizeof(struct wd_ctx_config_internal)); + + /* Get alg driver and dev name */ + wd_join_gather_setting.driver = wd_alg_drv_bind(task_type, alg); + if (!wd_join_gather_setting.driver) { + WD_ERR("failed to bind %s driver.\n", alg); + goto out_dlopen; + } + + join_gather_ctx_params.ctx_set_num = &join_gather_ctx_num; + ret = wd_ctx_param_init(&join_gather_ctx_params, ctx_params, + wd_join_gather_setting.driver, + WD_JOIN_GATHER_TYPE, 1); + if (ret) { + if (ret == -WD_EAGAIN) { + wd_disable_drv(wd_join_gather_setting.driver); + wd_alg_drv_unbind(wd_join_gather_setting.driver); + continue; + } + goto out_driver; + } + + (void)strcpy(wd_join_gather_init_attrs.alg, alg); + wd_join_gather_init_attrs.sched_type = sched_type; + wd_join_gather_init_attrs.driver = wd_join_gather_setting.driver; + wd_join_gather_init_attrs.ctx_params = &join_gather_ctx_params; + wd_join_gather_init_attrs.alg_init = wd_join_gather_alg_init; + wd_join_gather_init_attrs.alg_poll_ctx = wd_join_gather_poll_ctx; + ret = wd_alg_attrs_init(&wd_join_gather_init_attrs); + if (ret) { + if (ret == -WD_ENODEV) { + wd_disable_drv(wd_join_gather_setting.driver); + wd_alg_drv_unbind(wd_join_gather_setting.driver); + wd_ctx_param_uninit(&join_gather_ctx_params); + continue; + } + WD_ERR("fail to init alg attrs.\n"); + goto out_params_uninit; + } + } + + wd_alg_set_init(&wd_join_gather_setting.status); + wd_ctx_param_uninit(&join_gather_ctx_params); + + return WD_SUCCESS; + +out_params_uninit: + wd_ctx_param_uninit(&join_gather_ctx_params); +out_driver: + wd_alg_drv_unbind(wd_join_gather_setting.driver); +out_dlopen: + wd_join_gather_close_driver(); +out_uninit: + wd_alg_clear_init(&wd_join_gather_setting.status); + return ret; +} + +void wd_join_gather_uninit(void) +{ + int ret; + + ret = wd_join_gather_alg_uninit(); + if (ret) + return; + + wd_alg_attrs_uninit(&wd_join_gather_init_attrs); + wd_alg_drv_unbind(wd_join_gather_setting.driver); + wd_join_gather_close_driver(); + wd_join_gather_setting.dlh_list = NULL; + wd_alg_clear_init(&wd_join_gather_setting.status); +} + +static void fill_build_hash_msg(struct wd_join_gather_msg *msg, + struct wd_join_gather_sess *sess) +{ + msg->index_type = sess->index_type; + msg->key_cols_num = sess->join_conf.cols_num; +} + +static void fill_probe_msg(struct wd_join_gather_msg *msg, + struct wd_join_gather_sess *sess) +{ + msg->key_cols_num = sess->join_conf.cols_num; + msg->index_type = sess->index_type; + msg->key_out_en = sess->join_conf.key_output_enable; +} + +static void fill_rehash_msg(struct wd_join_gather_msg *msg, + struct wd_join_gather_sess *sess) +{ + msg->key_cols_num = sess->join_conf.cols_num; +} + +static void fill_complete_msg(struct wd_join_gather_msg *msg, + struct wd_join_gather_sess *sess) +{ + __u32 table_index = msg->req.gather_req.table_index; + + msg->index_type = sess->index_type; + msg->multi_batch_en = sess->gather_conf.tables[table_index].is_multi_batch; +} + + +static void fill_join_gather_msg(struct wd_join_gather_msg *msg, struct wd_join_gather_req *req, + struct wd_join_gather_sess *sess) +{ + memcpy(&msg->req, req, sizeof(struct wd_join_gather_req)); + msg->priv = sess->priv; + msg->op_type = req->op_type; + + switch (req->op_type) { + case WD_JOIN_BUILD_HASH: + fill_build_hash_msg(msg, sess); + break; + case WD_JOIN_PROBE: + fill_probe_msg(msg, sess); + break; + case WD_JOIN_REHASH: + fill_rehash_msg(msg, sess); + break; + case WD_GATHER_CONVERT: + break; + case WD_GATHER_COMPLETE: + fill_complete_msg(msg, sess); + break; + default: + break; + } +} + +static int wd_join_gather_check_common(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req, + __u8 mode, bool is_join) +{ + if (!sess) { + WD_ERR("invalid: join or gather session is NULL!\n"); + return -WD_EINVAL; + } + + if (!req) { + WD_ERR("invalid: join input req is NULL!\n"); + return -WD_EINVAL; + } + + if (mode == CTX_MODE_ASYNC && !req->cb) { + WD_ERR("invalid: join gather req cb is NULL!\n"); + return -WD_EINVAL; + } + + switch (sess->alg) { + case WD_JOIN: + if (!is_join || !sess->join_conf.data_size) { + WD_ERR("invalid: join session data size is NULL!\n"); + return -WD_EINVAL; + } + break; + case WD_GATHER: + if (is_join || !sess->gather_conf.data_size) { + WD_ERR("invalid: gather session data size is NULL!\n"); + return -WD_EINVAL; + } + break; + case WD_JOIN_GATHER: + if (mode == CTX_MODE_ASYNC) { + WD_ERR("join-gather session does not support the async mode!\n"); + return -WD_EINVAL; + } + + if (!sess->join_conf.data_size || !sess->gather_conf.data_size) { + WD_ERR("invalid: join or gather session data size is NULL!\n"); + return -WD_EINVAL; + } + break; + default: + WD_ERR("invalid: session alg is not supported!\n"); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int check_in_col_addr(struct wd_dae_col_addr *col, __u32 row_count, + enum wd_dae_data_type type, __u64 data_size) +{ + if (!col->empty || col->empty_size != row_count * sizeof(col->empty[0])) { + WD_ERR("failed to check input empty col, size: %llu!\n", col->empty_size); + return -WD_EINVAL; + } + + if (!col->value || col->value_size != row_count * data_size) { + WD_ERR("failed to check input value col size: %llu!\n", col->value_size); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int check_out_col_addr(struct wd_dae_col_addr *col, __u32 row_count, + enum wd_dae_data_type type, __u64 data_size) +{ + if (!col->empty || col->empty_size < row_count * sizeof(col->empty[0])) { + WD_ERR("failed to check output empty col, size: %llu!\n", col->empty_size); + return -WD_EINVAL; + } + + if (!col->value || col->value_size < row_count * data_size) { + WD_ERR("failed to check output value col size: %llu!\n", col->value_size); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int check_key_col_addr(struct wd_dae_col_addr *cols, __u32 cols_num, + struct wd_join_gather_sess *sess, __u32 row_count, bool is_input) +{ + int (*func)(struct wd_dae_col_addr *col, __u32 row_count, + enum wd_dae_data_type type, __u64 data_size); + __u32 i; + int ret; + + func = is_input ? check_in_col_addr : check_out_col_addr; + + for (i = 0; i < cols_num; i++) { + ret = func(cols + i, row_count, sess->join_conf.cols[i].data_type, + sess->join_conf.data_size[i]); + if (ret) { + WD_ERR("failed to check req key col! col idx: %u\n", i); + return ret; + } + } + + return WD_SUCCESS; +} + +static int check_data_col_addr(struct wd_gather_req *req, struct wd_join_gather_sess *sess, + __u32 row_count, bool is_input) +{ + struct wd_gather_table_info *table = &sess->gather_conf.tables[req->table_index]; + __u64 *data_size = sess->gather_conf.data_size[req->table_index]; + int (*func)(struct wd_dae_col_addr *col, __u32 row_count, + enum wd_dae_data_type type, __u64 data_size); + __u32 i; + int ret; + + if (!data_size) { + WD_ERR("invalid: gather session data size is NULL!\n"); + return -WD_EINVAL; + } + + if (!row_count) { + WD_ERR("invalid: gather data row number is 0!\n"); + return -WD_EINVAL; + } + + func = is_input ? check_in_col_addr : check_out_col_addr; + + for (i = 0; i < req->data_cols_num; i++) { + ret = func(&req->data_cols[i], row_count, table->cols[i].data_type, + data_size[i]); + if (ret) { + WD_ERR("failed to check req data col! col idx: %u\n", i); + return ret; + } + } + + return WD_SUCCESS; +} + +static int check_probe_out_addr(struct wd_probe_out_info *output, + struct wd_join_gather_sess *sess, __u32 row_num) +{ + if (!output->build_index.addr || !output->build_index.row_size) { + WD_ERR("probe multi index is not set!\n"); + return -WD_EINVAL; + } + + if (!output->probe_index.addr || !output->probe_index.row_size) { + WD_ERR("probe single index is not set!\n"); + return -WD_EINVAL; + } + + if (output->build_index.row_num < row_num || output->probe_index.row_num < row_num) { + WD_ERR("build: %u, probe: %u, row num is less than output row_num: %u!\n", + output->build_index.row_num, output->probe_index.row_num, row_num); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int wd_join_common_check_req(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req) +{ + struct wd_join_req *join_req = &req->join_req; + int ret; + + if (join_req->key_cols_num != sess->join_conf.cols_num) { + WD_ERR("invalid: join table key_cols_num is not equal!\n"); + return -WD_EINVAL; + } + + if (!join_req->key_cols) { + WD_ERR("invalid: join table key_cols is NULL!\n"); + return -WD_EINVAL; + } + + if (!req->input_row_num) { + WD_ERR("invalid: join table input row number is zero!\n"); + return -WD_EINVAL; + } + + ret = check_key_col_addr(join_req->key_cols, join_req->key_cols_num, sess, + req->input_row_num, true); + if (ret) { + WD_ERR("failed to check join table key cols addr!\n"); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int wd_build_hash_check_params(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req, __u8 mode) +{ + int ret; + + ret = wd_join_gather_check_common(sess, req, mode, true); + if (ret) + return ret; + + if (req->op_type != WD_JOIN_BUILD_HASH) { + WD_ERR("failed to check req op_type for build hash task!\n"); + return -WD_EINVAL; + } + + ret = wd_join_common_check_req(sess, req); + if (ret) + WD_ERR("failed to check join req for build hash task!\n"); + + return ret; +} + +static int wd_join_probe_check_req(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req) +{ + struct wd_join_req *jreq = &req->join_req; + struct wd_probe_out_info *probe_output = &jreq->probe_output; + int ret; + + if (req->op_type != WD_JOIN_PROBE) { + WD_ERR("failed to check req op_type for probe task!\n"); + return -WD_EINVAL; + } + + ret = wd_join_common_check_req(sess, req); + if (ret) { + WD_ERR("failed to check join req for probe task!\n"); + return ret; + } + + if (!req->output_row_num) { + WD_ERR("probe output row number is zero!\n"); + return -WD_EINVAL; + } + + if (sess->join_conf.key_output_enable) { + if (probe_output->key_cols_num != sess->join_conf.cols_num || + !probe_output->key_cols) { + WD_ERR("invalid: probe out key_cols_num is not equal!\n"); + return -WD_EINVAL; + } + ret = check_key_col_addr(probe_output->key_cols, probe_output->key_cols_num, + sess, req->output_row_num, false); + if (ret) { + WD_ERR("failed to check porbe output key cols addr!\n"); + return -WD_EINVAL; + } + } + + ret = check_probe_out_addr(probe_output, sess, req->output_row_num); + if (ret) { + WD_ERR("failed to check porbe output addr!\n"); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int wd_join_probe_check_params(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req, __u8 mode) +{ + int ret; + + ret = wd_join_gather_check_common(sess, req, mode, true); + if (ret) + return ret; + + return wd_join_probe_check_req(sess, req); +} + +static int wd_join_rehash_check_params(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req) +{ + int ret; + + ret = wd_join_gather_check_common(sess, req, CTX_MODE_SYNC, true); + if (ret) + return ret; + + if (req->op_type != WD_JOIN_REHASH) { + WD_ERR("failed to check req op_type for rehash task!\n"); + return -WD_EINVAL; + } + + if (!req->output_row_num) { + WD_ERR("invalid: req output_row_num is 0 for join rehash!\n"); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int wd_join_gather_sync_job(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req, + struct wd_join_gather_msg *msg) +{ + struct wd_join_gather_setting *setting = &wd_join_gather_setting; + struct wd_ctx_config_internal *config = &setting->config; + struct wd_msg_handle msg_handle; + struct wd_ctx_internal *ctx; + __u32 idx; + int ret; + + memset(msg, 0, sizeof(struct wd_join_gather_msg)); + fill_join_gather_msg(msg, req, sess); + req->state = 0; + + idx = setting->sched.pick_next_ctx(setting->sched.h_sched_ctx, + sess->sched_key, CTX_MODE_SYNC); + ret = wd_check_ctx(config, CTX_MODE_SYNC, idx); + if (ret) + return ret; + + wd_dfx_msg_cnt(config, WD_CTX_CNT_NUM, idx); + ctx = config->ctxs + idx; + + msg_handle.send = setting->driver->send; + msg_handle.recv = setting->driver->recv; + + pthread_spin_lock(&ctx->lock); + ret = wd_handle_msg_sync(setting->driver, &msg_handle, ctx->ctx, + msg, NULL, config->epoll_en); + pthread_spin_unlock(&ctx->lock); + + return ret; +} + +static int wd_build_hash_try_init(struct wd_join_gather_sess *sess, + enum wd_join_sess_state *expected) +{ + enum wd_join_sess_state state; + + (void)__atomic_compare_exchange_n(&sess->state, expected, WD_JOIN_SESS_BUILD_HASH, + false, __ATOMIC_ACQUIRE, __ATOMIC_RELAXED); + state = __atomic_load_n(&sess->state, __ATOMIC_RELAXED); + if (state != WD_JOIN_SESS_BUILD_HASH) { + WD_ERR("failed to set join sess state: %u!\n", state); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int wd_join_gather_check_result(__u32 result) +{ + switch (result) { + case WD_JOIN_GATHER_TASK_DONE: + return WD_SUCCESS; + case WD_JOIN_GATHER_IN_EPARA: + case WD_JOIN_GATHER_NEED_REHASH: + case WD_JOIN_GATHER_INVALID_HASH_TABLE: + case WD_JOIN_GATHER_PARSE_ERROR: + case WD_JOIN_GATHER_BUS_ERROR: + WD_ERR("failed to check join gather message state: %u!\n", result); + return -WD_EIO; + default: + return -WD_EINVAL; + } +} + +int wd_join_build_hash_sync(handle_t h_sess, struct wd_join_gather_req *req) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + enum wd_join_sess_state expected = WD_JOIN_SESS_INIT; + struct wd_join_gather_msg msg; + int ret; + + ret = wd_build_hash_check_params(sess, req, CTX_MODE_SYNC); + if (unlikely(ret)) { + WD_ERR("failed to check hashjoin build hash params!\n"); + return ret; + } + + ret = wd_build_hash_try_init(sess, &expected); + if (unlikely(ret)) + return ret; + + ret = wd_join_gather_sync_job(sess, req, &msg); + if (unlikely(ret)) { + if (expected == WD_JOIN_SESS_INIT) + __atomic_store_n(&sess->state, expected, __ATOMIC_RELEASE); + WD_ERR("failed to do hashjoin build hash sync job!\n"); + return ret; + } + + req->consumed_row_num = msg.consumed_row_num; + req->state = msg.result; + + return WD_SUCCESS; +} + +static int wd_join_gather_async_job(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req) +{ + struct wd_join_gather_setting *setting = &wd_join_gather_setting; + struct wd_ctx_config_internal *config = &setting->config; + struct wd_join_gather_msg *msg; + struct wd_ctx_internal *ctx; + int msg_id, ret; + __u32 idx; + + idx = setting->sched.pick_next_ctx(setting->sched.h_sched_ctx, + sess->sched_key, CTX_MODE_ASYNC); + ret = wd_check_ctx(config, CTX_MODE_ASYNC, idx); + if (ret) + return ret; + + ctx = config->ctxs + idx; + msg_id = wd_get_msg_from_pool(&setting->pool, idx, (void **)&msg); + if (msg_id < 0) { + WD_ERR("failed to get join gather msg from pool!\n"); + return msg_id; + } + + fill_join_gather_msg(msg, req, sess); + msg->tag = msg_id; + ret = wd_alg_driver_send(setting->driver, ctx->ctx, msg); + if (ret < 0) { + if (ret != -WD_EBUSY) + WD_ERR("wd join gather async send err!\n"); + + goto fail_with_msg; + } + + wd_dfx_msg_cnt(config, WD_CTX_CNT_NUM, idx); + + return WD_SUCCESS; + +fail_with_msg: + wd_put_msg_to_pool(&setting->pool, idx, msg->tag); + return ret; +} + +int wd_join_build_hash_async(handle_t h_sess, struct wd_join_gather_req *req) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + enum wd_join_sess_state expected = WD_JOIN_SESS_INIT; + int ret; + + ret = wd_build_hash_check_params(sess, req, CTX_MODE_ASYNC); + if (unlikely(ret)) { + WD_ERR("failed to check build hash async params!\n"); + return ret; + } + + ret = wd_build_hash_try_init(sess, &expected); + if (unlikely(ret)) + return ret; + + ret = wd_join_gather_async_job(sess, req); + if (unlikely(ret)) { + if (expected == WD_JOIN_SESS_INIT) + __atomic_store_n(&sess->state, expected, __ATOMIC_RELEASE); + WD_ERR("failed to do join build hash async job!\n"); + } + + return ret; +} + +static int wd_join_probe_try_init(struct wd_join_gather_sess *sess, + enum wd_join_sess_state *expected) +{ + enum wd_join_sess_state state; + + (void)__atomic_compare_exchange_n(&sess->state, expected, WD_JOIN_SESS_PROBE, + false, __ATOMIC_ACQUIRE, __ATOMIC_RELAXED); + state = __atomic_load_n(&sess->state, __ATOMIC_RELAXED); + if (state != WD_JOIN_SESS_PROBE) { + WD_ERR("failed to set join sess state: %u!\n", state); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +int wd_join_probe_sync(handle_t h_sess, struct wd_join_gather_req *req) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + enum wd_join_sess_state expected = WD_JOIN_SESS_BUILD_HASH; + struct wd_join_gather_msg msg; + int ret; + + ret = wd_join_probe_check_params(sess, req, CTX_MODE_SYNC); + if (unlikely(ret)) { + WD_ERR("failed to check join probe params!\n"); + return ret; + } + + ret = wd_join_probe_try_init(sess, &expected); + if (unlikely(ret)) + return ret; + + ret = wd_join_gather_sync_job(sess, req, &msg); + if (unlikely(ret)) { + if (expected == WD_JOIN_SESS_BUILD_HASH) + __atomic_store_n(&sess->state, expected, __ATOMIC_RELEASE); + WD_ERR("failed to do join probe sync job!\n"); + return ret; + } + + req->consumed_row_num = msg.consumed_row_num; + req->produced_row_num = msg.produced_row_num; + req->output_done = msg.output_done; + req->state = msg.result; + + return WD_SUCCESS; +} + +int wd_join_probe_async(handle_t h_sess, struct wd_join_gather_req *req) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + enum wd_join_sess_state expected = WD_JOIN_SESS_BUILD_HASH; + int ret; + + ret = wd_join_probe_check_params(sess, req, CTX_MODE_ASYNC); + if (unlikely(ret)) { + WD_ERR("failed to check join probe params!\n"); + return ret; + } + + ret = wd_join_probe_try_init(sess, &expected); + if (unlikely(ret)) + return ret; + + ret = wd_join_gather_async_job(sess, req); + if (unlikely(ret)) { + if (expected == WD_JOIN_SESS_BUILD_HASH) + __atomic_store_n(&sess->state, expected, __ATOMIC_RELEASE); + WD_ERR("failed to do join probe async job!\n"); + } + + return ret; +} + +static int wd_join_rehash_sync_inner(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req) +{ + struct wd_join_gather_msg msg = {0}; + int ret; + + ret = wd_join_gather_sync_job(sess, req, &msg); + if (ret) + return ret; + + ret = wd_join_gather_check_result(msg.result); + if (ret) + return ret; + + req->output_done = msg.output_done; + + return WD_SUCCESS; +} + +static int wd_join_rehash_try_init(struct wd_join_gather_sess *sess, + enum wd_join_sess_state *expected) +{ + int ret; + + ret = __atomic_compare_exchange_n(&sess->state, expected, WD_JOIN_SESS_REHASH, + false, __ATOMIC_ACQUIRE, __ATOMIC_RELAXED); + if (!ret) { + WD_ERR("invalid: join rehash sess state is %u!\n", *expected); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +int wd_join_rehash_sync(handle_t h_sess, struct wd_join_gather_req *req) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + enum wd_join_sess_state expected = WD_JOIN_SESS_PREPARE_REHASH; + __u64 max_cnt, cnt = 0; + int ret; + + ret = wd_join_rehash_check_params(sess, req); + if (unlikely(ret)) { + WD_ERR("failed to check join rehash params!\n"); + return ret; + } + + ret = wd_join_rehash_try_init(sess, &expected); + if (unlikely(ret)) + return ret; + + max_cnt = MAX_HASH_TABLE_ROW_NUM / req->output_row_num; + while (cnt < max_cnt) { + ret = wd_join_rehash_sync_inner(sess, req); + if (unlikely(ret)) { + __atomic_store_n(&sess->state, WD_JOIN_SESS_PREPARE_REHASH, + __ATOMIC_RELEASE); + WD_ERR("failed to do join rehash task!\n"); + return ret; + } + if (req->output_done) + break; + cnt++; + } + + __atomic_store_n(&sess->state, WD_JOIN_SESS_BUILD_HASH, __ATOMIC_RELEASE); + return WD_SUCCESS; +} + +static int wd_gather_common_check_req(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req) +{ + struct wd_gather_req *gather_req = &req->gather_req; + struct wd_gather_table_info *tables; + __u32 table_index; + + if (!sess->gather_conf.tables) { + WD_ERR("invalid: session gather tables is NULL!\n"); + return -WD_EINVAL; + } + tables = sess->gather_conf.tables; + table_index = gather_req->table_index; + + if (table_index >= sess->gather_conf.table_num) { + WD_ERR("invalid: gather table index(%u) is too big!\n", table_index); + return -WD_EINVAL; + } + + if (gather_req->data_cols_num != tables[table_index].cols_num) { + WD_ERR("invalid: gather table data_cols_num is not equal!\n"); + return -WD_EINVAL; + } + + if (!gather_req->data_cols) { + WD_ERR("invalid: gather table data_cols is NULL!\n"); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int wd_gather_convert_check_req(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req) +{ + struct wd_gather_req *gather_req = &req->gather_req; + __u32 expt_size, table_index; + int ret; + + if (req->op_type != WD_GATHER_CONVERT) { + WD_ERR("failed to check req op_type for gather convert task!\n"); + return -WD_EINVAL; + } + + ret = wd_gather_common_check_req(sess, req); + if (ret) + return ret; + + table_index = gather_req->table_index; + + ret = check_data_col_addr(gather_req, sess, req->input_row_num, true); + if (ret) { + WD_ERR("failed to check gather convert data cols addr!\n"); + return -WD_EINVAL; + } + + if (gather_req->row_batchs.batch_num != 1 || !gather_req->row_batchs.batch_addr || + !gather_req->row_batchs.batch_addr[0]) { + WD_ERR("invalid: gather convert only support one batch!\n"); + return -WD_EINVAL; + } + + if (!gather_req->row_batchs.batch_row_num || !gather_req->row_batchs.batch_row_size) { + WD_ERR("invalid: gather convert batchs row_num or row_size is NULL!\n"); + return -WD_EINVAL; + } + + expt_size = sess->gather_conf.batch_row_size[table_index]; + if (gather_req->row_batchs.batch_row_num[0] != req->input_row_num || + gather_req->row_batchs.batch_row_size[0] != expt_size) { + WD_ERR("invalid: gather convert row batchs, row_size: %u, row_num: %u\n", + gather_req->row_batchs.batch_row_size[0], + gather_req->row_batchs.batch_row_num[0]); + return -WD_EINVAL; + } + + return WD_SUCCESS; +} + +static int wd_gather_complete_check_req(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req) +{ + struct wd_gather_req *gather_req = &req->gather_req; + struct wd_gather_table_info *tables; + struct wd_dae_row_addr *index_addr; + __u32 table_index, expt_size, i; + int ret; + + if (req->op_type != WD_GATHER_COMPLETE) { + WD_ERR("failed to check req op_type for gather complete task!\n"); + return -WD_EINVAL; + } + + ret = wd_gather_common_check_req(sess, req); + if (ret) + return ret; + + tables = sess->gather_conf.tables; + table_index = gather_req->table_index; + + ret = check_data_col_addr(gather_req, sess, req->output_row_num, false); + if (ret) { + WD_ERR("failed to check gather complete data cols addr!\n"); + return -WD_EINVAL; + } + + index_addr = &gather_req->index; + if (!index_addr->addr || index_addr->row_num < req->output_row_num) { + WD_ERR("invalid: gather index is NULL or index row number is small!\n"); + return -WD_EINVAL; + } + + /* The row batch information is stored to index, no need to check. */ + if (sess->index_type == WD_BATCH_ADDR_INDEX && tables[table_index].is_multi_batch) + return WD_SUCCESS; + + if (!gather_req->row_batchs.batch_num || !gather_req->row_batchs.batch_addr) { + WD_ERR("invalid: gather row batch is NULL or batch addr number is 0!\n"); + return -WD_EINVAL; + } + + if (!gather_req->row_batchs.batch_row_num || !gather_req->row_batchs.batch_row_size) { + WD_ERR("invalid: gather row batch row_num or row_size is NULL!\n"); + return -WD_EINVAL; + } + + if (!tables[table_index].is_multi_batch) { + if (gather_req->row_batchs.batch_num != 1) { + WD_ERR("invalid: single gather row batch addr num should be 1!\n"); + return -WD_EINVAL; + } + } + + for (i = 0; i < gather_req->row_batchs.batch_num; i++) { + if (!gather_req->row_batchs.batch_addr[i] || + !gather_req->row_batchs.batch_row_num[i]) { + WD_ERR("invalid: row batch addr or row_num is null! idx: %u\n", i); + return -WD_EINVAL; + } + expt_size = sess->gather_conf.batch_row_size[table_index]; + if (gather_req->row_batchs.batch_row_size[i] != expt_size) { + WD_ERR("invalid row batch row_size: %u, batch idx: %u\n", + gather_req->row_batchs.batch_row_size[i], i); + return -WD_EINVAL; + } + } + + return WD_SUCCESS; +} + +static int wd_gather_convert_check_params(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req, __u8 mode) +{ + int ret; + + ret = wd_join_gather_check_common(sess, req, mode, false); + if (ret) + return ret; + + return wd_gather_convert_check_req(sess, req); +} + +static int wd_gather_complete_check_params(struct wd_join_gather_sess *sess, + struct wd_join_gather_req *req, __u8 mode) +{ + int ret; + + ret = wd_join_gather_check_common(sess, req, mode, false); + if (ret) + return ret; + + return wd_gather_complete_check_req(sess, req); +} + +int wd_gather_convert_sync(handle_t h_sess, struct wd_join_gather_req *req) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + struct wd_join_gather_msg msg; + int ret; + + ret = wd_gather_convert_check_params(sess, req, CTX_MODE_SYNC); + if (unlikely(ret)) { + WD_ERR("failed to check gather convert params!\n"); + return ret; + } + + ret = wd_join_gather_sync_job(sess, req, &msg); + if (unlikely(ret)) { + WD_ERR("failed to do gather convert sync job!\n"); + return ret; + } + + req->consumed_row_num = msg.consumed_row_num; + req->state = msg.result; + + return WD_SUCCESS; +} + +int wd_gather_convert_async(handle_t h_sess, struct wd_join_gather_req *req) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + int ret; + + ret = wd_gather_convert_check_params(sess, req, CTX_MODE_ASYNC); + if (unlikely(ret)) { + WD_ERR("failed to check gather convert async params!\n"); + return ret; + } + + ret = wd_join_gather_async_job(sess, req); + if (unlikely(ret)) + WD_ERR("failed to do gather convert async job!\n"); + + return ret; +} + +int wd_gather_complete_sync(handle_t h_sess, struct wd_join_gather_req *req) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + struct wd_join_gather_msg msg; + int ret; + + ret = wd_gather_complete_check_params(sess, req, CTX_MODE_SYNC); + if (unlikely(ret)) { + WD_ERR("failed to check gather complete params!\n"); + return ret; + } + + ret = wd_join_gather_sync_job(sess, req, &msg); + if (unlikely(ret)) { + WD_ERR("failed to do gather complete sync job!\n"); + return ret; + } + + req->produced_row_num = msg.produced_row_num; + req->state = msg.result; + + return WD_SUCCESS; +} + +int wd_gather_complete_async(handle_t h_sess, struct wd_join_gather_req *req) +{ + struct wd_join_gather_sess *sess = (struct wd_join_gather_sess *)h_sess; + int ret; + + ret = wd_gather_complete_check_params(sess, req, CTX_MODE_ASYNC); + if (unlikely(ret)) { + WD_ERR("failed to check gather complete params!\n"); + return ret; + } + + ret = wd_join_gather_async_job(sess, req); + if (unlikely(ret)) + WD_ERR("failed to do gather complete async job!\n"); + + return ret; +} + +struct wd_join_gather_msg *wd_join_gather_get_msg(__u32 idx, __u32 tag) +{ + return wd_find_msg_in_pool(&wd_join_gather_setting.pool, idx, tag); +} + +static int wd_join_gather_poll_ctx(__u32 idx, __u32 expt, __u32 *count) +{ + struct wd_ctx_config_internal *config = &wd_join_gather_setting.config; + struct wd_join_gather_msg resp_msg = {0}; + struct wd_join_gather_msg *msg; + struct wd_ctx_internal *ctx; + struct wd_join_gather_req *req; + __u64 recv_count = 0; + __u32 tmp = expt; + int ret; + + *count = 0; + + ret = wd_check_ctx(config, CTX_MODE_ASYNC, idx); + if (unlikely(ret)) + return ret; + + ctx = config->ctxs + idx; + + do { + ret = wd_alg_driver_recv(wd_join_gather_setting.driver, ctx->ctx, &resp_msg); + if (ret == -WD_EAGAIN) { + return ret; + } else if (ret < 0) { + WD_ERR("wd join_gather recv hw err!\n"); + return ret; + } + recv_count++; + msg = wd_find_msg_in_pool(&wd_join_gather_setting.pool, idx, resp_msg.tag); + if (!msg) { + WD_ERR("failed to get join gather msg from pool!\n"); + return -WD_EINVAL; + } + + msg->req.state = resp_msg.result; + msg->req.consumed_row_num = resp_msg.consumed_row_num; + msg->req.produced_row_num = resp_msg.produced_row_num; + msg->req.output_done = resp_msg.output_done; + req = &msg->req; + + req->cb(req, req->cb_param); + /* Free msg cache to msg_pool */ + wd_put_msg_to_pool(&wd_join_gather_setting.pool, idx, resp_msg.tag); + *count = recv_count; + } while (--tmp); + + return ret; +} + +int wd_join_gather_poll(__u32 expt, __u32 *count) +{ + handle_t h_ctx = wd_join_gather_setting.sched.h_sched_ctx; + struct wd_sched *sched = &wd_join_gather_setting.sched; + + if (!expt || !count) { + WD_ERR("invalid: join gather poll input param is NULL!\n"); + return -WD_EINVAL; + } + + return sched->poll_policy(h_ctx, expt, count); +} diff --git a/wd_util.c b/wd_util.c index 2f0ab198..f9440f26 100644 --- a/wd_util.c +++ b/wd_util.c @@ -66,6 +66,7 @@ static const char *wd_env_name[WD_TYPE_MAX] = { "WD_ECC_CTX_NUM", "WD_AGG_CTX_NUM", "WD_UDMA_CTX_NUM", + "WD_JOIN_GATHER_CTX_NUM", }; struct async_task { @@ -113,6 +114,9 @@ static struct acc_alg_item alg_options[] = { {"lz77_only", "lz77_only"}, {"hashagg", "hashagg"}, {"udma", "udma"}, + {"hashjoin", "hashjoin"}, + {"gather", "gather"}, + {"join-gather", "hashjoin"}, {"rsa", "rsa"}, {"dh", "dh"}, @@ -2603,7 +2607,7 @@ static int wd_alg_ctx_init(struct wd_init_attrs *attrs) list = wd_get_accel_list(attrs->alg); if (!list) { - WD_ERR("failed to get devices!\n"); + WD_ERR("failed to get devices for alg: %s\n", attrs->alg); return -WD_ENODEV; } -- 2.33.0

From: Qinxin Xia <xiaqinxin@huawei.com> Support 'lz4' algorithm in hisilicon driver. Signed-off-by: Qinxin Xia <xiaqinxin@huawei.com> --- drv/hisi_comp.c | 120 +++++++++++++++++++++++++++++++++++++++++++++- include/wd_comp.h | 1 + 2 files changed, 119 insertions(+), 2 deletions(-) diff --git a/drv/hisi_comp.c b/drv/hisi_comp.c index 0c07bb82..69256d26 100644 --- a/drv/hisi_comp.c +++ b/drv/hisi_comp.c @@ -109,6 +109,7 @@ enum alg_type { HW_DEFLATE = 0x1, HW_ZLIB, HW_GZIP, + HW_LZ4, HW_LZ77_ZSTD_PRICE = 0x42, HW_LZ77_ZSTD, HW_LZ77_ONLY = 0x40, @@ -504,6 +505,60 @@ static int fill_buf_gzip(handle_t h_qp, struct hisi_zip_sqe *sqe, return fill_buf_deflate_generic(sqe, msg, GZIP_HEADER, GZIP_HEADER_SZ); } +static void fill_buf_addr_lz4(struct hisi_zip_sqe *sqe, void *src, void *dst) +{ + sqe->source_addr_l = lower_32_bits(src); + sqe->source_addr_h = upper_32_bits(src); + sqe->dest_addr_l = lower_32_bits(dst); + sqe->dest_addr_h = upper_32_bits(dst); +} + +static int check_lz4_msg(struct wd_comp_msg *msg, enum wd_buff_type buf_type) +{ + /* LZ4 only support for compress and block mode */ + if (unlikely(msg->req.op_type != WD_DIR_COMPRESS)) { + WD_ERR("invalid: lz4 only support compress!\n"); + return -WD_EINVAL; + } + + if (unlikely(msg->stream_mode == WD_COMP_STATEFUL)) { + WD_ERR("invalid: lz4 does not support the stream mode!\n"); + return -WD_EINVAL; + } + + if (buf_type != WD_FLAT_BUF) + return 0; + + if (unlikely(msg->req.src_len == 0 || msg->req.src_len > HZ_MAX_SIZE)) { + WD_ERR("invalid: lz4 input size can't be zero or more than 8M size max!\n"); + return -WD_EINVAL; + } + + if (unlikely(msg->avail_out > HZ_MAX_SIZE)) + msg->avail_out = HZ_MAX_SIZE; + + return 0; + +} + +static int fill_buf_lz4(handle_t h_qp, struct hisi_zip_sqe *sqe, + struct wd_comp_msg *msg) +{ + void *src = msg->req.src; + void *dst = msg->req.dst; + int ret; + + ret = check_lz4_msg(msg, WD_FLAT_BUF); + if (unlikely(ret)) + return ret; + + fill_comp_buf_size(sqe, msg->req.src_len, msg->avail_out); + + fill_buf_addr_lz4(sqe, src, dst); + + return 0; +} + static void fill_buf_type_sgl(struct hisi_zip_sqe *sqe) { __u32 val; @@ -665,7 +720,7 @@ static int lz77_zstd_buf_check(struct wd_comp_msg *msg) if (unlikely(in_size > ZSTD_MAX_SIZE)) { WD_ERR("invalid: in_len(%u) of lz77_zstd is out of range!\n", in_size); - return -WD_EINVAL; + return -WD_EINVAL; } if (unlikely(msg->stream_mode == WD_COMP_STATEFUL && msg->comp_lv < WD_COMP_L9 && @@ -945,6 +1000,44 @@ static int fill_buf_lz77_zstd_sgl(handle_t h_qp, struct hisi_zip_sqe *sqe, return 0; } +static int fill_buf_addr_lz4_sgl(handle_t h_qp, struct hisi_zip_sqe *sqe, + struct wd_datalist *list_src, + struct wd_datalist *list_dst) +{ + struct comp_sgl c_sgl; + int ret; + + c_sgl.list_src = list_src; + c_sgl.list_dst = list_dst; + c_sgl.seq_start = NULL; + + ret = get_sgl_from_pool(h_qp, &c_sgl); + if (unlikely(ret)) + return ret; + + fill_buf_addr_lz4(sqe, c_sgl.in, c_sgl.out); + + return 0; +} + +static int fill_buf_lz4_sgl(handle_t h_qp, struct hisi_zip_sqe *sqe, + struct wd_comp_msg *msg) +{ + struct wd_datalist *list_src = msg->req.list_src; + struct wd_datalist *list_dst = msg->req.list_dst; + int ret; + + ret = check_lz4_msg(msg, WD_SGL_BUF); + if (unlikely(ret)) + return ret; + + fill_buf_type_sgl(sqe); + + fill_comp_buf_size(sqe, msg->req.src_len, msg->avail_out); + + return fill_buf_addr_lz4_sgl(h_qp, sqe, list_src, list_dst); +} + static void fill_sqe_type_v1(struct hisi_zip_sqe *sqe) { __u32 val; @@ -1003,6 +1096,15 @@ static void fill_alg_lz77_only(struct hisi_zip_sqe *sqe) sqe->dw9 = val; } +static void fill_alg_lz4(struct hisi_zip_sqe *sqe) +{ + __u32 val; + + val = sqe->dw9 & ~HZ_REQ_TYPE_MASK; + val |= HW_LZ4; + sqe->dw9 = val; +} + static void fill_tag_v1(struct hisi_zip_sqe *sqe, __u32 tag) { sqe->dw13 = tag; @@ -1157,6 +1259,16 @@ struct hisi_zip_sqe_ops ops[] = { { .fill_comp_level = fill_comp_level_lz77_zstd, .get_data_size = get_data_size_lz77_zstd, .get_tag = get_tag_v3, + }, { + .alg_name = "lz4", + .fill_buf[WD_FLAT_BUF] = fill_buf_lz4, + .fill_buf[WD_SGL_BUF] = fill_buf_lz4_sgl, + .fill_sqe_type = fill_sqe_type_v3, + .fill_alg = fill_alg_lz4, + .fill_tag = fill_tag_v3, + .fill_comp_level = fill_comp_level, + .get_data_size = get_comp_data_size, + .get_tag = get_tag_v3, }, { .alg_name = "lz77_only", .fill_buf[WD_FLAT_BUF] = fill_buf_lz77_zstd, @@ -1383,6 +1495,9 @@ static int get_alg_type(__u32 type) case HW_GZIP: alg_type = WD_GZIP; break; + case HW_LZ4: + alg_type = WD_LZ4; + break; case HW_LZ77_ZSTD: case HW_LZ77_ZSTD_PRICE: alg_type = WD_LZ77_ZSTD; @@ -1570,7 +1685,8 @@ static struct wd_alg_driver zip_alg_driver[] = { GEN_ZIP_ALG_DRIVER("deflate"), GEN_ZIP_ALG_DRIVER("lz77_zstd"), - GEN_ZIP_ALG_DRIVER("lz77_only"), + GEN_ZIP_ALG_DRIVER("lz4"), + GEN_ZIP_ALG_DRIVER("lz77_only") }; #ifdef WD_STATIC_DRV diff --git a/include/wd_comp.h b/include/wd_comp.h index 0012ef6b..8e056d1c 100644 --- a/include/wd_comp.h +++ b/include/wd_comp.h @@ -20,6 +20,7 @@ enum wd_comp_alg_type { WD_ZLIB, WD_GZIP, WD_LZ77_ZSTD, + WD_LZ4, WD_LZ77_ONLY, WD_COMP_ALG_MAX, }; -- 2.33.0
participants (1)
-
Qi Tao