| From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 | 
 | From: "Jason A. Donenfeld" <Jason@zx2c4.com> | 
 | Date: Fri, 4 Jun 2021 17:17:38 +0200 | 
 | Subject: [PATCH] wireguard: allowedips: free empty intermediate nodes when | 
 |  removing single node | 
 |  | 
 | commit bf7b042dc62a31f66d3a41dd4dfc7806f267b307 upstream. | 
 |  | 
 | When removing single nodes, it's possible that that node's parent is an | 
 | empty intermediate node, in which case, it too should be removed. | 
 | Otherwise the trie fills up and never is fully emptied, leading to | 
 | gradual memory leaks over time for tries that are modified often. There | 
 | was originally code to do this, but was removed during refactoring in | 
 | 2016 and never reworked. Now that we have proper parent pointers from | 
 | the previous commits, we can implement this properly. | 
 |  | 
 | In order to reduce branching and expensive comparisons, we want to keep | 
 | the double pointer for parent assignment (which lets us easily chain up | 
 | to the root), but we still need to actually get the parent's base | 
 | address. So encode the bit number into the last two bits of the pointer, | 
 | and pack and unpack it as needed. This is a little bit clumsy but is the | 
 | fastest and less memory wasteful of the compromises. Note that we align | 
 | the root struct here to a minimum of 4, because it's embedded into a | 
 | larger struct, and we're relying on having the bottom two bits for our | 
 | flag, which would only be 16-bit aligned on m68k. | 
 |  | 
 | The existing macro-based helpers were a bit unwieldy for adding the bit | 
 | packing to, so this commit replaces them with safer and clearer ordinary | 
 | functions. | 
 |  | 
 | We add a test to the randomized/fuzzer part of the selftests, to free | 
 | the randomized tries by-peer, refuzz it, and repeat, until it's supposed | 
 | to be empty, and then then see if that actually resulted in the whole | 
 | thing being emptied. That combined with kmemcheck should hopefully make | 
 | sure this commit is doing what it should. Along the way this resulted in | 
 | various other cleanups of the tests and fixes for recent graphviz. | 
 |  | 
 | Fixes: e7096c131e51 ("net: WireGuard secure network tunnel") | 
 | Cc: stable@vger.kernel.org | 
 | Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com> | 
 | Signed-off-by: David S. Miller <davem@davemloft.net> | 
 | Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com> | 
 | --- | 
 |  drivers/net/wireguard/allowedips.c          | 102 ++++++------ | 
 |  drivers/net/wireguard/allowedips.h          |   4 +- | 
 |  drivers/net/wireguard/selftest/allowedips.c | 162 ++++++++++---------- | 
 |  3 files changed, 137 insertions(+), 131 deletions(-) | 
 |  | 
 | --- a/drivers/net/wireguard/allowedips.c | 
 | +++ b/drivers/net/wireguard/allowedips.c | 
 | @@ -30,8 +30,11 @@ static void copy_and_assign_cidr(struct | 
 |  	node->bitlen = bits; | 
 |  	memcpy(node->bits, src, bits / 8U); | 
 |  } | 
 | -#define CHOOSE_NODE(parent, key) \ | 
 | -	parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] | 
 | + | 
 | +static inline u8 choose(struct allowedips_node *node, const u8 *key) | 
 | +{ | 
 | +	return (key[node->bit_at_a] >> node->bit_at_b) & 1; | 
 | +} | 
 |   | 
 |  static void push_rcu(struct allowedips_node **stack, | 
 |  		     struct allowedips_node __rcu *p, unsigned int *len) | 
 | @@ -112,7 +115,7 @@ static struct allowedips_node *find_node | 
 |  			found = node; | 
 |  		if (node->cidr == bits) | 
 |  			break; | 
 | -		node = rcu_dereference_bh(CHOOSE_NODE(node, key)); | 
 | +		node = rcu_dereference_bh(node->bit[choose(node, key)]); | 
 |  	} | 
 |  	return found; | 
 |  } | 
 | @@ -144,8 +147,7 @@ static bool node_placement(struct allowe | 
 |  			   u8 cidr, u8 bits, struct allowedips_node **rnode, | 
 |  			   struct mutex *lock) | 
 |  { | 
 | -	struct allowedips_node *node = rcu_dereference_protected(trie, | 
 | -						lockdep_is_held(lock)); | 
 | +	struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); | 
 |  	struct allowedips_node *parent = NULL; | 
 |  	bool exact = false; | 
 |   | 
 | @@ -155,13 +157,24 @@ static bool node_placement(struct allowe | 
 |  			exact = true; | 
 |  			break; | 
 |  		} | 
 | -		node = rcu_dereference_protected(CHOOSE_NODE(parent, key), | 
 | -						 lockdep_is_held(lock)); | 
 | +		node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock)); | 
 |  	} | 
 |  	*rnode = parent; | 
 |  	return exact; | 
 |  } | 
 |   | 
 | +static inline void connect_node(struct allowedips_node **parent, u8 bit, struct allowedips_node *node) | 
 | +{ | 
 | +	node->parent_bit_packed = (unsigned long)parent | bit; | 
 | +	rcu_assign_pointer(*parent, node); | 
 | +} | 
 | + | 
 | +static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node) | 
 | +{ | 
 | +	u8 bit = choose(parent, node->bits); | 
 | +	connect_node(&parent->bit[bit], bit, node); | 
 | +} | 
 | + | 
 |  static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, | 
 |  	       u8 cidr, struct wg_peer *peer, struct mutex *lock) | 
 |  { | 
 | @@ -177,8 +190,7 @@ static int add(struct allowedips_node __ | 
 |  		RCU_INIT_POINTER(node->peer, peer); | 
 |  		list_add_tail(&node->peer_list, &peer->allowedips_list); | 
 |  		copy_and_assign_cidr(node, key, cidr, bits); | 
 | -		rcu_assign_pointer(node->parent_bit, trie); | 
 | -		rcu_assign_pointer(*trie, node); | 
 | +		connect_node(trie, 2, node); | 
 |  		return 0; | 
 |  	} | 
 |  	if (node_placement(*trie, key, cidr, bits, &node, lock)) { | 
 | @@ -197,10 +209,10 @@ static int add(struct allowedips_node __ | 
 |  	if (!node) { | 
 |  		down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); | 
 |  	} else { | 
 | -		down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock)); | 
 | +		const u8 bit = choose(node, key); | 
 | +		down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock)); | 
 |  		if (!down) { | 
 | -			rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, key)); | 
 | -			rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); | 
 | +			connect_node(&node->bit[bit], bit, newnode); | 
 |  			return 0; | 
 |  		} | 
 |  	} | 
 | @@ -208,15 +220,11 @@ static int add(struct allowedips_node __ | 
 |  	parent = node; | 
 |   | 
 |  	if (newnode->cidr == cidr) { | 
 | -		rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(newnode, down->bits)); | 
 | -		rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); | 
 | -		if (!parent) { | 
 | -			rcu_assign_pointer(newnode->parent_bit, trie); | 
 | -			rcu_assign_pointer(*trie, newnode); | 
 | -		} else { | 
 | -			rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(parent, newnode->bits)); | 
 | -			rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode); | 
 | -		} | 
 | +		choose_and_connect_node(newnode, down); | 
 | +		if (!parent) | 
 | +			connect_node(trie, 2, newnode); | 
 | +		else | 
 | +			choose_and_connect_node(parent, newnode); | 
 |  		return 0; | 
 |  	} | 
 |   | 
 | @@ -229,17 +237,12 @@ static int add(struct allowedips_node __ | 
 |  	INIT_LIST_HEAD(&node->peer_list); | 
 |  	copy_and_assign_cidr(node, newnode->bits, cidr, bits); | 
 |   | 
 | -	rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(node, down->bits)); | 
 | -	rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); | 
 | -	rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, newnode->bits)); | 
 | -	rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); | 
 | -	if (!parent) { | 
 | -		rcu_assign_pointer(node->parent_bit, trie); | 
 | -		rcu_assign_pointer(*trie, node); | 
 | -	} else { | 
 | -		rcu_assign_pointer(node->parent_bit, &CHOOSE_NODE(parent, node->bits)); | 
 | -		rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node); | 
 | -	} | 
 | +	choose_and_connect_node(node, down); | 
 | +	choose_and_connect_node(node, newnode); | 
 | +	if (!parent) | 
 | +		connect_node(trie, 2, node); | 
 | +	else | 
 | +		choose_and_connect_node(parent, node); | 
 |  	return 0; | 
 |  } | 
 |   | 
 | @@ -297,7 +300,8 @@ int wg_allowedips_insert_v6(struct allow | 
 |  void wg_allowedips_remove_by_peer(struct allowedips *table, | 
 |  				  struct wg_peer *peer, struct mutex *lock) | 
 |  { | 
 | -	struct allowedips_node *node, *child, *tmp; | 
 | +	struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; | 
 | +	bool free_parent; | 
 |   | 
 |  	if (list_empty(&peer->allowedips_list)) | 
 |  		return; | 
 | @@ -307,19 +311,29 @@ void wg_allowedips_remove_by_peer(struct | 
 |  		RCU_INIT_POINTER(node->peer, NULL); | 
 |  		if (node->bit[0] && node->bit[1]) | 
 |  			continue; | 
 | -		child = rcu_dereference_protected( | 
 | -				node->bit[!rcu_access_pointer(node->bit[0])], | 
 | -				lockdep_is_held(lock)); | 
 | +		child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], | 
 | +						  lockdep_is_held(lock)); | 
 |  		if (child) | 
 | -			child->parent_bit = node->parent_bit; | 
 | -		*rcu_dereference_protected(node->parent_bit, lockdep_is_held(lock)) = child; | 
 | +			child->parent_bit_packed = node->parent_bit_packed; | 
 | +		parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); | 
 | +		*parent_bit = child; | 
 | +		parent = (void *)parent_bit - | 
 | +			 offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); | 
 | +		free_parent = !rcu_access_pointer(node->bit[0]) && | 
 | +			      !rcu_access_pointer(node->bit[1]) && | 
 | +			      (node->parent_bit_packed & 3) <= 1 && | 
 | +			      !rcu_access_pointer(parent->peer); | 
 | +		if (free_parent) | 
 | +			child = rcu_dereference_protected( | 
 | +					parent->bit[!(node->parent_bit_packed & 1)], | 
 | +					lockdep_is_held(lock)); | 
 |  		call_rcu(&node->rcu, node_free_rcu); | 
 | - | 
 | -		/* TODO: Note that we currently don't walk up and down in order to | 
 | -		 * free any potential filler nodes. This means that this function | 
 | -		 * doesn't free up as much as it could, which could be revisited | 
 | -		 * at some point. | 
 | -		 */ | 
 | +		if (!free_parent) | 
 | +			continue; | 
 | +		if (child) | 
 | +			child->parent_bit_packed = parent->parent_bit_packed; | 
 | +		*(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; | 
 | +		call_rcu(&parent->rcu, node_free_rcu); | 
 |  	} | 
 |  } | 
 |   | 
 | --- a/drivers/net/wireguard/allowedips.h | 
 | +++ b/drivers/net/wireguard/allowedips.h | 
 | @@ -19,7 +19,7 @@ struct allowedips_node { | 
 |  	u8 bits[16] __aligned(__alignof(u64)); | 
 |   | 
 |  	/* Keep rarely used members at bottom to be beyond cache line. */ | 
 | -	struct allowedips_node *__rcu *parent_bit; | 
 | +	unsigned long parent_bit_packed; | 
 |  	union { | 
 |  		struct list_head peer_list; | 
 |  		struct rcu_head rcu; | 
 | @@ -30,7 +30,7 @@ struct allowedips { | 
 |  	struct allowedips_node __rcu *root4; | 
 |  	struct allowedips_node __rcu *root6; | 
 |  	u64 seq; | 
 | -}; | 
 | +} __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 16-bit alignment. */ | 
 |   | 
 |  void wg_allowedips_init(struct allowedips *table); | 
 |  void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); | 
 | --- a/drivers/net/wireguard/selftest/allowedips.c | 
 | +++ b/drivers/net/wireguard/selftest/allowedips.c | 
 | @@ -19,32 +19,22 @@ | 
 |   | 
 |  #include <linux/siphash.h> | 
 |   | 
 | -static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, | 
 | -					      u8 cidr) | 
 | -{ | 
 | -	swap_endian(dst, src, bits); | 
 | -	memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8); | 
 | -	if (cidr) | 
 | -		dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8); | 
 | -} | 
 | - | 
 |  static __init void print_node(struct allowedips_node *node, u8 bits) | 
 |  { | 
 |  	char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n"; | 
 | -	char *fmt_declaration = KERN_DEBUG | 
 | -		"\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; | 
 | +	char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; | 
 | +	u8 ip1[16], ip2[16], cidr1, cidr2; | 
 |  	char *style = "dotted"; | 
 | -	u8 ip1[16], ip2[16]; | 
 |  	u32 color = 0; | 
 |   | 
 | +	if (node == NULL) | 
 | +		return; | 
 |  	if (bits == 32) { | 
 |  		fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n"; | 
 | -		fmt_declaration = KERN_DEBUG | 
 | -			"\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; | 
 | +		fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; | 
 |  	} else if (bits == 128) { | 
 |  		fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n"; | 
 | -		fmt_declaration = KERN_DEBUG | 
 | -			"\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; | 
 | +		fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; | 
 |  	} | 
 |  	if (node->peer) { | 
 |  		hsiphash_key_t key = { { 0 } }; | 
 | @@ -55,24 +45,20 @@ static __init void print_node(struct all | 
 |  			hsiphash_1u32(0xabad1dea, &key) % 200; | 
 |  		style = "bold"; | 
 |  	} | 
 | -	swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr); | 
 | -	printk(fmt_declaration, ip1, node->cidr, style, color); | 
 | +	wg_allowedips_read_node(node, ip1, &cidr1); | 
 | +	printk(fmt_declaration, ip1, cidr1, style, color); | 
 |  	if (node->bit[0]) { | 
 | -		swap_endian_and_apply_cidr(ip2, | 
 | -				rcu_dereference_raw(node->bit[0])->bits, bits, | 
 | -				node->cidr); | 
 | -		printk(fmt_connection, ip1, node->cidr, ip2, | 
 | -		       rcu_dereference_raw(node->bit[0])->cidr); | 
 | -		print_node(rcu_dereference_raw(node->bit[0]), bits); | 
 | +		wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2); | 
 | +		printk(fmt_connection, ip1, cidr1, ip2, cidr2); | 
 |  	} | 
 |  	if (node->bit[1]) { | 
 | -		swap_endian_and_apply_cidr(ip2, | 
 | -				rcu_dereference_raw(node->bit[1])->bits, | 
 | -				bits, node->cidr); | 
 | -		printk(fmt_connection, ip1, node->cidr, ip2, | 
 | -		       rcu_dereference_raw(node->bit[1])->cidr); | 
 | -		print_node(rcu_dereference_raw(node->bit[1]), bits); | 
 | +		wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2); | 
 | +		printk(fmt_connection, ip1, cidr1, ip2, cidr2); | 
 |  	} | 
 | +	if (node->bit[0]) | 
 | +		print_node(rcu_dereference_raw(node->bit[0]), bits); | 
 | +	if (node->bit[1]) | 
 | +		print_node(rcu_dereference_raw(node->bit[1]), bits); | 
 |  } | 
 |   | 
 |  static __init void print_tree(struct allowedips_node __rcu *top, u8 bits) | 
 | @@ -121,8 +107,8 @@ static __init inline union nf_inet_addr | 
 |  { | 
 |  	union nf_inet_addr mask; | 
 |   | 
 | -	memset(&mask, 0x00, 128 / 8); | 
 | -	memset(&mask, 0xff, cidr / 8); | 
 | +	memset(&mask, 0, sizeof(mask)); | 
 | +	memset(&mask.all, 0xff, cidr / 8); | 
 |  	if (cidr % 32) | 
 |  		mask.all[cidr / 32] = (__force u32)htonl( | 
 |  			(0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); | 
 | @@ -149,42 +135,36 @@ horrible_mask_self(struct horrible_allow | 
 |  } | 
 |   | 
 |  static __init inline bool | 
 | -horrible_match_v4(const struct horrible_allowedips_node *node, | 
 | -		  struct in_addr *ip) | 
 | +horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip) | 
 |  { | 
 |  	return (ip->s_addr & node->mask.ip) == node->ip.ip; | 
 |  } | 
 |   | 
 |  static __init inline bool | 
 | -horrible_match_v6(const struct horrible_allowedips_node *node, | 
 | -		  struct in6_addr *ip) | 
 | +horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip) | 
 |  { | 
 | -	return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == | 
 | -		       node->ip.ip6[0] && | 
 | -	       (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == | 
 | -		       node->ip.ip6[1] && | 
 | -	       (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == | 
 | -		       node->ip.ip6[2] && | 
 | +	return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] && | 
 | +	       (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] && | 
 | +	       (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] && | 
 |  	       (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; | 
 |  } | 
 |   | 
 |  static __init void | 
 | -horrible_insert_ordered(struct horrible_allowedips *table, | 
 | -			struct horrible_allowedips_node *node) | 
 | +horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node) | 
 |  { | 
 |  	struct horrible_allowedips_node *other = NULL, *where = NULL; | 
 |  	u8 my_cidr = horrible_mask_to_cidr(node->mask); | 
 |   | 
 |  	hlist_for_each_entry(other, &table->head, table) { | 
 | -		if (!memcmp(&other->mask, &node->mask, | 
 | -			    sizeof(union nf_inet_addr)) && | 
 | -		    !memcmp(&other->ip, &node->ip, | 
 | -			    sizeof(union nf_inet_addr)) && | 
 | -		    other->ip_version == node->ip_version) { | 
 | +		if (other->ip_version == node->ip_version && | 
 | +		    !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) && | 
 | +		    !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) { | 
 |  			other->value = node->value; | 
 |  			kfree(node); | 
 |  			return; | 
 |  		} | 
 | +	} | 
 | +	hlist_for_each_entry(other, &table->head, table) { | 
 |  		where = other; | 
 |  		if (horrible_mask_to_cidr(other->mask) <= my_cidr) | 
 |  			break; | 
 | @@ -201,8 +181,7 @@ static __init int | 
 |  horrible_allowedips_insert_v4(struct horrible_allowedips *table, | 
 |  			      struct in_addr *ip, u8 cidr, void *value) | 
 |  { | 
 | -	struct horrible_allowedips_node *node = kzalloc(sizeof(*node), | 
 | -							GFP_KERNEL); | 
 | +	struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); | 
 |   | 
 |  	if (unlikely(!node)) | 
 |  		return -ENOMEM; | 
 | @@ -219,8 +198,7 @@ static __init int | 
 |  horrible_allowedips_insert_v6(struct horrible_allowedips *table, | 
 |  			      struct in6_addr *ip, u8 cidr, void *value) | 
 |  { | 
 | -	struct horrible_allowedips_node *node = kzalloc(sizeof(*node), | 
 | -							GFP_KERNEL); | 
 | +	struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); | 
 |   | 
 |  	if (unlikely(!node)) | 
 |  		return -ENOMEM; | 
 | @@ -234,39 +212,43 @@ horrible_allowedips_insert_v6(struct hor | 
 |  } | 
 |   | 
 |  static __init void * | 
 | -horrible_allowedips_lookup_v4(struct horrible_allowedips *table, | 
 | -			      struct in_addr *ip) | 
 | +horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip) | 
 |  { | 
 |  	struct horrible_allowedips_node *node; | 
 | -	void *ret = NULL; | 
 |   | 
 |  	hlist_for_each_entry(node, &table->head, table) { | 
 | -		if (node->ip_version != 4) | 
 | -			continue; | 
 | -		if (horrible_match_v4(node, ip)) { | 
 | -			ret = node->value; | 
 | -			break; | 
 | -		} | 
 | +		if (node->ip_version == 4 && horrible_match_v4(node, ip)) | 
 | +			return node->value; | 
 |  	} | 
 | -	return ret; | 
 | +	return NULL; | 
 |  } | 
 |   | 
 |  static __init void * | 
 | -horrible_allowedips_lookup_v6(struct horrible_allowedips *table, | 
 | -			      struct in6_addr *ip) | 
 | +horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip) | 
 |  { | 
 |  	struct horrible_allowedips_node *node; | 
 | -	void *ret = NULL; | 
 |   | 
 |  	hlist_for_each_entry(node, &table->head, table) { | 
 | -		if (node->ip_version != 6) | 
 | +		if (node->ip_version == 6 && horrible_match_v6(node, ip)) | 
 | +			return node->value; | 
 | +	} | 
 | +	return NULL; | 
 | +} | 
 | + | 
 | + | 
 | +static __init void | 
 | +horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value) | 
 | +{ | 
 | +	struct horrible_allowedips_node *node; | 
 | +	struct hlist_node *h; | 
 | + | 
 | +	hlist_for_each_entry_safe(node, h, &table->head, table) { | 
 | +		if (node->value != value) | 
 |  			continue; | 
 | -		if (horrible_match_v6(node, ip)) { | 
 | -			ret = node->value; | 
 | -			break; | 
 | -		} | 
 | +		hlist_del(&node->table); | 
 | +		kfree(node); | 
 |  	} | 
 | -	return ret; | 
 | + | 
 |  } | 
 |   | 
 |  static __init bool randomized_test(void) | 
 | @@ -397,23 +379,33 @@ static __init bool randomized_test(void) | 
 |  		print_tree(t.root6, 128); | 
 |  	} | 
 |   | 
 | -	for (i = 0; i < NUM_QUERIES; ++i) { | 
 | -		prandom_bytes(ip, 4); | 
 | -		if (lookup(t.root4, 32, ip) != | 
 | -		    horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { | 
 | -			pr_err("allowedips random self-test: FAIL\n"); | 
 | -			goto free; | 
 | +	for (j = 0;; ++j) { | 
 | +		for (i = 0; i < NUM_QUERIES; ++i) { | 
 | +			prandom_bytes(ip, 4); | 
 | +			if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { | 
 | +				horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip); | 
 | +				pr_err("allowedips random v4 self-test: FAIL\n"); | 
 | +				goto free; | 
 | +			} | 
 | +			prandom_bytes(ip, 16); | 
 | +			if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { | 
 | +				pr_err("allowedips random v6 self-test: FAIL\n"); | 
 | +				goto free; | 
 | +			} | 
 |  		} | 
 | +		if (j >= NUM_PEERS) | 
 | +			break; | 
 | +		mutex_lock(&mutex); | 
 | +		wg_allowedips_remove_by_peer(&t, peers[j], &mutex); | 
 | +		mutex_unlock(&mutex); | 
 | +		horrible_allowedips_remove_by_value(&h, peers[j]); | 
 |  	} | 
 |   | 
 | -	for (i = 0; i < NUM_QUERIES; ++i) { | 
 | -		prandom_bytes(ip, 16); | 
 | -		if (lookup(t.root6, 128, ip) != | 
 | -		    horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { | 
 | -			pr_err("allowedips random self-test: FAIL\n"); | 
 | -			goto free; | 
 | -		} | 
 | +	if (t.root4 || t.root6) { | 
 | +		pr_err("allowedips random self-test removal: FAIL\n"); | 
 | +		goto free; | 
 |  	} | 
 | + | 
 |  	ret = true; | 
 |   | 
 |  free: |