From: "fu.lin" fu.lin10@huawei.com
Conflict:NA Reference:https://gitee.com/src-openeuler/criu/pulls/21 Signed-off-by: fu.lin fu.lin10@huawei.com --- criu/Makefile.crtools | 1 + criu/include/nftables.h | 138 +++++++ criu/nftables.c | 823 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 962 insertions(+) create mode 100644 criu/nftables.c
diff --git a/criu/Makefile.crtools b/criu/Makefile.crtools index ff6b597..cda5b82 100644 --- a/criu/Makefile.crtools +++ b/criu/Makefile.crtools @@ -91,6 +91,7 @@ obj-y += vdso.o obj-y += timens.o obj-y += devname.o obj-y += mnl.o +obj-y += nftables.o obj-$(CONFIG_HAS_LIBBPF) += bpfmap.o obj-$(CONFIG_COMPAT) += pie-util-vdso-elf32.o CFLAGS_pie-util-vdso-elf32.o += -DCONFIG_VDSO_32 diff --git a/criu/include/nftables.h b/criu/include/nftables.h index 0bdab31..3b51a3d 100644 --- a/criu/include/nftables.h +++ b/criu/include/nftables.h @@ -3,6 +3,99 @@
#include <libmnl/libmnl.h>
+#include <libnftnl/table.h> +#include <libnftnl/chain.h> +#include <libnftnl/set.h> +#include <libnftnl/rule.h> +#include <libnftnl/expr.h> + +#define construct_buf(buf, type, family, flags, seq, payload, cb_prefix) \ + ({ \ + struct nlmsghdr *_nlh; \ + \ + _nlh = nftnl_##cb_prefix##_nlmsg_build_hdr((buf), \ + (type), (family), (flags), (seq)); \ + nftnl_##cb_prefix##_nlmsg_build_payload(_nlh, (payload)); \ + nftnl_##cb_prefix##_free((payload)); \ + _nlh; \ + }) + +#define construct_table_buf(buf, type, family, flags, seq, payload) \ + construct_buf((buf), (type), (family), (flags), (seq), \ + (payload), table) + +#define construct_chain_buf(buf, type, family, flags, seq, payload) \ + construct_buf((buf), (type), (family), (flags), (seq), \ + (payload), chain) + +#define construct_batch(batch, type, family, flags, seq, payload, cb_prefix) \ + { \ + struct nlmsghdr *_nlh; \ + \ + _nlh = nftnl_##cb_prefix##_nlmsg_build_hdr( \ + mnl_nlmsg_batch_current(batch), \ + (type), (family), (flags), (seq)); \ + nftnl_##cb_prefix##_nlmsg_build_payload(_nlh, (payload)); \ + nftnl_##cb_prefix##_free((payload)); \ + mnl_nlmsg_batch_next((batch)); \ + } + +#define construct_table_batch(batch, type, family, flags, seq, payload) \ + construct_batch((batch), (type), (family), (flags), (seq), \ + (payload), table) + +#define construct_chain_batch(batch, type, family, flags, seq, payload) \ + construct_batch((batch), (type), (family), (flags), (seq), \ + (payload), chain) + +#define construct_set_batch(batch, type, family, flags, seq, payload) \ + construct_batch((batch), (type), (family), (flags), (seq), \ + (payload), set) + +#define construct_rule_batch(batch, type, family, flags, seq, payload) \ + construct_batch((batch), (type), (family), (flags), (seq), \ + (payload), rule) + +#define construct_set_elems_batch(batch, type, family, flags, seq, payload) \ + { \ + struct nlmsghdr *_nlh; \ + \ + _nlh = nftnl_nlmsg_build_hdr( \ + mnl_nlmsg_batch_current(batch), \ + (type), (family), (flags), (seq)); \ + nftnl_set_elems_nlmsg_build_payload(_nlh, (payload)); \ + nftnl_set_free((payload)); \ + mnl_nlmsg_batch_next((batch)); \ + } + +#define TABLE_NAME "filter" +#define INPUT_CHAIN_NAME "criu-input" +#define OUTPUT_CHAIN_NAME "criu-output" +#define INPUT_IPV4_SET_NAME "criu-input-ipv4-blacklist-%d" +#define INPUT_IPV6_SET_NAME "criu-input-ipv6-blacklist-%d" +#define OUTPUT_IPV4_SET_NAME "criu-output-ipv4-blacklist-%d" +#define OUTPUT_IPV6_SET_NAME "criu-output-ipv6-blacklist-%d" + +/* set key type, see nftables/include/datatypes.h + * The rule of the datatype calculation: + * Each type occupies 6 bits, type: + * - ipaddr: 7, 4 bytes + * - ip6addr: 8, 16 types + * - inet_service: 13, 2 bytes (pading to 4 bytes) + * + * 0x1cd1cd: 0b 000111 001101 000111 001101 + * 0x20d20d: 0b 001000 001101 001000 001101 + */ +#define INET_SERVICE_LEN 2 +#define IPADDR_LEN 4 +#define IP6ADDR_LEN 16 +#define div_round_up(n, d) (((n) + (d) - 1) / (d)) + +#define IPv4_KEY_TYPE 0x1cd1cd +#define IPv4_KEY_LEN div_round_up(IPADDR_LEN + INET_SERVICE_LEN, 4) * 4 * 2 +#define IPv6_KEY_TYPE 0x20d20d +#define IPv6_KEY_LEN div_round_up(IP6ADDR_LEN + INET_SERVICE_LEN, 4) * 4 * 2 + struct mnl_params { struct mnl_socket *nl; char *buf; @@ -25,4 +118,49 @@ int mnl_common(mnl_func_t mnl_cb, void *arg1, void *arg2); int mnl_batch_send_and_recv(struct mnl_params *mnl_params, batch_func_t cb, void *args, int *result); int mnl_buf_send_and_recv(struct mnl_params *mnl_params, buf_func_t cb, void *args, int *result);
+struct nft_chain_params { + char *name; + uint32_t hooknum; + char *type; + uint32_t prio; + uint32_t policy; +}; + +struct nft_set_params { + char name[128]; + uint32_t id; + uint32_t datatype; + uint32_t key_len; +}; + +struct nft_rule_params { + char *chain_name; + char set_name[128]; + uint32_t mark; + uint16_t mark_op; + uint32_t nfproto; + uint8_t l4proto; + unsigned int stmt; + bool ipv6; +}; + +struct nft_set_elem_params { + char set_name[128]; + char data[40]; + size_t data_len; +}; + +struct nf_conn_params { + uint8_t family; + uint32_t *src_addr; + uint16_t src_port; + uint32_t *dst_addr; + uint16_t dst_port; + bool lock; + pid_t tree_id; +}; + +struct inet_sk_desc; +int nft_connection_switch(struct inet_sk_desc *sk, bool lock, pid_t tree_id); + #endif /* __CR_NFTABLES_H__ */ diff --git a/criu/nftables.c b/criu/nftables.c new file mode 100644 index 0000000..57774e6 --- /dev/null +++ b/criu/nftables.c @@ -0,0 +1,823 @@ +#include <libmnl/libmnl.h> +#include <stddef.h> +#include <string.h> +#include <sys/socket.h> +#include <time.h> + +#include <netinet/in.h> +#include <arpa/inet.h> + +#include <linux/netfilter.h> +#include <linux/netfilter_ipv4.h> +#include <linux/netfilter/nf_tables.h> +#include <linux/if_ether.h> +#include <linux/ip.h> +#include <linux/ipv6.h> + +#include "sk-inet.h" +#include "nftables.h" + +#include "../soccr/soccr.h" + +#include "log.h" + +static struct nftnl_table *setup_table(uint8_t family, const char *table) +{ + struct nftnl_table *t; + + t = nftnl_table_alloc(); + if (t == NULL) + return NULL; + + nftnl_table_set_u32(t, NFTNL_TABLE_FAMILY, family); + if (nftnl_table_set_str(t, NFTNL_TABLE_NAME, table) < 0) + goto err; + + return t; +err: + nftnl_table_free(t); + return NULL; +} + +static struct nftnl_chain *setup_chain(const char *table, + struct nft_chain_params *params, + bool create) +{ + struct nftnl_chain *c; + + c = nftnl_chain_alloc(); + if (c == NULL) + return NULL; + + if (nftnl_chain_set_str(c, NFTNL_CHAIN_TABLE, table) < 0) + goto err; + if (nftnl_chain_set_str(c, NFTNL_CHAIN_NAME, params->name) < 0) + goto err; + if (create) { + nftnl_chain_set_u32(c, NFTNL_CHAIN_HOOKNUM, params->hooknum); + if (nftnl_chain_set_str(c, NFTNL_CHAIN_TYPE, params->type) < 0) + goto err; + nftnl_chain_set_u32(c, NFTNL_CHAIN_PRIO, params->prio); + nftnl_chain_set_u32(c, NFTNL_CHAIN_POLICY, params->policy); + } + + return c; +err: + nftnl_chain_free(c); + return NULL; +} + +static struct nftnl_set *setup_set(uint8_t family, const char *table, + struct nft_set_params *params, + bool create) +{ + struct nftnl_set *s; + + s = nftnl_set_alloc(); + if (s == NULL) + return NULL; + + if (nftnl_set_set_str(s, NFTNL_SET_TABLE, table) < 0) + goto err; + if (nftnl_set_set_str(s, NFTNL_SET_NAME, params->name) < 0) + goto err; + if (create) { + nftnl_set_set_u32(s, NFTNL_SET_FAMILY, family); + nftnl_set_set_u32(s, NFTNL_SET_ID, params->id); + + nftnl_set_set_u32(s, NFTNL_SET_KEY_TYPE, params->datatype); + nftnl_set_set_u32(s, NFTNL_SET_KEY_LEN, params->key_len); + } + + return s; +err: + nftnl_set_free(s); + return NULL; +} + +static int add_mark(struct nftnl_rule *r, uint32_t meta_key, enum nft_registers dreg) +{ + struct nftnl_expr *e; + + e = nftnl_expr_alloc("meta"); + if (e == NULL) + return -1; + + nftnl_expr_set_u32(e, NFTNL_EXPR_META_KEY, meta_key); + nftnl_expr_set_u32(e, NFTNL_EXPR_META_DREG, dreg); + + nftnl_rule_add_expr(r, e); + + return 0; +} + +static int add_proto(struct nftnl_rule *r, enum nft_registers dreg) +{ + struct nftnl_expr *e; + + e = nftnl_expr_alloc("meta"); + if (e == NULL) + return -1; + + nftnl_expr_set_u32(e, NFTNL_EXPR_META_KEY, NFT_META_L4PROTO); + nftnl_expr_set_u32(e, NFTNL_EXPR_META_DREG, dreg); + + nftnl_rule_add_expr(r, e); + + return 0; +} + +static int add_payload(struct nftnl_rule *r, uint32_t base, uint32_t dreg, + uint32_t offset, uint32_t len) +{ + struct nftnl_expr *e; + + e = nftnl_expr_alloc("payload"); + if (e == NULL) + return -1; + + nftnl_expr_set_u32(e, NFTNL_EXPR_PAYLOAD_BASE, base); + nftnl_expr_set_u32(e, NFTNL_EXPR_PAYLOAD_DREG, dreg); + nftnl_expr_set_u32(e, NFTNL_EXPR_PAYLOAD_OFFSET, offset); + nftnl_expr_set_u32(e, NFTNL_EXPR_PAYLOAD_LEN, len); + + nftnl_rule_add_expr(r, e); + + return 0; +} + +static int add_cmp(struct nftnl_rule *r, enum nft_registers sreg, uint32_t op, + const void *data, uint32_t data_len) +{ + struct nftnl_expr *e; + + e = nftnl_expr_alloc("cmp"); + if (e == NULL) + return -1; + + nftnl_expr_set_u32(e, NFTNL_EXPR_CMP_SREG, sreg); + nftnl_expr_set_u32(e, NFTNL_EXPR_CMP_OP, op); + nftnl_expr_set(e, NFTNL_EXPR_CMP_DATA, data, data_len); + + nftnl_rule_add_expr(r, e); + + return 0; +} + +static int add_lookup(struct nftnl_rule *r, enum nft_registers sreg, + const char *set) +{ + struct nftnl_expr *e; + + e = nftnl_expr_alloc("lookup"); + if (e == NULL) + return -1; + + if (nftnl_expr_set_str(e, NFTNL_EXPR_LOOKUP_SET, set) < 0) + goto err; + nftnl_expr_set_u32(e, NFTNL_EXPR_LOOKUP_SREG, sreg); + + nftnl_rule_add_expr(r, e); + + return 0; +err: + nftnl_expr_free(e); + return -1; +} + +static int add_counter(struct nftnl_rule *r) +{ + struct nftnl_expr *e; + + e = nftnl_expr_alloc("counter"); + if (e == NULL) + return -1; + + nftnl_rule_add_expr(r, e); + + return 0; +} + +static int add_verdict(struct nftnl_rule *r, const char *chain, int verdict) +{ + struct nftnl_expr *e; + + e = nftnl_expr_alloc("immediate"); + if (e == NULL) + return -1; + + nftnl_expr_set_u32(e, NFTNL_EXPR_IMM_DREG, NFT_REG_VERDICT); + nftnl_expr_set_u32(e, NFTNL_EXPR_IMM_VERDICT, verdict); + + nftnl_rule_add_expr(r, e); + + return 0; +} + +static int __setup_rule(struct nftnl_rule *r, struct nft_rule_params *params) +{ + /* meta nfproto == <nfproto> */ + if (add_mark(r, NFT_META_PROTOCOL, NFT_REG32_00) < 0) + return -1; + if (add_cmp(r, NFT_REG32_00, NFT_CMP_EQ, ¶ms->nfproto, sizeof(uint32_t))< 0) + return -1; + + /* meta l4proto == <l4proto> */ + if (add_proto(r, NFT_REG32_00) < 0) + return -1; + if (add_cmp(r, NFT_REG32_00, NFT_CMP_EQ, ¶ms->l4proto, sizeof(uint8_t)) < 0) + return -1; + + /* ip saddr . sport . daddr . dport @<set> */ + if (params->ipv6 == false) { + if (add_payload(r, NFT_PAYLOAD_NETWORK_HEADER, NFT_REG32_00, + offsetof(struct iphdr, saddr), IPADDR_LEN) < 0) + return -1; + if (add_payload(r, NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG32_01, + offsetof(struct tcphdr, source), INET_SERVICE_LEN) < 0) + return -1; + if (add_payload(r, NFT_PAYLOAD_NETWORK_HEADER, NFT_REG32_02, + offsetof(struct iphdr, daddr), IPADDR_LEN) < 0) + return -1; + if (add_payload(r, NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG32_03, + offsetof(struct tcphdr, dest), INET_SERVICE_LEN) < 0) + return -1; + + if (add_lookup(r, NFT_REG32_00, params->set_name) < 0) + return -1; + } else { + if (add_payload(r, NFT_PAYLOAD_NETWORK_HEADER, NFT_REG32_00, + offsetof(struct ipv6hdr, saddr), IP6ADDR_LEN) < 0) + return -1; + if (add_payload(r, NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG32_04, + offsetof(struct tcphdr, source), INET_SERVICE_LEN) < 0) + return -1; + if (add_payload(r, NFT_PAYLOAD_NETWORK_HEADER, NFT_REG32_05, + offsetof(struct ipv6hdr, daddr), IP6ADDR_LEN) < 0) + return -1; + if (add_payload(r, NFT_PAYLOAD_TRANSPORT_HEADER, NFT_REG32_09, + offsetof(struct tcphdr, dest), INET_SERVICE_LEN) < 0) + return -1; + + if (add_lookup(r, NFT_REG32_00, params->set_name) < 0) + return -1; + } + + /* counter */ + if (add_counter(r) < 0) + return -1; + + return 0; +} + +static struct nftnl_rule *setup_rule(uint8_t family, const char *table, + struct nft_rule_params *params, + bool create, bool ns) +{ + struct nftnl_rule *r = NULL; + + r = nftnl_rule_alloc(); + if (r == NULL) + return NULL; + + if (nftnl_rule_set_str(r, NFTNL_RULE_TABLE, table) < 0) + goto err; + nftnl_rule_set_u32(r, NFTNL_RULE_FAMILY, family); + if (nftnl_rule_set_str(r, NFTNL_RULE_CHAIN, params->chain_name) < 0) + goto err; + + if (params->mark != 0) { + /* meta mark != <mark> */ + if (add_mark(r, NFT_META_MARK, NFT_REG32_00) < 0) + goto err; + if (add_cmp(r, NFT_REG32_00, params->mark_op, ¶ms->mark, sizeof(uint32_t)) < 0) + goto err; + } + + if (!ns && __setup_rule(r, params) < 0) + goto err; + + /* drop */ + if (add_verdict(r, params->chain_name, params->stmt) < 0) + goto err; + + return r; + +err: + nftnl_rule_free(r); + return NULL; +} + +static struct nlmsghdr *nft_table_detect(struct mnl_params *mnl_params, void *args) +{ + struct nftnl_table *table; + + table = setup_table(NFPROTO_INET, TABLE_NAME); + if (table == NULL) + return NULL; + + return construct_table_buf(mnl_params->buf, NFT_MSG_GETTABLE, NFPROTO_INET, + NLM_F_ACK, mnl_params->seq++, table); +} + +static int nft_table_create(struct mnl_params *mnl_params, void *args) +{ + struct nftnl_table *table; + + table = setup_table(NFPROTO_INET, TABLE_NAME); + if (table == NULL) + return -1; + + construct_table_batch(mnl_params->batch, NFT_MSG_NEWTABLE, NFPROTO_INET, + NLM_F_CREATE|NLM_F_EXCL|NLM_F_ACK, + mnl_params->seq++, table); + + return 0; +} + +static int nft_table_prepare(struct mnl_params *mnl_params) +{ + int result = 0; + + if (mnl_buf_send_and_recv(mnl_params, nft_table_detect, NULL, &result) == 0) + return 0; + + pr_debug("%s: detect table result %d\n", __func__, result); + + if (result == ENOENT && + (mnl_batch_send_and_recv(mnl_params, nft_table_create, NULL, &result) < 0 + && (result != 0 && result != EEXIST))) { + pr_err("%s: create nftables table failed!\n", __func__); + return -1; + } else if (result != 0) { + pr_err("%s: detect table result %d\n", __func__, -result); + return -1; + } + + return 0; +} + +static struct nlmsghdr *nft_chain_detect(struct mnl_params *mnl_params, void *args) +{ + struct nftnl_chain *chain; + + chain = setup_chain(TABLE_NAME, args, false); + if (chain == NULL) + return NULL; + + return construct_chain_buf(mnl_params->buf, NFT_MSG_GETCHAIN, NFPROTO_INET, + NLM_F_ACK, mnl_params->seq++, chain); +} + +static int nft_chain_create(struct mnl_params *mnl_params, void *args) +{ + struct nftnl_chain *chain; + + chain = setup_chain(TABLE_NAME, args, true); + if (chain == NULL) + return -1; + + construct_chain_batch(mnl_params->batch, NFT_MSG_NEWCHAIN, NFPROTO_INET, + NLM_F_CREATE|NLM_F_EXCL|NLM_F_ACK, mnl_params->seq++, chain); + + return 0; +} + +static int nft_chain_prepare_internal(struct mnl_params *mnl_params, + struct nft_chain_params *params) +{ + int result = 0; + + if (mnl_buf_send_and_recv(mnl_params, nft_chain_detect, params, &result) == 0) + return 0; + + pr_debug("%s: detect chain result %d\n", __func__, result); + + if (result == ENOENT && + (mnl_batch_send_and_recv(mnl_params, nft_chain_create, params, &result) < 0 + && (result != 0 && result != EEXIST))) { + pr_err("%s: nftables create chain %s failed!\n", + __func__, params->name); + return -1; + } else if (result != 0) { + pr_err("%s: detect chain result %d\n", __func__, -result); + return -1; + } + + return result; +} + +static int nft_chain_prepare(struct mnl_params *mnl_params) +{ + struct nft_chain_params params = { + .type = "filter", + .prio = NF_IP_PRI_FILTER, + .policy = NF_ACCEPT, + }; + + /* prepare ipv4 input chain in filter table */ + params.name = INPUT_CHAIN_NAME; + params.hooknum = NF_INET_LOCAL_IN; + + if (nft_chain_prepare_internal(mnl_params, ¶ms) < 0) + return -1; + + /* prepare ipv4 output chain in filter table */ + params.name = OUTPUT_CHAIN_NAME; + params.hooknum = NF_INET_LOCAL_OUT; + + if (nft_chain_prepare_internal(mnl_params, ¶ms) < 0) + return -1; + + return 0; +} + +static int nft_set_internal(uint8_t family, struct mnl_params *mnl_params, + struct nft_set_params *params, bool create) +{ + struct nftnl_set *set; + + set = setup_set(family, TABLE_NAME, params, create); + if (set == NULL) + return -1; + + if (create) { + construct_set_batch(mnl_params->batch, NFT_MSG_NEWSET, family, + NLM_F_CREATE|NLM_F_EXCL|NLM_F_ACK, mnl_params->seq++, set); + } else { + construct_set_batch(mnl_params->batch, NFT_MSG_DELSET, family, + 0, mnl_params->seq++, set); + } + + return 0; +} + +static int nft_set_raw(struct mnl_params *mnl_params, + struct mnl_cb_params *args, bool input) +{ + const uint32_t set_id_base = input ? 0x12315 : 0x17173; + const uint8_t family = NFPROTO_INET; + struct nft_set_params params = { 0 }; + char *set_name; + int idx = 0; + + if (!args->ipv6) { + params.datatype = IPv4_KEY_TYPE; + params.key_len = IPv4_KEY_LEN; + idx = 4; + } else { + params.datatype = IPv6_KEY_TYPE; + params.key_len = IPv6_KEY_LEN; + idx = 6; + } + + if (args->ipv6 && input) + set_name = INPUT_IPV6_SET_NAME; + else if (args->ipv6 && !input) + set_name = OUTPUT_IPV6_SET_NAME; + else if (!args->ipv6 && input) + set_name = INPUT_IPV4_SET_NAME; + else + set_name = OUTPUT_IPV4_SET_NAME; + + snprintf(params.name, sizeof(params.name)-1, set_name, args->tree_id); + params.id = set_id_base + args->tree_id + idx; + + if (nft_set_internal(family, mnl_params, ¶ms, args->create) < 0) { + pr_err("%s: create nftables %s %s set failed!\n", __func__, + args->ipv6 ? "ipv6" : "ipv4", + input ? "input" : "output"); + return -1; + } + + return 0; +} + +static int nft_set(struct mnl_params *mnl_params, void *args) +{ + struct mnl_cb_params *params = args; + + params->ipv6 = false; + if (nft_set_raw(mnl_params, params, true) < 0) + return -1; + + if (nft_set_raw(mnl_params, params, false) < 0) + return -1; + + params->ipv6 = true; + if (nft_set_raw(mnl_params, params, true) < 0) + return -1; + + if (nft_set_raw(mnl_params, params, false) < 0) + return -1; + + return 0; +} + +static int nft_set_common(struct mnl_params *mnl_params, pid_t tree_id, bool create) +{ + struct mnl_cb_params params = { + .tree_id = tree_id, + .create = create, + }; + int result = 0; + + if (create && + (mnl_batch_send_and_recv(mnl_params, nft_set, ¶ms, &result) < 0 + && (result != 0 && result != EEXIST))) { + pr_err("%s: create set failed!\n", __func__); + return -1; + } else if (!create && + mnl_batch_send_and_recv(mnl_params, nft_set, ¶ms, NULL) < 0) { + pr_err("%s: delete set failed!\n", __func__); + return -1; + } + + return 0; +} + +static int nft_rule_internal(uint8_t family, struct mnl_params *mnl_params, + struct nft_rule_params *params, bool create) +{ + struct nftnl_rule *rule; + + rule = setup_rule(family, TABLE_NAME, params, create, false); + if (rule == NULL) + return -1; + + if (create) { + construct_rule_batch(mnl_params->batch, NFT_MSG_NEWRULE, family, + NLM_F_CREATE|NLM_F_EXCL|NLM_F_ACK, + mnl_params->seq++, rule); + } else { + construct_rule_batch(mnl_params->batch, NFT_MSG_DELRULE, family, + 0, mnl_params->seq++, rule); + } + + return 0; +} + +static int nft_rule_raw(struct mnl_params *mnl_params, struct mnl_cb_params *args, + struct nft_rule_params *params) +{ + char *set_name; + + params->nfproto = params->ipv6 ? htons(ETH_P_IPV6) : htons(ETH_P_IP); + + set_name = params->ipv6 ? INPUT_IPV6_SET_NAME : INPUT_IPV4_SET_NAME; + params->chain_name = INPUT_CHAIN_NAME; + snprintf(params->set_name, sizeof(params->set_name)-1, set_name, args->tree_id); + if (nft_rule_internal(NFPROTO_INET, mnl_params, params, args->create) < 0) { + pr_err("%s: create nft %s input rule failed!\n", + __func__, params->ipv6 ? "ipv6" : "ipv4"); + return -1; + } + + set_name = params->ipv6 ? OUTPUT_IPV6_SET_NAME : OUTPUT_IPV4_SET_NAME; + params->chain_name = OUTPUT_CHAIN_NAME; + snprintf(params->set_name, sizeof(params->set_name)-1, set_name, args->tree_id); + if (nft_rule_internal(NFPROTO_INET, mnl_params, params, args->create) < 0) { + pr_err("%s: create nftables %s output rule failed!\n", + __func__, params->ipv6 ? "ipv6" : "ipv4"); + return -1; + } + + return 0; +} + +static int nft_rule(struct mnl_params *mnl_params, void *args) +{ + struct nft_rule_params params = { + .l4proto = IPPROTO_TCP, + .mark = SOCCR_MARK, + .mark_op = NFT_CMP_NEQ, + .stmt = NF_DROP, + }; + + params.ipv6 = false; + if (nft_rule_raw(mnl_params, args, ¶ms) < 0) + return -1; + + params.ipv6 = true; + if (nft_rule_raw(mnl_params, args, ¶ms) < 0) + return -1; + + return 0; +} + +static int nft_rule_common(struct mnl_params *mnl_params, pid_t tree_id, bool create) +{ + struct mnl_cb_params params = { + .tree_id = tree_id, + .create = create, + }; + int result = 0; + + if (create && + (mnl_batch_send_and_recv(mnl_params, nft_rule, ¶ms, &result) < 0 + && (result != 0 && result != EEXIST))) { + pr_err("%s: create rule failed!\n", __func__); + return -1; + } else if (!create && + mnl_batch_send_and_recv(mnl_params, nft_rule, ¶ms, NULL) < 0) { + pr_err("%s: delete rule failed!\n", __func__); + return -1; + } + + return 0; +} + +static int network_prepare_internal(struct mnl_params *params, batch_func_t _, void *args) +{ + pid_t tree_id = *(pid_t *)args; + + if (nft_table_prepare(params) < 0) + return -1; + + if (nft_chain_prepare(params) < 0) + return -1; + + if (nft_set_common(params, tree_id, true) < 0) + return -1; + + if (nft_rule_common(params, tree_id, true) < 0) + return -1; + + return 0; +} + +int network_prepare(pid_t tree_id) +{ + pr_info("Prepare network\n"); + + return mnl_common(network_prepare_internal, NULL, &tree_id); +} + +static int network_unprepare_internal(struct mnl_params *params, + batch_func_t _, void *args) +{ + pid_t tree_id = *(pid_t *)args; + + if (nft_rule_common(params, tree_id, false) < 0) + return -1; + + if (nft_set_common(params, tree_id, false) < 0) + return -1; + + return 0; +} + +void network_unprepare(pid_t tree_id) +{ + pr_info("Unprepare network\n"); + + mnl_common(network_unprepare_internal, NULL, &tree_id); +} + +static int add_set_elem_internal(struct nftnl_set *s, void *data, size_t len) +{ + struct nftnl_set_elem *e; + + e = nftnl_set_elem_alloc(); + if (e == NULL) + return -1; + + nftnl_set_elem_set(e, NFTNL_SET_ELEM_KEY, data, len); + + nftnl_set_elem_add(s, e); + + return 0; +} + +static struct nftnl_set *add_set_elem(const char *table, const char *set, + void *data, size_t len) +{ + struct nftnl_set *s; + + s = nftnl_set_alloc(); + if (s == NULL) + return NULL; + + if (nftnl_set_set_str(s, NFTNL_SET_TABLE, table) < 0) + goto err; + if (nftnl_set_set_str(s, NFTNL_SET_NAME, set) < 0) + goto err; + + if (add_set_elem_internal(s, data, len) < 0) + goto err; + + return s; + +err: + nftnl_set_free(s); + return NULL; +} + +static int nft_set_elem(uint8_t family, struct mnl_params *mnl_param, + struct nft_set_elem_params *elem_param, + bool lock) +{ + struct nftnl_set *set; + + set = add_set_elem(TABLE_NAME, elem_param->set_name, + elem_param->data, elem_param->data_len); + if (set == NULL) + return -1; + + if (lock) { + construct_set_elems_batch(mnl_param->batch, NFT_MSG_NEWSETELEM, + family, NLM_F_CREATE|NLM_F_EXCL, + mnl_param->seq++, set); + } else { + construct_set_elems_batch(mnl_param->batch, NFT_MSG_DELSETELEM, + family, 0, mnl_param->seq++, set); + } + + return 0; +} + +static void construct_set_elem_key(void *data, struct nf_conn_params *param, bool output) +{ + size_t offset = 0; + size_t addr_len = param->family == AF_INET ? IPADDR_LEN : IP6ADDR_LEN; + + memcpy(data+offset, output ? param->src_addr : param->dst_addr, addr_len); + offset = addr_len; + *(uint32_t *)(data + offset) = htons(output ? param->src_port : param->dst_port); + offset += sizeof(uint32_t); + memcpy(data+offset, output ? param->dst_addr : param->src_addr, addr_len); + offset += addr_len; + *(uint32_t *)(data + offset) = htons(output ? param->dst_port : param->src_port); +} + +static int nf_connection_switch_raw(struct mnl_params *mnl_params, void *args) +{ + struct nf_conn_params *param = args; + char *input_set_name, *output_set_name; + struct nft_set_elem_params elem; + + switch (param->family) { + case AF_INET: + input_set_name = INPUT_IPV4_SET_NAME; + output_set_name = OUTPUT_IPV4_SET_NAME; + elem.data_len = IPv4_KEY_LEN; + break; + case AF_INET6: + input_set_name = INPUT_IPV6_SET_NAME; + output_set_name = OUTPUT_IPV6_SET_NAME; + elem.data_len = IPv6_KEY_LEN; + break; + default: + pr_err("Unknown socket family %d\n", param->family); + return -1; + } + + construct_set_elem_key(elem.data, param, false); + snprintf(elem.set_name, sizeof(elem.set_name)-1, input_set_name, param->tree_id); + if (nft_set_elem(NFPROTO_INET, mnl_params, &elem, param->lock) < 0) + return -1; + + construct_set_elem_key(elem.data, param, true); + snprintf(elem.set_name, sizeof(elem.set_name)-1, output_set_name, param->tree_id); + if (nft_set_elem(NFPROTO_INET, mnl_params, &elem, param->lock) < 0) + return -1; + + return 0; +} + +/* IPv4-Mapped IPv6 Addresses */ +static int ipv6_addr_mapped(uint32_t *addr) +{ + return (addr[2] == htonl(0x0000ffff)); +} + +int nft_connection_switch(struct inet_sk_desc *sk, bool lock, pid_t tree_id) +{ + char sip[INET_ADDR_LEN], dip[INET_ADDR_LEN]; + struct nf_conn_params param = { + .family = sk->sd.family, + .src_addr = sk->src_addr, + .src_port = sk->src_port, + .dst_addr = sk->dst_addr, + .dst_port = sk->dst_port, + .lock = lock, + .tree_id = tree_id, + }; + + if (param.family == AF_INET6 && ipv6_addr_mapped(param.dst_addr)) { + param.family = AF_INET; + param.src_addr = ¶m.src_addr[3]; + param.dst_addr = ¶m.dst_addr[3]; + } + + if (!inet_ntop(param.family, (void *)param.src_addr, sip, INET_ADDR_LEN) || + !inet_ntop(param.family, (void *)param.dst_addr, dip, INET_ADDR_LEN)) { + pr_perror("nf: Can't translate ip addr"); + return -1; + } + + pr_info("%s %s:%d - %s:%d connection\n", lock ? "Locked" : "Unlocked", + sip, (int)param.src_port, dip, (int)param.dst_port); + + return mnl_sendmsg(nf_connection_switch_raw, ¶m); +}