blob: 8c1508a2e241a3a94d4f077e856cfebdf46575d2 [file] [log] [blame]
b.liue9582032025-04-17 19:18:16 +08001// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4#include <linux/skmsg.h>
5#include <linux/filter.h>
6#include <linux/bpf.h>
7#include <linux/init.h>
8#include <linux/wait.h>
9
10#include <net/inet_common.h>
11#include <net/tls.h>
12
13static bool tcp_bpf_stream_read(const struct sock *sk)
14{
15 struct sk_psock *psock;
16 bool empty = true;
17
18 rcu_read_lock();
19 psock = sk_psock(sk);
20 if (likely(psock))
21 empty = list_empty(&psock->ingress_msg);
22 rcu_read_unlock();
23 return !empty;
24}
25
26static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
27 int flags, long timeo, int *err)
28{
29 DEFINE_WAIT_FUNC(wait, woken_wake_function);
30 int ret = 0;
31
32 if (!timeo)
33 return ret;
34
35 add_wait_queue(sk_sleep(sk), &wait);
36 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
37 ret = sk_wait_event(sk, &timeo,
38 !list_empty(&psock->ingress_msg) ||
39 !skb_queue_empty(&sk->sk_receive_queue), &wait);
40 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
41 remove_wait_queue(sk_sleep(sk), &wait);
42 return ret;
43}
44
45int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
46 struct msghdr *msg, int len, int flags)
47{
48 struct iov_iter *iter = &msg->msg_iter;
49 int peek = flags & MSG_PEEK;
50 struct sk_msg *msg_rx;
51 int i, copied = 0;
52
53 msg_rx = list_first_entry_or_null(&psock->ingress_msg,
54 struct sk_msg, list);
55
56 while (copied != len) {
57 struct scatterlist *sge;
58
59 if (unlikely(!msg_rx))
60 break;
61
62 i = msg_rx->sg.start;
63 do {
64 struct page *page;
65 int copy;
66
67 sge = sk_msg_elem(msg_rx, i);
68 copy = sge->length;
69 page = sg_page(sge);
70 if (copied + copy > len)
71 copy = len - copied;
72 copy = copy_page_to_iter(page, sge->offset, copy, iter);
73 if (!copy)
74 return copied ? copied : -EFAULT;
75
76 copied += copy;
77 if (likely(!peek)) {
78 sge->offset += copy;
79 sge->length -= copy;
80 if (!msg_rx->skb)
81 sk_mem_uncharge(sk, copy);
82 msg_rx->sg.size -= copy;
83
84 if (!sge->length) {
85 sk_msg_iter_var_next(i);
86 if (!msg_rx->skb)
87 put_page(page);
88 }
89 } else {
90 /* Lets not optimize peek case if copy_page_to_iter
91 * didn't copy the entire length lets just break.
92 */
93 if (copy != sge->length)
94 return copied;
95 sk_msg_iter_var_next(i);
96 }
97
98 if (copied == len)
99 break;
100 } while (i != msg_rx->sg.end);
101
102 if (unlikely(peek)) {
103 if (msg_rx == list_last_entry(&psock->ingress_msg,
104 struct sk_msg, list))
105 break;
106 msg_rx = list_next_entry(msg_rx, list);
107 continue;
108 }
109
110 msg_rx->sg.start = i;
111 if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
112 list_del(&msg_rx->list);
113 if (msg_rx->skb)
114 consume_skb(msg_rx->skb);
115 kfree(msg_rx);
116 }
117 msg_rx = list_first_entry_or_null(&psock->ingress_msg,
118 struct sk_msg, list);
119 }
120
121 return copied;
122}
123EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
124
125int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
126 int nonblock, int flags, int *addr_len)
127{
128 struct sk_psock *psock;
129 int copied, ret;
130
131 if (unlikely(flags & MSG_ERRQUEUE))
132 return inet_recv_error(sk, msg, len, addr_len);
133
134 psock = sk_psock_get(sk);
135 if (unlikely(!psock))
136 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
137 if (!skb_queue_empty(&sk->sk_receive_queue) &&
138 sk_psock_queue_empty(psock)) {
139 sk_psock_put(sk, psock);
140 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
141 }
142 lock_sock(sk);
143msg_bytes_ready:
144 copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
145 if (!copied) {
146 int data, err = 0;
147 long timeo;
148
149 timeo = sock_rcvtimeo(sk, nonblock);
150 data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
151 if (data) {
152 if (!sk_psock_queue_empty(psock))
153 goto msg_bytes_ready;
154 release_sock(sk);
155 sk_psock_put(sk, psock);
156 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
157 }
158 if (err) {
159 ret = err;
160 goto out;
161 }
162 copied = -EAGAIN;
163 }
164 ret = copied;
165out:
166 release_sock(sk);
167 sk_psock_put(sk, psock);
168 return ret;
169}
170
171static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
172 struct sk_msg *msg, u32 apply_bytes, int flags)
173{
174 bool apply = apply_bytes;
175 struct scatterlist *sge;
176 u32 size, copied = 0;
177 struct sk_msg *tmp;
178 int i, ret = 0;
179
180 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
181 if (unlikely(!tmp))
182 return -ENOMEM;
183
184 lock_sock(sk);
185 tmp->sg.start = msg->sg.start;
186 i = msg->sg.start;
187 do {
188 sge = sk_msg_elem(msg, i);
189 size = (apply && apply_bytes < sge->length) ?
190 apply_bytes : sge->length;
191 if (!sk_wmem_schedule(sk, size)) {
192 if (!copied)
193 ret = -ENOMEM;
194 break;
195 }
196
197 sk_mem_charge(sk, size);
198 sk_msg_xfer(tmp, msg, i, size);
199 copied += size;
200 if (sge->length)
201 get_page(sk_msg_page(tmp, i));
202 sk_msg_iter_var_next(i);
203 tmp->sg.end = i;
204 if (apply) {
205 apply_bytes -= size;
206 if (!apply_bytes) {
207 if (sge->length)
208 sk_msg_iter_var_prev(i);
209 break;
210 }
211 }
212 } while (i != msg->sg.end);
213
214 if (!ret) {
215 msg->sg.start = i;
216 sk_psock_queue_msg(psock, tmp);
217 sk_psock_data_ready(sk, psock);
218 } else {
219 sk_msg_free(sk, tmp);
220 kfree(tmp);
221 }
222
223 release_sock(sk);
224 return ret;
225}
226
227static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
228 int flags, bool uncharge)
229{
230 bool apply = apply_bytes;
231 struct scatterlist *sge;
232 struct page *page;
233 int size, ret = 0;
234 u32 off;
235
236 while (1) {
237 bool has_tx_ulp;
238
239 sge = sk_msg_elem(msg, msg->sg.start);
240 size = (apply && apply_bytes < sge->length) ?
241 apply_bytes : sge->length;
242 off = sge->offset;
243 page = sg_page(sge);
244
245 tcp_rate_check_app_limited(sk);
246retry:
247 has_tx_ulp = tls_sw_has_ctx_tx(sk);
248 if (has_tx_ulp) {
249 flags |= MSG_SENDPAGE_NOPOLICY;
250 ret = kernel_sendpage_locked(sk,
251 page, off, size, flags);
252 } else {
253 ret = do_tcp_sendpages(sk, page, off, size, flags);
254 }
255
256 if (ret <= 0)
257 return ret;
258 if (apply)
259 apply_bytes -= ret;
260 msg->sg.size -= ret;
261 sge->offset += ret;
262 sge->length -= ret;
263 if (uncharge)
264 sk_mem_uncharge(sk, ret);
265 if (ret != size) {
266 size -= ret;
267 off += ret;
268 goto retry;
269 }
270 if (!sge->length) {
271 put_page(page);
272 sk_msg_iter_next(msg, start);
273 sg_init_table(sge, 1);
274 if (msg->sg.start == msg->sg.end)
275 break;
276 }
277 if (apply && !apply_bytes)
278 break;
279 }
280
281 return 0;
282}
283
284static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
285 u32 apply_bytes, int flags, bool uncharge)
286{
287 int ret;
288
289 lock_sock(sk);
290 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
291 release_sock(sk);
292 return ret;
293}
294
295int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
296 u32 bytes, int flags)
297{
298 bool ingress = sk_msg_to_ingress(msg);
299 struct sk_psock *psock = sk_psock_get(sk);
300 int ret;
301
302 if (unlikely(!psock))
303 return -EPIPE;
304
305 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
306 tcp_bpf_push_locked(sk, msg, bytes, flags, false);
307 sk_psock_put(sk, psock);
308 return ret;
309}
310EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
311
312static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
313 struct sk_msg *msg, int *copied, int flags)
314{
315 bool cork = false, enospc = sk_msg_full(msg);
316 struct sock *sk_redir;
317 u32 tosend, origsize, sent, delta = 0;
318 u32 eval;
319 int ret;
320
321more_data:
322 if (psock->eval == __SK_NONE) {
323 /* Track delta in msg size to add/subtract it on SK_DROP from
324 * returned to user copied size. This ensures user doesn't
325 * get a positive return code with msg_cut_data and SK_DROP
326 * verdict.
327 */
328 delta = msg->sg.size;
329 psock->eval = sk_psock_msg_verdict(sk, psock, msg);
330 delta -= msg->sg.size;
331 }
332
333 if (msg->cork_bytes &&
334 msg->cork_bytes > msg->sg.size && !enospc) {
335 psock->cork_bytes = msg->cork_bytes - msg->sg.size;
336 if (!psock->cork) {
337 psock->cork = kzalloc(sizeof(*psock->cork),
338 GFP_ATOMIC | __GFP_NOWARN);
339 if (!psock->cork)
340 return -ENOMEM;
341 }
342 memcpy(psock->cork, msg, sizeof(*msg));
343 return 0;
344 }
345
346 tosend = msg->sg.size;
347 if (psock->apply_bytes && psock->apply_bytes < tosend)
348 tosend = psock->apply_bytes;
349 eval = __SK_NONE;
350
351 switch (psock->eval) {
352 case __SK_PASS:
353 ret = tcp_bpf_push(sk, msg, tosend, flags, true);
354 if (unlikely(ret)) {
355 *copied -= sk_msg_free(sk, msg);
356 break;
357 }
358 sk_msg_apply_bytes(psock, tosend);
359 break;
360 case __SK_REDIRECT:
361 sk_redir = psock->sk_redir;
362 sk_msg_apply_bytes(psock, tosend);
363 if (!psock->apply_bytes) {
364 /* Clean up before releasing the sock lock. */
365 eval = psock->eval;
366 psock->eval = __SK_NONE;
367 psock->sk_redir = NULL;
368 }
369 if (psock->cork) {
370 cork = true;
371 psock->cork = NULL;
372 }
373 release_sock(sk);
374
375 origsize = msg->sg.size;
376 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
377 sent = origsize - msg->sg.size;
378
379 if (eval == __SK_REDIRECT)
380 sock_put(sk_redir);
381
382 lock_sock(sk);
383 sk_mem_uncharge(sk, sent);
384 if (unlikely(ret < 0)) {
385 int free = sk_msg_free(sk, msg);
386
387 if (!cork)
388 *copied -= free;
389 }
390 if (cork) {
391 sk_msg_free(sk, msg);
392 kfree(msg);
393 msg = NULL;
394 ret = 0;
395 }
396 break;
397 case __SK_DROP:
398 default:
399 sk_msg_free(sk, msg);
400 sk_msg_apply_bytes(psock, tosend);
401 *copied -= (tosend + delta);
402 return -EACCES;
403 }
404
405 if (likely(!ret)) {
406 if (!psock->apply_bytes) {
407 psock->eval = __SK_NONE;
408 if (psock->sk_redir) {
409 sock_put(psock->sk_redir);
410 psock->sk_redir = NULL;
411 }
412 }
413 if (msg &&
414 msg->sg.data[msg->sg.start].page_link &&
415 msg->sg.data[msg->sg.start].length)
416 goto more_data;
417 }
418 return ret;
419}
420
421static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
422{
423 struct sk_msg tmp, *msg_tx = NULL;
424 int copied = 0, err = 0;
425 struct sk_psock *psock;
426 long timeo;
427 int flags;
428
429 /* Don't let internal do_tcp_sendpages() flags through */
430 flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED);
431 flags |= MSG_NO_SHARED_FRAGS;
432
433 psock = sk_psock_get(sk);
434 if (unlikely(!psock))
435 return tcp_sendmsg(sk, msg, size);
436
437 lock_sock(sk);
438 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
439 while (msg_data_left(msg)) {
440 bool enospc = false;
441 u32 copy, osize;
442
443 if (sk->sk_err) {
444 err = -sk->sk_err;
445 goto out_err;
446 }
447
448 copy = msg_data_left(msg);
449 if (!sk_stream_memory_free(sk))
450 goto wait_for_sndbuf;
451 if (psock->cork) {
452 msg_tx = psock->cork;
453 } else {
454 msg_tx = &tmp;
455 sk_msg_init(msg_tx);
456 }
457
458 osize = msg_tx->sg.size;
459 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
460 if (err) {
461 if (err != -ENOSPC)
462 goto wait_for_memory;
463 enospc = true;
464 copy = msg_tx->sg.size - osize;
465 }
466
467 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
468 copy);
469 if (err < 0) {
470 sk_msg_trim(sk, msg_tx, osize);
471 goto out_err;
472 }
473
474 copied += copy;
475 if (psock->cork_bytes) {
476 if (size > psock->cork_bytes)
477 psock->cork_bytes = 0;
478 else
479 psock->cork_bytes -= size;
480 if (psock->cork_bytes && !enospc)
481 goto out_err;
482 /* All cork bytes are accounted, rerun the prog. */
483 psock->eval = __SK_NONE;
484 psock->cork_bytes = 0;
485 }
486
487 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
488 if (unlikely(err < 0))
489 goto out_err;
490 continue;
491wait_for_sndbuf:
492 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
493wait_for_memory:
494 err = sk_stream_wait_memory(sk, &timeo);
495 if (err) {
496 if (msg_tx && msg_tx != psock->cork)
497 sk_msg_free(sk, msg_tx);
498 goto out_err;
499 }
500 }
501out_err:
502 if (err < 0)
503 err = sk_stream_error(sk, msg->msg_flags, err);
504 release_sock(sk);
505 sk_psock_put(sk, psock);
506 return copied > 0 ? copied : err;
507}
508
509static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
510 size_t size, int flags)
511{
512 struct sk_msg tmp, *msg = NULL;
513 int err = 0, copied = 0;
514 struct sk_psock *psock;
515 bool enospc = false;
516
517 psock = sk_psock_get(sk);
518 if (unlikely(!psock))
519 return tcp_sendpage(sk, page, offset, size, flags);
520
521 lock_sock(sk);
522 if (psock->cork) {
523 msg = psock->cork;
524 } else {
525 msg = &tmp;
526 sk_msg_init(msg);
527 }
528
529 /* Catch case where ring is full and sendpage is stalled. */
530 if (unlikely(sk_msg_full(msg)))
531 goto out_err;
532
533 sk_msg_page_add(msg, page, size, offset);
534 sk_mem_charge(sk, size);
535 copied = size;
536 if (sk_msg_full(msg))
537 enospc = true;
538 if (psock->cork_bytes) {
539 if (size > psock->cork_bytes)
540 psock->cork_bytes = 0;
541 else
542 psock->cork_bytes -= size;
543 if (psock->cork_bytes && !enospc)
544 goto out_err;
545 /* All cork bytes are accounted, rerun the prog. */
546 psock->eval = __SK_NONE;
547 psock->cork_bytes = 0;
548 }
549
550 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
551out_err:
552 release_sock(sk);
553 sk_psock_put(sk, psock);
554 return copied ? copied : err;
555}
556
557static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
558{
559 struct sk_psock_link *link;
560
561 while ((link = sk_psock_link_pop(psock))) {
562 sk_psock_unlink(sk, link);
563 sk_psock_free_link(link);
564 }
565}
566
567static void tcp_bpf_unhash(struct sock *sk)
568{
569 void (*saved_unhash)(struct sock *sk);
570 struct sk_psock *psock;
571
572 rcu_read_lock();
573 psock = sk_psock(sk);
574 if (unlikely(!psock)) {
575 rcu_read_unlock();
576 if (sk->sk_prot->unhash)
577 sk->sk_prot->unhash(sk);
578 return;
579 }
580
581 saved_unhash = psock->saved_unhash;
582 tcp_bpf_remove(sk, psock);
583 rcu_read_unlock();
584 saved_unhash(sk);
585}
586
587static void tcp_bpf_close(struct sock *sk, long timeout)
588{
589 void (*saved_close)(struct sock *sk, long timeout);
590 struct sk_psock *psock;
591
592 lock_sock(sk);
593 rcu_read_lock();
594 psock = sk_psock(sk);
595 if (unlikely(!psock)) {
596 rcu_read_unlock();
597 release_sock(sk);
598 return sk->sk_prot->close(sk, timeout);
599 }
600
601 saved_close = psock->saved_close;
602 tcp_bpf_remove(sk, psock);
603 rcu_read_unlock();
604 release_sock(sk);
605 saved_close(sk, timeout);
606}
607
608enum {
609 TCP_BPF_IPV4,
610 TCP_BPF_IPV6,
611 TCP_BPF_NUM_PROTS,
612};
613
614enum {
615 TCP_BPF_BASE,
616 TCP_BPF_TX,
617 TCP_BPF_NUM_CFGS,
618};
619
620static struct proto *tcpv6_prot_saved __read_mostly;
621static DEFINE_SPINLOCK(tcpv6_prot_lock);
622static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
623
624static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
625 struct proto *base)
626{
627 prot[TCP_BPF_BASE] = *base;
628 prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash;
629 prot[TCP_BPF_BASE].close = tcp_bpf_close;
630 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
631 prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read;
632
633 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
634 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
635 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
636}
637
638static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
639{
640 if (sk->sk_family == AF_INET6 &&
641 unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
642 spin_lock_bh(&tcpv6_prot_lock);
643 if (likely(ops != tcpv6_prot_saved)) {
644 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
645 smp_store_release(&tcpv6_prot_saved, ops);
646 }
647 spin_unlock_bh(&tcpv6_prot_lock);
648 }
649}
650
651static int __init tcp_bpf_v4_build_proto(void)
652{
653 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
654 return 0;
655}
656late_initcall(tcp_bpf_v4_build_proto);
657
658static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
659{
660 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
661 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
662
663 sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
664}
665
666static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
667{
668 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
669 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
670
671 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
672 * or added requiring sk_prot hook updates. We keep original saved
673 * hooks in this case.
674 */
675 sk->sk_prot = &tcp_bpf_prots[family][config];
676}
677
678static int tcp_bpf_assert_proto_ops(struct proto *ops)
679{
680 /* In order to avoid retpoline, we make assumptions when we call
681 * into ops if e.g. a psock is not present. Make sure they are
682 * indeed valid assumptions.
683 */
684 return ops->recvmsg == tcp_recvmsg &&
685 ops->sendmsg == tcp_sendmsg &&
686 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
687}
688
689void tcp_bpf_reinit(struct sock *sk)
690{
691 struct sk_psock *psock;
692
693 sock_owned_by_me(sk);
694
695 rcu_read_lock();
696 psock = sk_psock(sk);
697 tcp_bpf_reinit_sk_prot(sk, psock);
698 rcu_read_unlock();
699}
700
701int tcp_bpf_init(struct sock *sk)
702{
703 struct proto *ops = READ_ONCE(sk->sk_prot);
704 struct sk_psock *psock;
705
706 sock_owned_by_me(sk);
707
708 rcu_read_lock();
709 psock = sk_psock(sk);
710 if (unlikely(!psock || psock->sk_proto ||
711 tcp_bpf_assert_proto_ops(ops))) {
712 rcu_read_unlock();
713 return -EINVAL;
714 }
715 tcp_bpf_check_v6_needs_rebuild(sk, ops);
716 tcp_bpf_update_sk_prot(sk, psock);
717 rcu_read_unlock();
718 return 0;
719}