| From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 |
| From: "Jason A. Donenfeld" <Jason@zx2c4.com> |
| Date: Fri, 4 Jun 2021 17:17:38 +0200 |
| Subject: [PATCH] wireguard: allowedips: free empty intermediate nodes when |
| removing single node |
| |
| commit bf7b042dc62a31f66d3a41dd4dfc7806f267b307 upstream. |
| |
| When removing single nodes, it's possible that that node's parent is an |
| empty intermediate node, in which case, it too should be removed. |
| Otherwise the trie fills up and never is fully emptied, leading to |
| gradual memory leaks over time for tries that are modified often. There |
| was originally code to do this, but was removed during refactoring in |
| 2016 and never reworked. Now that we have proper parent pointers from |
| the previous commits, we can implement this properly. |
| |
| In order to reduce branching and expensive comparisons, we want to keep |
| the double pointer for parent assignment (which lets us easily chain up |
| to the root), but we still need to actually get the parent's base |
| address. So encode the bit number into the last two bits of the pointer, |
| and pack and unpack it as needed. This is a little bit clumsy but is the |
| fastest and less memory wasteful of the compromises. Note that we align |
| the root struct here to a minimum of 4, because it's embedded into a |
| larger struct, and we're relying on having the bottom two bits for our |
| flag, which would only be 16-bit aligned on m68k. |
| |
| The existing macro-based helpers were a bit unwieldy for adding the bit |
| packing to, so this commit replaces them with safer and clearer ordinary |
| functions. |
| |
| We add a test to the randomized/fuzzer part of the selftests, to free |
| the randomized tries by-peer, refuzz it, and repeat, until it's supposed |
| to be empty, and then then see if that actually resulted in the whole |
| thing being emptied. That combined with kmemcheck should hopefully make |
| sure this commit is doing what it should. Along the way this resulted in |
| various other cleanups of the tests and fixes for recent graphviz. |
| |
| Fixes: e7096c131e51 ("net: WireGuard secure network tunnel") |
| Cc: stable@vger.kernel.org |
| Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com> |
| Signed-off-by: David S. Miller <davem@davemloft.net> |
| Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com> |
| --- |
| drivers/net/wireguard/allowedips.c | 102 ++++++------ |
| drivers/net/wireguard/allowedips.h | 4 +- |
| drivers/net/wireguard/selftest/allowedips.c | 162 ++++++++++---------- |
| 3 files changed, 137 insertions(+), 131 deletions(-) |
| |
| --- a/drivers/net/wireguard/allowedips.c |
| +++ b/drivers/net/wireguard/allowedips.c |
| @@ -30,8 +30,11 @@ static void copy_and_assign_cidr(struct |
| node->bitlen = bits; |
| memcpy(node->bits, src, bits / 8U); |
| } |
| -#define CHOOSE_NODE(parent, key) \ |
| - parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] |
| + |
| +static inline u8 choose(struct allowedips_node *node, const u8 *key) |
| +{ |
| + return (key[node->bit_at_a] >> node->bit_at_b) & 1; |
| +} |
| |
| static void push_rcu(struct allowedips_node **stack, |
| struct allowedips_node __rcu *p, unsigned int *len) |
| @@ -112,7 +115,7 @@ static struct allowedips_node *find_node |
| found = node; |
| if (node->cidr == bits) |
| break; |
| - node = rcu_dereference_bh(CHOOSE_NODE(node, key)); |
| + node = rcu_dereference_bh(node->bit[choose(node, key)]); |
| } |
| return found; |
| } |
| @@ -144,8 +147,7 @@ static bool node_placement(struct allowe |
| u8 cidr, u8 bits, struct allowedips_node **rnode, |
| struct mutex *lock) |
| { |
| - struct allowedips_node *node = rcu_dereference_protected(trie, |
| - lockdep_is_held(lock)); |
| + struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); |
| struct allowedips_node *parent = NULL; |
| bool exact = false; |
| |
| @@ -155,13 +157,24 @@ static bool node_placement(struct allowe |
| exact = true; |
| break; |
| } |
| - node = rcu_dereference_protected(CHOOSE_NODE(parent, key), |
| - lockdep_is_held(lock)); |
| + node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock)); |
| } |
| *rnode = parent; |
| return exact; |
| } |
| |
| +static inline void connect_node(struct allowedips_node **parent, u8 bit, struct allowedips_node *node) |
| +{ |
| + node->parent_bit_packed = (unsigned long)parent | bit; |
| + rcu_assign_pointer(*parent, node); |
| +} |
| + |
| +static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node) |
| +{ |
| + u8 bit = choose(parent, node->bits); |
| + connect_node(&parent->bit[bit], bit, node); |
| +} |
| + |
| static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, |
| u8 cidr, struct wg_peer *peer, struct mutex *lock) |
| { |
| @@ -177,8 +190,7 @@ static int add(struct allowedips_node __ |
| RCU_INIT_POINTER(node->peer, peer); |
| list_add_tail(&node->peer_list, &peer->allowedips_list); |
| copy_and_assign_cidr(node, key, cidr, bits); |
| - rcu_assign_pointer(node->parent_bit, trie); |
| - rcu_assign_pointer(*trie, node); |
| + connect_node(trie, 2, node); |
| return 0; |
| } |
| if (node_placement(*trie, key, cidr, bits, &node, lock)) { |
| @@ -197,10 +209,10 @@ static int add(struct allowedips_node __ |
| if (!node) { |
| down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); |
| } else { |
| - down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock)); |
| + const u8 bit = choose(node, key); |
| + down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock)); |
| if (!down) { |
| - rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, key)); |
| - rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); |
| + connect_node(&node->bit[bit], bit, newnode); |
| return 0; |
| } |
| } |
| @@ -208,15 +220,11 @@ static int add(struct allowedips_node __ |
| parent = node; |
| |
| if (newnode->cidr == cidr) { |
| - rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(newnode, down->bits)); |
| - rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); |
| - if (!parent) { |
| - rcu_assign_pointer(newnode->parent_bit, trie); |
| - rcu_assign_pointer(*trie, newnode); |
| - } else { |
| - rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(parent, newnode->bits)); |
| - rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode); |
| - } |
| + choose_and_connect_node(newnode, down); |
| + if (!parent) |
| + connect_node(trie, 2, newnode); |
| + else |
| + choose_and_connect_node(parent, newnode); |
| return 0; |
| } |
| |
| @@ -229,17 +237,12 @@ static int add(struct allowedips_node __ |
| INIT_LIST_HEAD(&node->peer_list); |
| copy_and_assign_cidr(node, newnode->bits, cidr, bits); |
| |
| - rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(node, down->bits)); |
| - rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); |
| - rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, newnode->bits)); |
| - rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); |
| - if (!parent) { |
| - rcu_assign_pointer(node->parent_bit, trie); |
| - rcu_assign_pointer(*trie, node); |
| - } else { |
| - rcu_assign_pointer(node->parent_bit, &CHOOSE_NODE(parent, node->bits)); |
| - rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node); |
| - } |
| + choose_and_connect_node(node, down); |
| + choose_and_connect_node(node, newnode); |
| + if (!parent) |
| + connect_node(trie, 2, node); |
| + else |
| + choose_and_connect_node(parent, node); |
| return 0; |
| } |
| |
| @@ -297,7 +300,8 @@ int wg_allowedips_insert_v6(struct allow |
| void wg_allowedips_remove_by_peer(struct allowedips *table, |
| struct wg_peer *peer, struct mutex *lock) |
| { |
| - struct allowedips_node *node, *child, *tmp; |
| + struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; |
| + bool free_parent; |
| |
| if (list_empty(&peer->allowedips_list)) |
| return; |
| @@ -307,19 +311,29 @@ void wg_allowedips_remove_by_peer(struct |
| RCU_INIT_POINTER(node->peer, NULL); |
| if (node->bit[0] && node->bit[1]) |
| continue; |
| - child = rcu_dereference_protected( |
| - node->bit[!rcu_access_pointer(node->bit[0])], |
| - lockdep_is_held(lock)); |
| + child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], |
| + lockdep_is_held(lock)); |
| if (child) |
| - child->parent_bit = node->parent_bit; |
| - *rcu_dereference_protected(node->parent_bit, lockdep_is_held(lock)) = child; |
| + child->parent_bit_packed = node->parent_bit_packed; |
| + parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); |
| + *parent_bit = child; |
| + parent = (void *)parent_bit - |
| + offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); |
| + free_parent = !rcu_access_pointer(node->bit[0]) && |
| + !rcu_access_pointer(node->bit[1]) && |
| + (node->parent_bit_packed & 3) <= 1 && |
| + !rcu_access_pointer(parent->peer); |
| + if (free_parent) |
| + child = rcu_dereference_protected( |
| + parent->bit[!(node->parent_bit_packed & 1)], |
| + lockdep_is_held(lock)); |
| call_rcu(&node->rcu, node_free_rcu); |
| - |
| - /* TODO: Note that we currently don't walk up and down in order to |
| - * free any potential filler nodes. This means that this function |
| - * doesn't free up as much as it could, which could be revisited |
| - * at some point. |
| - */ |
| + if (!free_parent) |
| + continue; |
| + if (child) |
| + child->parent_bit_packed = parent->parent_bit_packed; |
| + *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; |
| + call_rcu(&parent->rcu, node_free_rcu); |
| } |
| } |
| |
| --- a/drivers/net/wireguard/allowedips.h |
| +++ b/drivers/net/wireguard/allowedips.h |
| @@ -19,7 +19,7 @@ struct allowedips_node { |
| u8 bits[16] __aligned(__alignof(u64)); |
| |
| /* Keep rarely used members at bottom to be beyond cache line. */ |
| - struct allowedips_node *__rcu *parent_bit; |
| + unsigned long parent_bit_packed; |
| union { |
| struct list_head peer_list; |
| struct rcu_head rcu; |
| @@ -30,7 +30,7 @@ struct allowedips { |
| struct allowedips_node __rcu *root4; |
| struct allowedips_node __rcu *root6; |
| u64 seq; |
| -}; |
| +} __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 16-bit alignment. */ |
| |
| void wg_allowedips_init(struct allowedips *table); |
| void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); |
| --- a/drivers/net/wireguard/selftest/allowedips.c |
| +++ b/drivers/net/wireguard/selftest/allowedips.c |
| @@ -19,32 +19,22 @@ |
| |
| #include <linux/siphash.h> |
| |
| -static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, |
| - u8 cidr) |
| -{ |
| - swap_endian(dst, src, bits); |
| - memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8); |
| - if (cidr) |
| - dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8); |
| -} |
| - |
| static __init void print_node(struct allowedips_node *node, u8 bits) |
| { |
| char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n"; |
| - char *fmt_declaration = KERN_DEBUG |
| - "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; |
| + char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; |
| + u8 ip1[16], ip2[16], cidr1, cidr2; |
| char *style = "dotted"; |
| - u8 ip1[16], ip2[16]; |
| u32 color = 0; |
| |
| + if (node == NULL) |
| + return; |
| if (bits == 32) { |
| fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n"; |
| - fmt_declaration = KERN_DEBUG |
| - "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; |
| + fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; |
| } else if (bits == 128) { |
| fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n"; |
| - fmt_declaration = KERN_DEBUG |
| - "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; |
| + fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; |
| } |
| if (node->peer) { |
| hsiphash_key_t key = { { 0 } }; |
| @@ -55,24 +45,20 @@ static __init void print_node(struct all |
| hsiphash_1u32(0xabad1dea, &key) % 200; |
| style = "bold"; |
| } |
| - swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr); |
| - printk(fmt_declaration, ip1, node->cidr, style, color); |
| + wg_allowedips_read_node(node, ip1, &cidr1); |
| + printk(fmt_declaration, ip1, cidr1, style, color); |
| if (node->bit[0]) { |
| - swap_endian_and_apply_cidr(ip2, |
| - rcu_dereference_raw(node->bit[0])->bits, bits, |
| - node->cidr); |
| - printk(fmt_connection, ip1, node->cidr, ip2, |
| - rcu_dereference_raw(node->bit[0])->cidr); |
| - print_node(rcu_dereference_raw(node->bit[0]), bits); |
| + wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2); |
| + printk(fmt_connection, ip1, cidr1, ip2, cidr2); |
| } |
| if (node->bit[1]) { |
| - swap_endian_and_apply_cidr(ip2, |
| - rcu_dereference_raw(node->bit[1])->bits, |
| - bits, node->cidr); |
| - printk(fmt_connection, ip1, node->cidr, ip2, |
| - rcu_dereference_raw(node->bit[1])->cidr); |
| - print_node(rcu_dereference_raw(node->bit[1]), bits); |
| + wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2); |
| + printk(fmt_connection, ip1, cidr1, ip2, cidr2); |
| } |
| + if (node->bit[0]) |
| + print_node(rcu_dereference_raw(node->bit[0]), bits); |
| + if (node->bit[1]) |
| + print_node(rcu_dereference_raw(node->bit[1]), bits); |
| } |
| |
| static __init void print_tree(struct allowedips_node __rcu *top, u8 bits) |
| @@ -121,8 +107,8 @@ static __init inline union nf_inet_addr |
| { |
| union nf_inet_addr mask; |
| |
| - memset(&mask, 0x00, 128 / 8); |
| - memset(&mask, 0xff, cidr / 8); |
| + memset(&mask, 0, sizeof(mask)); |
| + memset(&mask.all, 0xff, cidr / 8); |
| if (cidr % 32) |
| mask.all[cidr / 32] = (__force u32)htonl( |
| (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); |
| @@ -149,42 +135,36 @@ horrible_mask_self(struct horrible_allow |
| } |
| |
| static __init inline bool |
| -horrible_match_v4(const struct horrible_allowedips_node *node, |
| - struct in_addr *ip) |
| +horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip) |
| { |
| return (ip->s_addr & node->mask.ip) == node->ip.ip; |
| } |
| |
| static __init inline bool |
| -horrible_match_v6(const struct horrible_allowedips_node *node, |
| - struct in6_addr *ip) |
| +horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip) |
| { |
| - return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == |
| - node->ip.ip6[0] && |
| - (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == |
| - node->ip.ip6[1] && |
| - (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == |
| - node->ip.ip6[2] && |
| + return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] && |
| + (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] && |
| + (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] && |
| (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; |
| } |
| |
| static __init void |
| -horrible_insert_ordered(struct horrible_allowedips *table, |
| - struct horrible_allowedips_node *node) |
| +horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node) |
| { |
| struct horrible_allowedips_node *other = NULL, *where = NULL; |
| u8 my_cidr = horrible_mask_to_cidr(node->mask); |
| |
| hlist_for_each_entry(other, &table->head, table) { |
| - if (!memcmp(&other->mask, &node->mask, |
| - sizeof(union nf_inet_addr)) && |
| - !memcmp(&other->ip, &node->ip, |
| - sizeof(union nf_inet_addr)) && |
| - other->ip_version == node->ip_version) { |
| + if (other->ip_version == node->ip_version && |
| + !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) && |
| + !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) { |
| other->value = node->value; |
| kfree(node); |
| return; |
| } |
| + } |
| + hlist_for_each_entry(other, &table->head, table) { |
| where = other; |
| if (horrible_mask_to_cidr(other->mask) <= my_cidr) |
| break; |
| @@ -201,8 +181,7 @@ static __init int |
| horrible_allowedips_insert_v4(struct horrible_allowedips *table, |
| struct in_addr *ip, u8 cidr, void *value) |
| { |
| - struct horrible_allowedips_node *node = kzalloc(sizeof(*node), |
| - GFP_KERNEL); |
| + struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); |
| |
| if (unlikely(!node)) |
| return -ENOMEM; |
| @@ -219,8 +198,7 @@ static __init int |
| horrible_allowedips_insert_v6(struct horrible_allowedips *table, |
| struct in6_addr *ip, u8 cidr, void *value) |
| { |
| - struct horrible_allowedips_node *node = kzalloc(sizeof(*node), |
| - GFP_KERNEL); |
| + struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); |
| |
| if (unlikely(!node)) |
| return -ENOMEM; |
| @@ -234,39 +212,43 @@ horrible_allowedips_insert_v6(struct hor |
| } |
| |
| static __init void * |
| -horrible_allowedips_lookup_v4(struct horrible_allowedips *table, |
| - struct in_addr *ip) |
| +horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip) |
| { |
| struct horrible_allowedips_node *node; |
| - void *ret = NULL; |
| |
| hlist_for_each_entry(node, &table->head, table) { |
| - if (node->ip_version != 4) |
| - continue; |
| - if (horrible_match_v4(node, ip)) { |
| - ret = node->value; |
| - break; |
| - } |
| + if (node->ip_version == 4 && horrible_match_v4(node, ip)) |
| + return node->value; |
| } |
| - return ret; |
| + return NULL; |
| } |
| |
| static __init void * |
| -horrible_allowedips_lookup_v6(struct horrible_allowedips *table, |
| - struct in6_addr *ip) |
| +horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip) |
| { |
| struct horrible_allowedips_node *node; |
| - void *ret = NULL; |
| |
| hlist_for_each_entry(node, &table->head, table) { |
| - if (node->ip_version != 6) |
| + if (node->ip_version == 6 && horrible_match_v6(node, ip)) |
| + return node->value; |
| + } |
| + return NULL; |
| +} |
| + |
| + |
| +static __init void |
| +horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value) |
| +{ |
| + struct horrible_allowedips_node *node; |
| + struct hlist_node *h; |
| + |
| + hlist_for_each_entry_safe(node, h, &table->head, table) { |
| + if (node->value != value) |
| continue; |
| - if (horrible_match_v6(node, ip)) { |
| - ret = node->value; |
| - break; |
| - } |
| + hlist_del(&node->table); |
| + kfree(node); |
| } |
| - return ret; |
| + |
| } |
| |
| static __init bool randomized_test(void) |
| @@ -397,23 +379,33 @@ static __init bool randomized_test(void) |
| print_tree(t.root6, 128); |
| } |
| |
| - for (i = 0; i < NUM_QUERIES; ++i) { |
| - prandom_bytes(ip, 4); |
| - if (lookup(t.root4, 32, ip) != |
| - horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { |
| - pr_err("allowedips random self-test: FAIL\n"); |
| - goto free; |
| + for (j = 0;; ++j) { |
| + for (i = 0; i < NUM_QUERIES; ++i) { |
| + prandom_bytes(ip, 4); |
| + if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { |
| + horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip); |
| + pr_err("allowedips random v4 self-test: FAIL\n"); |
| + goto free; |
| + } |
| + prandom_bytes(ip, 16); |
| + if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { |
| + pr_err("allowedips random v6 self-test: FAIL\n"); |
| + goto free; |
| + } |
| } |
| + if (j >= NUM_PEERS) |
| + break; |
| + mutex_lock(&mutex); |
| + wg_allowedips_remove_by_peer(&t, peers[j], &mutex); |
| + mutex_unlock(&mutex); |
| + horrible_allowedips_remove_by_value(&h, peers[j]); |
| } |
| |
| - for (i = 0; i < NUM_QUERIES; ++i) { |
| - prandom_bytes(ip, 16); |
| - if (lookup(t.root6, 128, ip) != |
| - horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { |
| - pr_err("allowedips random self-test: FAIL\n"); |
| - goto free; |
| - } |
| + if (t.root4 || t.root6) { |
| + pr_err("allowedips random self-test removal: FAIL\n"); |
| + goto free; |
| } |
| + |
| ret = true; |
| |
| free: |