From: Wei Yongjun weiyongjun1@huawei.com
hulk inclusion category: feature bugzilla: NA DTS: #659 CVE: NA
-------------------------------------------------
This patch implement software level compression for sending tcp messages. All of the TCP payload will be compressed before xmit.
Signed-off-by: Wei Yongjun weiyongjun1@huawei.com Signed-off-by: Wang Yufen wangyufen@huawei.com --- net/ipv4/Kconfig | 2 +- net/ipv4/tcp_comp.c | 448 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 446 insertions(+), 4 deletions(-)
diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index dc4a6d1..96f2311 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -757,7 +757,7 @@ config TCP_MD5SIG
config TCP_COMP bool "TCP: Transport Layer Compression support" - depends on !SMC + depends on !SMC && ZSTD_COMPRESS=y ---help--- Enable kernel payload compression support for TCP protocol. This allows payload compression handling of the TCP protocol to be done in-kernel. diff --git a/net/ipv4/tcp_comp.c b/net/ipv4/tcp_comp.c index b7d3ac2..5fe8acd 100644 --- a/net/ipv4/tcp_comp.c +++ b/net/ipv4/tcp_comp.c @@ -6,6 +6,14 @@ */
#include <net/tcp.h> +#include <linux/zstd.h> + +#define TCP_COMP_MAX_PADDING 64 +#define TCP_COMP_SCRATCH_SIZE 65400 +#define TCP_COMP_MAX_CSIZE (TCP_COMP_SCRATCH_SIZE + TCP_COMP_MAX_PADDING) + +#define TCP_COMP_SEND_PENDING 1 +#define ZSTD_COMP_DEFAULT_LEVEL 1
static unsigned long tcp_compression_ports[65536 / 8];
@@ -14,11 +22,42 @@
static struct proto tcp_prot_override;
+struct tcp_comp_context_tx { + ZSTD_CStream *cstream; + void *cworkspace; + void *plaintext_data; + void *compressed_data; + + struct scatterlist sg_data[MAX_SKB_FRAGS]; + unsigned int sg_size; + int sg_num; + + struct scatterlist *partially_send; + bool in_tcp_sendpages; +}; + struct tcp_comp_context { - struct proto *sk_proto; struct rcu_head rcu; + + struct proto *sk_proto; + void (*sk_write_space)(struct sock *sk); + + struct tcp_comp_context_tx tx; + + unsigned long flags; };
+static bool tcp_comp_is_write_pending(struct tcp_comp_context *ctx) +{ + return test_bit(TCP_COMP_SEND_PENDING, &ctx->flags); +} + +static void tcp_comp_err_abort(struct sock *sk, int err) +{ + sk->sk_err = err; + sk->sk_error_report(sk); +} + static bool tcp_comp_enabled(__be32 saddr, __be32 daddr, int port) { if (!sysctl_tcp_compression_local && @@ -55,11 +94,359 @@ static struct tcp_comp_context *comp_get_ctx(const struct sock *sk) return (__force void *)icsk->icsk_ulp_data; }
-static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +static int tcp_comp_tx_context_init(struct tcp_comp_context *ctx) +{ + ZSTD_parameters params; + int csize; + + params = ZSTD_getParams(ZSTD_COMP_DEFAULT_LEVEL, PAGE_SIZE, 0); + csize = ZSTD_CStreamWorkspaceBound(params.cParams); + if (csize <= 0) + return -EINVAL; + + ctx->tx.cworkspace = kmalloc(csize, GFP_KERNEL); + if (!ctx->tx.cworkspace) + return -ENOMEM; + + ctx->tx.cstream = ZSTD_initCStream(params, 0, ctx->tx.cworkspace, + csize); + if (!ctx->tx.cstream) + goto err_cstream; + + ctx->tx.plaintext_data = kvmalloc(TCP_COMP_SCRATCH_SIZE, GFP_KERNEL); + if (!ctx->tx.plaintext_data) + goto err_cstream; + + ctx->tx.compressed_data = kvmalloc(TCP_COMP_MAX_CSIZE, GFP_KERNEL); + if (!ctx->tx.compressed_data) + goto err_compressed; + + return 0; + +err_compressed: + kvfree(ctx->tx.plaintext_data); + ctx->tx.plaintext_data = NULL; +err_cstream: + kfree(ctx->tx.cworkspace); + ctx->tx.cworkspace = NULL; + + return -ENOMEM; +} + +static void *tcp_comp_get_tx_stream(struct sock *sk) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + + if (!ctx->tx.plaintext_data) + tcp_comp_tx_context_init(ctx); + + return ctx->tx.plaintext_data; +} + +static int alloc_compressed_sg(struct sock *sk, int len) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + int rc = 0; + + rc = sk_alloc_sg(sk, len, ctx->tx.sg_data, 0, + &ctx->tx.sg_num, &ctx->tx.sg_size, 0); + if (rc == -ENOSPC) + ctx->tx.sg_num = ARRAY_SIZE(ctx->tx.sg_data); + + return rc; +} + +static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, int copy) +{ + void *dest; + int rc; + + dest = tcp_comp_get_tx_stream(sk); + if (!dest) + return -ENOSPC; + + if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) + rc = copy_from_iter_nocache(dest, copy, from); + else + rc = copy_from_iter(dest, copy, from); + + if (rc != copy) + rc = -EFAULT; + + return rc; +} + +static int memcopy_to_sg(struct sock *sk, int bytes) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + struct scatterlist *sg = ctx->tx.sg_data; + char *from, *to; + int copy; + + from = ctx->tx.compressed_data; + while (bytes && sg) { + to = sg_virt(sg); + copy = min_t(int, sg->length, bytes); + memcpy(to, from, copy); + bytes -= copy; + from += copy; + sg = sg_next(sg); + } + + return bytes; +} + +static void trim_sg(struct sock *sk, int target_size) { struct tcp_comp_context *ctx = comp_get_ctx(sk); + struct scatterlist *sg = ctx->tx.sg_data; + int trim = ctx->tx.sg_size - target_size; + int i = ctx->tx.sg_num - 1; + + if (trim <= 0) { + WARN_ON_ONCE(trim < 0); + return; + } + + ctx->tx.sg_size = target_size; + while (trim >= sg[i].length) { + trim -= sg[i].length; + sk_mem_uncharge(sk, sg[i].length); + put_page(sg_page(&sg[i])); + i--;
- return ctx->sk_proto->sendmsg(sk, msg, size); + if (i < 0) + goto out; + } + + sg[i].length -= trim; + sk_mem_uncharge(sk, trim); + +out: + ctx->tx.sg_num = i + 1; + sg_mark_end(ctx->tx.sg_data + ctx->tx.sg_num - 1); +} + +static int tcp_comp_compress_to_sg(struct sock *sk, int bytes) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + ZSTD_outBuffer outbuf; + ZSTD_inBuffer inbuf; + size_t ret; + + inbuf.src = ctx->tx.plaintext_data; + outbuf.dst = ctx->tx.compressed_data; + inbuf.size = bytes; + outbuf.size = TCP_COMP_MAX_CSIZE; + inbuf.pos = 0; + outbuf.pos = 0; + + ret = ZSTD_compressStream(ctx->tx.cstream, &outbuf, &inbuf); + if (ZSTD_isError(ret)) + return -EIO; + + ret = ZSTD_flushStream(ctx->tx.cstream, &outbuf); + if (ZSTD_isError(ret)) + return -EIO; + + if (inbuf.pos != inbuf.size) + return -EIO; + + if (memcopy_to_sg(sk, outbuf.pos)) + return -EIO; + + trim_sg(sk, outbuf.pos); + + return 0; +} + +static int tcp_comp_push_sg(struct sock *sk, struct scatterlist *sg, int flags) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + int ret, offset; + struct page *p; + size_t size; + + ctx->tx.in_tcp_sendpages = true; + while (sg) { + offset = sg->offset; + size = sg->length; + p = sg_page(sg); +retry: + ret = do_tcp_sendpages(sk, p, offset, size, flags); + if (ret != size) { + if (ret > 0) { + sk_mem_uncharge(sk, ret); + sg->offset += ret; + sg->length -= ret; + size -= ret; + offset += ret; + goto retry; + } + ctx->tx.partially_send = (void *)sg; + ctx->tx.in_tcp_sendpages = false; + return ret; + } + + sk_mem_uncharge(sk, ret); + put_page(p); + sg = sg_next(sg); + } + + clear_bit(TCP_COMP_SEND_PENDING, &ctx->flags); + ctx->tx.in_tcp_sendpages = false; + ctx->tx.sg_size = 0; + ctx->tx.sg_num = 0; + + return 0; +} + +static int tcp_comp_push(struct sock *sk, int bytes, int flags) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + int ret; + + ret = tcp_comp_compress_to_sg(sk, bytes); + if (ret < 0) { + pr_debug("%s: failed to compress sg\n", __func__); + return ret; + } + + set_bit(TCP_COMP_SEND_PENDING, &ctx->flags); + + ret = tcp_comp_push_sg(sk, ctx->tx.sg_data, flags); + if (ret) { + pr_debug("%s: failed to tcp_comp_push_sg\n", __func__); + return ret; + } + + return 0; +} + +static int wait_on_pending_writer(struct sock *sk, long *timeo) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + int ret = 0; + + add_wait_queue(sk_sleep(sk), &wait); + while (1) { + if (!*timeo) { + ret = -EAGAIN; + break; + } + + if (signal_pending(current)) { + ret = sock_intr_errno(*timeo); + break; + } + + if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait)) + break; + } + remove_wait_queue(sk_sleep(sk), &wait); + + return ret; +} + +static int tcp_comp_push_pending_sg(struct sock *sk, int flags) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + struct scatterlist *sg; + + if (!ctx->tx.partially_send) + return 0; + + sg = ctx->tx.partially_send; + ctx->tx.partially_send = NULL; + + return tcp_comp_push_sg(sk, sg, flags); +} + +static int tcp_comp_complete_pending_work(struct sock *sk, int flags, + long *timeo) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + int ret = 0; + + if (unlikely(sk->sk_write_pending)) + ret = wait_on_pending_writer(sk, timeo); + + if (!ret && tcp_comp_is_write_pending(ctx)) + ret = tcp_comp_push_pending_sg(sk, flags); + + return ret; +} + +static int tcp_comp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + int copied = 0, err = 0; + size_t try_to_copy; + int required_size; + long timeo; + + lock_sock(sk); + + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + + err = tcp_comp_complete_pending_work(sk, msg->msg_flags, &timeo); + if (err) + goto out_err; + + while (msg_data_left(msg)) { + if (sk->sk_err) { + err = -sk->sk_err; + goto out_err; + } + + try_to_copy = msg_data_left(msg); + if (try_to_copy > TCP_COMP_SCRATCH_SIZE) + try_to_copy = TCP_COMP_SCRATCH_SIZE; + required_size = try_to_copy + TCP_COMP_MAX_PADDING; + + if (!sk_stream_memory_free(sk)) + goto wait_for_sndbuf; + +alloc_compressed: + err = alloc_compressed_sg(sk, required_size); + if (err) { + if (err != -ENOSPC) + goto wait_for_memory; + goto out_err; + } + + err = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy); + if (err < 0) + goto out_err; + + copied += try_to_copy; + + err = tcp_comp_push(sk, try_to_copy, msg->msg_flags); + if (err < 0) { + if (err == -ENOMEM) + goto wait_for_memory; + if (err != -EAGAIN) + tcp_comp_err_abort(sk, EBADMSG); + goto out_err; + } + + continue; +wait_for_sndbuf: + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); +wait_for_memory: + err = sk_stream_wait_memory(sk, &timeo); + if (err) + goto out_err; + if (ctx->tx.sg_size < required_size) + goto alloc_compressed; + } + +out_err: + err = sk_stream_error(sk, msg->msg_flags, err); + + release_sock(sk); + + return copied ? copied : err; }
static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, @@ -70,6 +457,30 @@ static int tcp_comp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, return ctx->sk_proto->recvmsg(sk, msg, len, nonblock, flags, addr_len); }
+static void tcp_comp_write_space(struct sock *sk) +{ + struct tcp_comp_context *ctx = comp_get_ctx(sk); + + if (ctx->tx.in_tcp_sendpages) { + ctx->sk_write_space(sk); + return; + } + + if (!sk->sk_write_pending && tcp_comp_is_write_pending(ctx)) { + gfp_t sk_allocation = sk->sk_allocation; + int rc; + + sk->sk_allocation = GFP_ATOMIC; + rc = tcp_comp_push_pending_sg(sk, MSG_DONTWAIT | MSG_NOSIGNAL); + sk->sk_allocation = sk_allocation; + + if (rc < 0) + return; + } + + ctx->sk_write_space(sk); +} + void tcp_init_compression(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); @@ -83,20 +494,46 @@ void tcp_init_compression(struct sock *sk) if (!ctx) return;
+ sg_init_table(ctx->tx.sg_data, ARRAY_SIZE(ctx->tx.sg_data)); + + ctx->sk_write_space = sk->sk_write_space; ctx->sk_proto = sk->sk_prot; WRITE_ONCE(sk->sk_prot, &tcp_prot_override); + sk->sk_write_space = tcp_comp_write_space;
rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
sock_set_flag(sk, SOCK_COMP); }
+static void free_sg(struct sock *sk, struct scatterlist *sg) +{ + while (sg) { + sk_mem_uncharge(sk, sg->length); + put_page(sg_page(sg)); + sg = sg_next(sg); + } +} + +static void tcp_comp_context_tx_free(struct tcp_comp_context *ctx) +{ + kfree(ctx->tx.cworkspace); + ctx->tx.cworkspace = NULL; + + kvfree(ctx->tx.plaintext_data); + ctx->tx.plaintext_data = NULL; + + kvfree(ctx->tx.compressed_data); + ctx->tx.compressed_data = NULL; +} + static void tcp_comp_context_free(struct rcu_head *head) { struct tcp_comp_context *ctx;
ctx = container_of(head, struct tcp_comp_context, rcu);
+ tcp_comp_context_tx_free(ctx); kfree(ctx); }
@@ -108,6 +545,11 @@ void tcp_cleanup_compression(struct sock *sk) if (!ctx || !sock_flag(sk, SOCK_COMP)) return;
+ if (ctx->tx.partially_send) { + free_sg(sk, ctx->tx.partially_send); + ctx->tx.partially_send = NULL; + } + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); call_rcu(&ctx->rcu, tcp_comp_context_free); }