b.liu | e958203 | 2025-04-17 19:18:16 +0800 | [diff] [blame] | 1 | From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 |
| 2 | From: "Jason A. Donenfeld" <Jason@zx2c4.com> |
| 3 | Date: Tue, 19 May 2020 22:49:30 -0600 |
| 4 | Subject: [PATCH] wireguard: noise: separate receive counter from send counter |
| 5 | |
| 6 | commit a9e90d9931f3a474f04bab782ccd9d77904941e9 upstream. |
| 7 | |
| 8 | In "wireguard: queueing: preserve flow hash across packet scrubbing", we |
| 9 | were required to slightly increase the size of the receive replay |
| 10 | counter to something still fairly small, but an increase nonetheless. |
| 11 | It turns out that we can recoup some of the additional memory overhead |
| 12 | by splitting up the prior union type into two distinct types. Before, we |
| 13 | used the same "noise_counter" union for both sending and receiving, with |
| 14 | sending just using a simple atomic64_t, while receiving used the full |
| 15 | replay counter checker. This meant that most of the memory being |
| 16 | allocated for the sending counter was being wasted. Since the old |
| 17 | "noise_counter" type increased in size in the prior commit, now is a |
| 18 | good time to split up that union type into a distinct "noise_replay_ |
| 19 | counter" for receiving and a boring atomic64_t for sending, each using |
| 20 | neither more nor less memory than required. |
| 21 | |
| 22 | Also, since sometimes the replay counter is accessed without |
| 23 | necessitating additional accesses to the bitmap, we can reduce cache |
| 24 | misses by hoisting the always-necessary lock above the bitmap in the |
| 25 | struct layout. We also change a "noise_replay_counter" stack allocation |
| 26 | to kmalloc in a -DDEBUG selftest so that KASAN doesn't trigger a stack |
| 27 | frame warning. |
| 28 | |
| 29 | All and all, removing a bit of abstraction in this commit makes the code |
| 30 | simpler and smaller, in addition to the motivating memory usage |
| 31 | recuperation. For example, passing around raw "noise_symmetric_key" |
| 32 | structs is something that really only makes sense within noise.c, in the |
| 33 | one place where the sending and receiving keys can safely be thought of |
| 34 | as the same type of object; subsequent to that, it's important that we |
| 35 | uniformly access these through keypair->{sending,receiving}, where their |
| 36 | distinct roles are always made explicit. So this patch allows us to draw |
| 37 | that distinction clearly as well. |
| 38 | |
| 39 | Fixes: e7096c131e51 ("net: WireGuard secure network tunnel") |
| 40 | Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com> |
| 41 | Signed-off-by: David S. Miller <davem@davemloft.net> |
| 42 | Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com> |
| 43 | --- |
| 44 | drivers/net/wireguard/noise.c | 16 +++------ |
| 45 | drivers/net/wireguard/noise.h | 14 ++++---- |
| 46 | drivers/net/wireguard/receive.c | 42 ++++++++++++------------ |
| 47 | drivers/net/wireguard/selftest/counter.c | 17 +++++++--- |
| 48 | drivers/net/wireguard/send.c | 12 +++---- |
| 49 | 5 files changed, 48 insertions(+), 53 deletions(-) |
| 50 | |
| 51 | --- a/drivers/net/wireguard/noise.c |
| 52 | +++ b/drivers/net/wireguard/noise.c |
| 53 | @@ -104,6 +104,7 @@ static struct noise_keypair *keypair_cre |
| 54 | |
| 55 | if (unlikely(!keypair)) |
| 56 | return NULL; |
| 57 | + spin_lock_init(&keypair->receiving_counter.lock); |
| 58 | keypair->internal_id = atomic64_inc_return(&keypair_counter); |
| 59 | keypair->entry.type = INDEX_HASHTABLE_KEYPAIR; |
| 60 | keypair->entry.peer = peer; |
| 61 | @@ -358,25 +359,16 @@ out: |
| 62 | memzero_explicit(output, BLAKE2S_HASH_SIZE + 1); |
| 63 | } |
| 64 | |
| 65 | -static void symmetric_key_init(struct noise_symmetric_key *key) |
| 66 | -{ |
| 67 | - spin_lock_init(&key->counter.receive.lock); |
| 68 | - atomic64_set(&key->counter.counter, 0); |
| 69 | - memset(key->counter.receive.backtrack, 0, |
| 70 | - sizeof(key->counter.receive.backtrack)); |
| 71 | - key->birthdate = ktime_get_coarse_boottime_ns(); |
| 72 | - key->is_valid = true; |
| 73 | -} |
| 74 | - |
| 75 | static void derive_keys(struct noise_symmetric_key *first_dst, |
| 76 | struct noise_symmetric_key *second_dst, |
| 77 | const u8 chaining_key[NOISE_HASH_LEN]) |
| 78 | { |
| 79 | + u64 birthdate = ktime_get_coarse_boottime_ns(); |
| 80 | kdf(first_dst->key, second_dst->key, NULL, NULL, |
| 81 | NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, |
| 82 | chaining_key); |
| 83 | - symmetric_key_init(first_dst); |
| 84 | - symmetric_key_init(second_dst); |
| 85 | + first_dst->birthdate = second_dst->birthdate = birthdate; |
| 86 | + first_dst->is_valid = second_dst->is_valid = true; |
| 87 | } |
| 88 | |
| 89 | static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], |
| 90 | --- a/drivers/net/wireguard/noise.h |
| 91 | +++ b/drivers/net/wireguard/noise.h |
| 92 | @@ -15,18 +15,14 @@ |
| 93 | #include <linux/mutex.h> |
| 94 | #include <linux/kref.h> |
| 95 | |
| 96 | -union noise_counter { |
| 97 | - struct { |
| 98 | - u64 counter; |
| 99 | - unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG]; |
| 100 | - spinlock_t lock; |
| 101 | - } receive; |
| 102 | - atomic64_t counter; |
| 103 | +struct noise_replay_counter { |
| 104 | + u64 counter; |
| 105 | + spinlock_t lock; |
| 106 | + unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG]; |
| 107 | }; |
| 108 | |
| 109 | struct noise_symmetric_key { |
| 110 | u8 key[NOISE_SYMMETRIC_KEY_LEN]; |
| 111 | - union noise_counter counter; |
| 112 | u64 birthdate; |
| 113 | bool is_valid; |
| 114 | }; |
| 115 | @@ -34,7 +30,9 @@ struct noise_symmetric_key { |
| 116 | struct noise_keypair { |
| 117 | struct index_hashtable_entry entry; |
| 118 | struct noise_symmetric_key sending; |
| 119 | + atomic64_t sending_counter; |
| 120 | struct noise_symmetric_key receiving; |
| 121 | + struct noise_replay_counter receiving_counter; |
| 122 | __le32 remote_index; |
| 123 | bool i_am_the_initiator; |
| 124 | struct kref refcount; |
| 125 | --- a/drivers/net/wireguard/receive.c |
| 126 | +++ b/drivers/net/wireguard/receive.c |
| 127 | @@ -245,20 +245,20 @@ static void keep_key_fresh(struct wg_pee |
| 128 | } |
| 129 | } |
| 130 | |
| 131 | -static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key) |
| 132 | +static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair) |
| 133 | { |
| 134 | struct scatterlist sg[MAX_SKB_FRAGS + 8]; |
| 135 | struct sk_buff *trailer; |
| 136 | unsigned int offset; |
| 137 | int num_frags; |
| 138 | |
| 139 | - if (unlikely(!key)) |
| 140 | + if (unlikely(!keypair)) |
| 141 | return false; |
| 142 | |
| 143 | - if (unlikely(!READ_ONCE(key->is_valid) || |
| 144 | - wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) || |
| 145 | - key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) { |
| 146 | - WRITE_ONCE(key->is_valid, false); |
| 147 | + if (unlikely(!READ_ONCE(keypair->receiving.is_valid) || |
| 148 | + wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) || |
| 149 | + keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) { |
| 150 | + WRITE_ONCE(keypair->receiving.is_valid, false); |
| 151 | return false; |
| 152 | } |
| 153 | |
| 154 | @@ -283,7 +283,7 @@ static bool decrypt_packet(struct sk_buf |
| 155 | |
| 156 | if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0, |
| 157 | PACKET_CB(skb)->nonce, |
| 158 | - key->key)) |
| 159 | + keypair->receiving.key)) |
| 160 | return false; |
| 161 | |
| 162 | /* Another ugly situation of pushing and pulling the header so as to |
| 163 | @@ -298,41 +298,41 @@ static bool decrypt_packet(struct sk_buf |
| 164 | } |
| 165 | |
| 166 | /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */ |
| 167 | -static bool counter_validate(union noise_counter *counter, u64 their_counter) |
| 168 | +static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter) |
| 169 | { |
| 170 | unsigned long index, index_current, top, i; |
| 171 | bool ret = false; |
| 172 | |
| 173 | - spin_lock_bh(&counter->receive.lock); |
| 174 | + spin_lock_bh(&counter->lock); |
| 175 | |
| 176 | - if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || |
| 177 | + if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 || |
| 178 | their_counter >= REJECT_AFTER_MESSAGES)) |
| 179 | goto out; |
| 180 | |
| 181 | ++their_counter; |
| 182 | |
| 183 | if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < |
| 184 | - counter->receive.counter)) |
| 185 | + counter->counter)) |
| 186 | goto out; |
| 187 | |
| 188 | index = their_counter >> ilog2(BITS_PER_LONG); |
| 189 | |
| 190 | - if (likely(their_counter > counter->receive.counter)) { |
| 191 | - index_current = counter->receive.counter >> ilog2(BITS_PER_LONG); |
| 192 | + if (likely(their_counter > counter->counter)) { |
| 193 | + index_current = counter->counter >> ilog2(BITS_PER_LONG); |
| 194 | top = min_t(unsigned long, index - index_current, |
| 195 | COUNTER_BITS_TOTAL / BITS_PER_LONG); |
| 196 | for (i = 1; i <= top; ++i) |
| 197 | - counter->receive.backtrack[(i + index_current) & |
| 198 | + counter->backtrack[(i + index_current) & |
| 199 | ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; |
| 200 | - counter->receive.counter = their_counter; |
| 201 | + counter->counter = their_counter; |
| 202 | } |
| 203 | |
| 204 | index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1; |
| 205 | ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), |
| 206 | - &counter->receive.backtrack[index]); |
| 207 | + &counter->backtrack[index]); |
| 208 | |
| 209 | out: |
| 210 | - spin_unlock_bh(&counter->receive.lock); |
| 211 | + spin_unlock_bh(&counter->lock); |
| 212 | return ret; |
| 213 | } |
| 214 | |
| 215 | @@ -472,12 +472,12 @@ int wg_packet_rx_poll(struct napi_struct |
| 216 | if (unlikely(state != PACKET_STATE_CRYPTED)) |
| 217 | goto next; |
| 218 | |
| 219 | - if (unlikely(!counter_validate(&keypair->receiving.counter, |
| 220 | + if (unlikely(!counter_validate(&keypair->receiving_counter, |
| 221 | PACKET_CB(skb)->nonce))) { |
| 222 | net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", |
| 223 | peer->device->dev->name, |
| 224 | PACKET_CB(skb)->nonce, |
| 225 | - keypair->receiving.counter.receive.counter); |
| 226 | + keypair->receiving_counter.counter); |
| 227 | goto next; |
| 228 | } |
| 229 | |
| 230 | @@ -511,8 +511,8 @@ void wg_packet_decrypt_worker(struct wor |
| 231 | struct sk_buff *skb; |
| 232 | |
| 233 | while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) { |
| 234 | - enum packet_state state = likely(decrypt_packet(skb, |
| 235 | - &PACKET_CB(skb)->keypair->receiving)) ? |
| 236 | + enum packet_state state = |
| 237 | + likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ? |
| 238 | PACKET_STATE_CRYPTED : PACKET_STATE_DEAD; |
| 239 | wg_queue_enqueue_per_peer_napi(skb, state); |
| 240 | if (need_resched()) |
| 241 | --- a/drivers/net/wireguard/selftest/counter.c |
| 242 | +++ b/drivers/net/wireguard/selftest/counter.c |
| 243 | @@ -6,18 +6,24 @@ |
| 244 | #ifdef DEBUG |
| 245 | bool __init wg_packet_counter_selftest(void) |
| 246 | { |
| 247 | + struct noise_replay_counter *counter; |
| 248 | unsigned int test_num = 0, i; |
| 249 | - union noise_counter counter; |
| 250 | bool success = true; |
| 251 | |
| 252 | -#define T_INIT do { \ |
| 253 | - memset(&counter, 0, sizeof(union noise_counter)); \ |
| 254 | - spin_lock_init(&counter.receive.lock); \ |
| 255 | + counter = kmalloc(sizeof(*counter), GFP_KERNEL); |
| 256 | + if (unlikely(!counter)) { |
| 257 | + pr_err("nonce counter self-test malloc: FAIL\n"); |
| 258 | + return false; |
| 259 | + } |
| 260 | + |
| 261 | +#define T_INIT do { \ |
| 262 | + memset(counter, 0, sizeof(*counter)); \ |
| 263 | + spin_lock_init(&counter->lock); \ |
| 264 | } while (0) |
| 265 | #define T_LIM (COUNTER_WINDOW_SIZE + 1) |
| 266 | #define T(n, v) do { \ |
| 267 | ++test_num; \ |
| 268 | - if (counter_validate(&counter, n) != (v)) { \ |
| 269 | + if (counter_validate(counter, n) != (v)) { \ |
| 270 | pr_err("nonce counter self-test %u: FAIL\n", \ |
| 271 | test_num); \ |
| 272 | success = false; \ |
| 273 | @@ -99,6 +105,7 @@ bool __init wg_packet_counter_selftest(v |
| 274 | |
| 275 | if (success) |
| 276 | pr_info("nonce counter self-tests: pass\n"); |
| 277 | + kfree(counter); |
| 278 | return success; |
| 279 | } |
| 280 | #endif |
| 281 | --- a/drivers/net/wireguard/send.c |
| 282 | +++ b/drivers/net/wireguard/send.c |
| 283 | @@ -129,7 +129,7 @@ static void keep_key_fresh(struct wg_pee |
| 284 | rcu_read_lock_bh(); |
| 285 | keypair = rcu_dereference_bh(peer->keypairs.current_keypair); |
| 286 | send = keypair && READ_ONCE(keypair->sending.is_valid) && |
| 287 | - (atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES || |
| 288 | + (atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES || |
| 289 | (keypair->i_am_the_initiator && |
| 290 | wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME))); |
| 291 | rcu_read_unlock_bh(); |
| 292 | @@ -349,7 +349,6 @@ void wg_packet_purge_staged_packets(stru |
| 293 | |
| 294 | void wg_packet_send_staged_packets(struct wg_peer *peer) |
| 295 | { |
| 296 | - struct noise_symmetric_key *key; |
| 297 | struct noise_keypair *keypair; |
| 298 | struct sk_buff_head packets; |
| 299 | struct sk_buff *skb; |
| 300 | @@ -369,10 +368,9 @@ void wg_packet_send_staged_packets(struc |
| 301 | rcu_read_unlock_bh(); |
| 302 | if (unlikely(!keypair)) |
| 303 | goto out_nokey; |
| 304 | - key = &keypair->sending; |
| 305 | - if (unlikely(!READ_ONCE(key->is_valid))) |
| 306 | + if (unlikely(!READ_ONCE(keypair->sending.is_valid))) |
| 307 | goto out_nokey; |
| 308 | - if (unlikely(wg_birthdate_has_expired(key->birthdate, |
| 309 | + if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate, |
| 310 | REJECT_AFTER_TIME))) |
| 311 | goto out_invalid; |
| 312 | |
| 313 | @@ -387,7 +385,7 @@ void wg_packet_send_staged_packets(struc |
| 314 | */ |
| 315 | PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb); |
| 316 | PACKET_CB(skb)->nonce = |
| 317 | - atomic64_inc_return(&key->counter.counter) - 1; |
| 318 | + atomic64_inc_return(&keypair->sending_counter) - 1; |
| 319 | if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES)) |
| 320 | goto out_invalid; |
| 321 | } |
| 322 | @@ -399,7 +397,7 @@ void wg_packet_send_staged_packets(struc |
| 323 | return; |
| 324 | |
| 325 | out_invalid: |
| 326 | - WRITE_ONCE(key->is_valid, false); |
| 327 | + WRITE_ONCE(keypair->sending.is_valid, false); |
| 328 | out_nokey: |
| 329 | wg_noise_keypair_put(keypair, false); |
| 330 | |