blob: 6b367d6f3e097b8c80318fc479ccaacfd32ecde0 [file] [log] [blame]
rjw1f884582022-01-06 17:20:42 +08001/*
2 * Copyright (c) 2014 Travis Geiselbrecht
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining
5 * a copy of this software and associated documentation files
6 * (the "Software"), to deal in the Software without restriction,
7 * including without limitation the rights to use, copy, modify, merge,
8 * publish, distribute, sublicense, and/or sell copies of the Software,
9 * and to permit persons to whom the Software is furnished to do so,
10 * subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice shall be
13 * included in all copies or substantial portions of the Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
18 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
19 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
20 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
21 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
22 */
23
24#include "minip-internal.h"
25
26#include <trace.h>
27#include <assert.h>
28#include <compiler.h>
29#include <stdlib.h>
30#include <err.h>
31#include <string.h>
32#include <sys/types.h>
33#include <lib/console.h>
34#include <lib/cbuf.h>
35#include <kernel/mutex.h>
36#include <kernel/semaphore.h>
37#include <arch/ops.h>
38#include <platform.h>
39
40#define LOCAL_TRACE 0
41
42typedef uint32_t ipv4_addr;
43
44typedef struct tcp_header {
45 uint16_t source_port;
46 uint16_t dest_port;
47 uint32_t seq_num;
48 uint32_t ack_num;
49 uint16_t length_flags;
50 uint16_t win_size;
51 uint16_t checksum;
52 uint16_t urg_pointer;
53} __PACKED tcp_header_t;
54
55typedef struct tcp_pseudo_header {
56 ipv4_addr source_addr;
57 ipv4_addr dest_addr;
58 uint8_t zero;
59 uint8_t protocol;
60 uint16_t tcp_length;
61} __PACKED tcp_pseudo_header_t;
62
63typedef struct tcp_mss_option {
64 uint8_t kind; /* 0x2 */
65 uint8_t len; /* 0x4 */
66 uint16_t mss;
67} __PACKED tcp_mss_option_t;
68
69typedef enum tcp_state {
70 STATE_CLOSED,
71 STATE_LISTEN,
72 STATE_SYN_SENT,
73 STATE_SYN_RCVD,
74 STATE_ESTABLISHED,
75 STATE_CLOSE_WAIT,
76 STATE_LAST_ACK,
77 STATE_CLOSING,
78 STATE_FIN_WAIT_1,
79 STATE_FIN_WAIT_2,
80 STATE_TIME_WAIT
81} tcp_state_t;
82
83typedef enum tcp_flags {
84 PKT_FIN = 1,
85 PKT_SYN = 2,
86 PKT_RST = 4,
87 PKT_PSH = 8,
88 PKT_ACK = 16,
89 PKT_URG = 32
90} tcp_flags_t;
91
92typedef struct tcp_socket {
93 struct list_node node;
94
95 mutex_t lock;
96 volatile int ref;
97
98 tcp_state_t state;
99 ipv4_addr local_ip;
100 ipv4_addr remote_ip;
101 uint16_t local_port;
102 uint16_t remote_port;
103
104 uint32_t mss;
105
106 /* rx */
107 uint32_t rx_win_size;
108 uint32_t rx_win_low;
109 uint32_t rx_win_high;
110 uint8_t *rx_buffer_raw;
111 cbuf_t rx_buffer;
112 event_t rx_event;
113 int rx_full_mss_count; // number of packets we have received in a row with a full mss
114 net_timer_t ack_delay_timer;
115
116 /* tx */
117 uint32_t tx_win_low; // low side of the acked window
118 uint32_t tx_win_high; // tx_win_low + their advertised window size
119 uint32_t tx_highest_seq; // highest sequence we have txed them
120 uint8_t *tx_buffer; // our outgoing buffer
121 uint32_t tx_buffer_size; // size of tx_buffer
122 uint32_t tx_buffer_offset; // offset into the buffer to append new data to
123 event_t tx_event;
124 net_timer_t retransmit_timer;
125
126 /* listen accept */
127 semaphore_t accept_sem;
128 struct tcp_socket *accepted;
129
130 net_timer_t time_wait_timer;
131} tcp_socket_t;
132
133#define DEFAULT_MSS (1460)
134#define DEFAULT_RX_WINDOW_SIZE (8192)
135#define DEFAULT_TX_BUFFER_SIZE (8192)
136
137#define RETRANSMIT_TIMEOUT (50)
138#define DELAYED_ACK_TIMEOUT (50)
139#define TIME_WAIT_TIMEOUT (60000) // 1 minute
140
141#define FORCE_TCP_CHECKSUM (false)
142
143#define SEQUENCE_GTE(a, b) ((int32_t)((a) - (b)) >= 0)
144#define SEQUENCE_LTE(a, b) ((int32_t)((a) - (b)) <= 0)
145#define SEQUENCE_GT(a, b) ((int32_t)((a) - (b)) > 0)
146#define SEQUENCE_LT(a, b) ((int32_t)((a) - (b)) < 0)
147
148static mutex_t tcp_socket_list_lock = MUTEX_INITIAL_VALUE(tcp_socket_list_lock);
149static struct list_node tcp_socket_list = LIST_INITIAL_VALUE(tcp_socket_list);
150
151static bool tcp_debug = false;
152
153/* local routines */
154static tcp_socket_t *lookup_socket(ipv4_addr remote_ip, ipv4_addr local_ip, uint16_t remote_port, uint16_t local_port);
155static void add_socket_to_list(tcp_socket_t *s);
156static void remove_socket_from_list(tcp_socket_t *s);
157static tcp_socket_t *create_tcp_socket(bool alloc_buffers);
158static status_t tcp_send(ipv4_addr dest_ip, uint16_t dest_port, ipv4_addr src_ip, uint16_t src_port, const void *buf,
159 size_t len, tcp_flags_t flags, const void *options, size_t options_length, uint32_t ack, uint32_t sequence, uint16_t window_size);
160static status_t tcp_socket_send(tcp_socket_t *s, const void *data, size_t len, tcp_flags_t flags, const void *options, size_t options_length, uint32_t sequence);
161static void handle_data(tcp_socket_t *s, const void *data, size_t len, uint32_t sequence);
162static void send_ack(tcp_socket_t *s);
163static void handle_ack(tcp_socket_t *s, uint32_t sequence, uint32_t win_size);
164static void handle_retransmit_timeout(void *_s);
165static void handle_time_wait_timeout(void *_s);
166static void handle_delayed_ack_timeout(void *_s);
167static void tcp_remote_close(tcp_socket_t *s);
168static void tcp_wakeup_waiters(tcp_socket_t *s);
169static void inc_socket_ref(tcp_socket_t *s);
170static bool dec_socket_ref(tcp_socket_t *s);
171
172static uint16_t cksum_pheader(const tcp_pseudo_header_t *pheader, const void *buf, size_t len)
173{
174 uint16_t checksum = ones_sum16(0, pheader, sizeof(*pheader));
175 return ~ones_sum16(checksum, buf, len);
176}
177
178__NO_INLINE static void dump_tcp_header(const tcp_header_t *header)
179{
180 printf("TCP: src_port %u, dest_port %u, seq %u, ack %u, win %u, flags %c%c%c%c%c%c\n",
181 ntohs(header->source_port), ntohs(header->dest_port), ntohl(header->seq_num), ntohl(header->ack_num),
182 ntohs(header->win_size),
183 (ntohs(header->length_flags) & PKT_FIN) ? 'F' : ' ',
184 (ntohs(header->length_flags) & PKT_SYN) ? 'S' : ' ',
185 (ntohs(header->length_flags) & PKT_RST) ? 'R' : ' ',
186 (ntohs(header->length_flags) & PKT_PSH) ? 'P' : ' ',
187 (ntohs(header->length_flags) & PKT_ACK) ? 'A' : ' ',
188 (ntohs(header->length_flags) & PKT_URG) ? 'U' : ' ');
189}
190
191static const char *tcp_state_to_string(tcp_state_t state)
192{
193 switch (state) {
194 default:
195 case STATE_CLOSED: return "CLOSED";
196 case STATE_LISTEN: return "LISTEN";
197 case STATE_SYN_SENT: return "SYN_SENT";
198 case STATE_SYN_RCVD: return "SYN_RCVD";
199 case STATE_ESTABLISHED: return "ESTABLISHED";
200 case STATE_CLOSE_WAIT: return "CLOSE_WAIT";
201 case STATE_LAST_ACK: return "LAST_ACK";
202 case STATE_CLOSING: return "CLOSING";
203 case STATE_FIN_WAIT_1: return "FIN_WAIT_1";
204 case STATE_FIN_WAIT_2: return "FIN_WAIT_2";
205 case STATE_TIME_WAIT: return "TIME_WAIT";
206 }
207}
208
209static void dump_socket(tcp_socket_t *s)
210{
211 printf("socket %p: state %d (%s), local 0x%x:%hu, remote 0x%x:%hu, ref %d\n",
212 s, s->state, tcp_state_to_string(s->state),
213 s->local_ip, s->local_port, s->remote_ip, s->remote_port, s->ref);
214 if (s->state == STATE_ESTABLISHED || s->state == STATE_CLOSE_WAIT) {
215 printf("\trx: wsize %u wlo %u whi %u (%u)\n",
216 s->rx_win_size, s->rx_win_low, s->rx_win_high,
217 s->rx_win_high - s->rx_win_low);
218 printf("\ttx: wlo %u whi %u (%u) highest_seq %u (%u) bufsize %u bufoff %u\n",
219 s->tx_win_low, s->tx_win_high, s->tx_win_high - s->tx_win_low,
220 s->tx_highest_seq, s->tx_highest_seq - s->tx_win_low,
221 s->tx_buffer_size, s->tx_buffer_offset);
222 }
223}
224
225static tcp_socket_t *lookup_socket(ipv4_addr remote_ip, ipv4_addr local_ip, uint16_t remote_port, uint16_t local_port)
226{
227 LTRACEF("remote ip 0x%x local ip 0x%x remote port %u local port %u\n", remote_ip, local_ip, remote_port, local_port);
228
229 mutex_acquire(&tcp_socket_list_lock);
230
231 /* XXX replace with something faster, like a hash table */
232 tcp_socket_t *s = NULL;
233 list_for_every_entry(&tcp_socket_list, s, tcp_socket_t, node) {
234 if (s->state == STATE_CLOSED || s->state == STATE_LISTEN) {
235 continue;
236 } else {
237 /* full check */
238 if (s->remote_ip == remote_ip &&
239 s->local_ip == local_ip &&
240 s->remote_port == remote_port &&
241 s->local_port == local_port) {
242 goto out;
243 }
244 }
245 }
246
247 /* walk the list again, looking only for listen matches */
248 list_for_every_entry(&tcp_socket_list, s, tcp_socket_t, node) {
249 if (s->state == STATE_LISTEN) {
250 /* sockets in listen state only care about local port */
251 if (s->local_port == local_port) {
252 goto out;
253 }
254 }
255 }
256
257 /* fall through case returns null */
258 s = NULL;
259
260out:
261 /* bump the ref before returning it */
262 if (s)
263 inc_socket_ref(s);
264
265 mutex_release(&tcp_socket_list_lock);
266
267 return s;
268}
269
270static void add_socket_to_list(tcp_socket_t *s)
271{
272 DEBUG_ASSERT(s);
273 DEBUG_ASSERT(s->ref > 0); // we should have implicitly bumped the ref when creating the socket
274
275 mutex_acquire(&tcp_socket_list_lock);
276
277 list_add_head(&tcp_socket_list, &s->node);
278
279 mutex_release(&tcp_socket_list_lock);
280}
281
282static void remove_socket_from_list(tcp_socket_t *s)
283{
284 DEBUG_ASSERT(s);
285 DEBUG_ASSERT(s->ref > 0);
286
287 mutex_acquire(&tcp_socket_list_lock);
288
289 DEBUG_ASSERT(list_in_list(&s->node));
290 list_delete(&s->node);
291
292 mutex_release(&tcp_socket_list_lock);
293}
294
295static void inc_socket_ref(tcp_socket_t *s)
296{
297 DEBUG_ASSERT(s);
298
299 __UNUSED int oldval = atomic_add(&s->ref, 1);
300 LTRACEF("caller %p, thread %p, socket %p, ref now %d\n", __GET_CALLER(), get_current_thread(), s, oldval + 1);
301 DEBUG_ASSERT(oldval > 0);
302}
303
304static bool dec_socket_ref(tcp_socket_t *s)
305{
306 DEBUG_ASSERT(s);
307
308 int oldval = atomic_add(&s->ref, -1);
309 LTRACEF("caller %p, thread %p, socket %p, ref now %d\n", __GET_CALLER(), get_current_thread(), s, oldval - 1);
310
311 if (oldval == 1) {
312 LTRACEF("destroying socket\n");
313 event_destroy(&s->tx_event);
314 event_destroy(&s->rx_event);
315
316 free(s->rx_buffer_raw);
317 free(s->tx_buffer);
318
319 free(s);
320 }
321 return (oldval == 1);
322}
323
324static void tcp_timer_set(tcp_socket_t *s, net_timer_t *timer, net_timer_callback_t cb, lk_time_t delay)
325{
326 DEBUG_ASSERT(s);
327 DEBUG_ASSERT(timer);
328
329 if (net_timer_set(timer, cb, s, delay))
330 inc_socket_ref(s);
331}
332
333static void tcp_timer_cancel(tcp_socket_t *s, net_timer_t *timer)
334{
335
336 DEBUG_ASSERT(s);
337 DEBUG_ASSERT(timer);
338
339 if (net_timer_cancel(timer))
340 dec_socket_ref(s);
341}
342
343void tcp_input(pktbuf_t *p, uint32_t src_ip, uint32_t dst_ip)
344{
345 if (unlikely(tcp_debug))
346 TRACEF("p %p (len %u), src_ip 0x%x, dst_ip 0x%x\n", p, p->dlen, src_ip, dst_ip);
347
348 tcp_header_t *header = (tcp_header_t *)p->data;
349
350 /* reject if too small */
351 if (p->dlen < sizeof(tcp_header_t))
352 return;
353
354 if (unlikely(tcp_debug) || LOCAL_TRACE) {
355 dump_tcp_header(header);
356 }
357
358 /* compute the actual header length (+ options) */
359 size_t header_len = ((ntohs(header->length_flags) >> 12) & 0xf) * 4;
360 if (p->dlen < header_len) {
361 TRACEF("REJECT: packet too large for buffer\n");
362 return;
363 }
364
365 /* checksum */
366 if (FORCE_TCP_CHECKSUM || (p->flags & PKTBUF_FLAG_CKSUM_TCP_GOOD) == 0) {
367 tcp_pseudo_header_t pheader;
368
369 // set up the pseudo header for checksum purposes
370 pheader.source_addr = src_ip;
371 pheader.dest_addr = dst_ip;
372 pheader.zero = 0;
373 pheader.protocol = IP_PROTO_TCP;
374 pheader.tcp_length = htons(p->dlen);
375
376 uint16_t checksum = cksum_pheader(&pheader, p->data, p->dlen);
377 if(checksum != 0) {
378 TRACEF("REJECT: failed checksum, header says 0x%x, we got 0x%x\n", header->checksum, checksum);
379 return;
380 }
381 }
382
383 /* byte swap header in place */
384 header->source_port = ntohs(header->source_port);
385 header->dest_port = ntohs(header->dest_port);
386 header->seq_num = ntohl(header->seq_num);
387 header->ack_num = ntohl(header->ack_num);
388 header->length_flags = ntohs(header->length_flags);
389 header->win_size = ntohs(header->win_size);
390 header->urg_pointer = ntohs(header->urg_pointer);
391
392 /* get some data from the packet */
393 uint8_t packet_flags = header->length_flags & 0x3f;
394 size_t data_len = p->dlen - header_len;
395 uint32_t highest_sequence = header->seq_num + ((data_len > 0) ? (data_len - 1) : 0);
396
397 /* see if it matches a socket we have */
398 tcp_socket_t *s = lookup_socket(src_ip, dst_ip, header->source_port, header->dest_port);
399 if (!s) {
400 /* send a RST packet */
401 goto send_reset;
402 }
403
404 if (unlikely(tcp_debug))
405 TRACEF("got socket %p, state %d (%s), ref %d\n", s, s->state, tcp_state_to_string(s->state), s->ref);
406
407 /* remove the header */
408 pktbuf_consume(p, header_len);
409
410 mutex_acquire(&s->lock);
411
412 /* check to see if they're resetting us */
413 if (packet_flags & PKT_RST) {
414 if (s->state != STATE_CLOSED && s->state != STATE_LISTEN) {
415 tcp_remote_close(s);
416 }
417 goto done;
418 }
419
420 switch (s->state) {
421 case STATE_CLOSED:
422 /* socket closed, send RST */
423 goto send_reset;
424
425 /* passive connect states */
426 case STATE_LISTEN: {
427 /* we're in listen and they want to talk to us */
428 if (!(packet_flags & PKT_SYN)) {
429 /* not a SYN, send RST */
430 goto send_reset;
431 }
432
433 /* see if we have a slot to accept */
434 if (s->accepted != NULL)
435 goto done;
436
437 /* make a new accept socket */
438 tcp_socket_t *accept_socket = create_tcp_socket(true);
439 if (!accept_socket)
440 goto done;
441
442 /* set it up */
443 accept_socket->local_ip = minip_get_ipaddr();
444 accept_socket->local_port = s->local_port;
445 accept_socket->remote_ip = src_ip;
446 accept_socket->remote_port = header->source_port;
447 accept_socket->state = STATE_SYN_RCVD;
448
449 mutex_acquire(&accept_socket->lock);
450
451 add_socket_to_list(accept_socket);
452
453 /* remember their sequence */
454 accept_socket->rx_win_low = header->seq_num + 1;
455 accept_socket->rx_win_high = accept_socket->rx_win_low + accept_socket->rx_win_size - 1;
456
457 /* save this socket and wake anyone up that is waiting to accept */
458 s->accepted = accept_socket;
459 sem_post(&s->accept_sem, true);
460
461 /* set up a mss option for sending back */
462 tcp_mss_option_t mss_option;
463 mss_option.kind = 0x2;
464 mss_option.len = 0x4;
465 mss_option.mss = ntohs(s->mss); // XXX make sure we fit in their mss
466
467 /* send a response */
468 tcp_socket_send(accept_socket, NULL, 0, PKT_ACK|PKT_SYN, &mss_option, sizeof(mss_option),
469 accept_socket->tx_win_low);
470
471 /* SYN consumed a sequence */
472 accept_socket->tx_win_low++;
473
474 mutex_release(&accept_socket->lock);
475 break;
476 }
477 case STATE_SYN_RCVD:
478 if (packet_flags & PKT_SYN) {
479 /* they must have not seen our ack of their original syn, retransmit */
480 // XXX implement
481 goto send_reset;
482 }
483
484 /* if they ack our SYN, we can move on to ESTABLISHED */
485 if (packet_flags & PKT_ACK) {
486 if (header->ack_num != s->tx_win_low) {
487 goto send_reset;
488 }
489
490 s->tx_win_high = s->tx_win_low + header->win_size;
491 s->tx_highest_seq = s->tx_win_low;
492
493 s->state = STATE_ESTABLISHED;
494 } else {
495 goto send_reset;
496 }
497
498 break;
499
500 case STATE_ESTABLISHED:
501 if (packet_flags & PKT_ACK) {
502 /* they're acking us */
503 handle_ack(s, header->ack_num, header->win_size);
504 }
505
506 if (data_len > 0) {
507 LTRACEF("new data, len %zu\n", data_len);
508 handle_data(s, p->data, p->dlen, header->seq_num);
509 }
510
511 if ((packet_flags & PKT_FIN) && SEQUENCE_GTE(s->rx_win_low, highest_sequence)) {
512 /* they're closing with us, and there's no outstanding data */
513
514 /* FIN consumed a sequence */
515 s->rx_win_low++;
516
517 /* ack them and transition to new state */
518 send_ack(s);
519 s->state = STATE_CLOSE_WAIT;
520
521 /* wake up any read waiters */
522 event_signal(&s->rx_event, true);
523 }
524 break;
525
526 case STATE_CLOSE_WAIT:
527 if (packet_flags & PKT_ACK) {
528 /* they're acking us */
529 handle_ack(s, header->ack_num, header->win_size);
530 }
531 if (packet_flags & PKT_FIN) {
532 /* they must have missed our ack, ack them again */
533 send_ack(s);
534 }
535 break;
536 case STATE_LAST_ACK:
537 if (packet_flags & PKT_ACK) {
538 /* they're acking our FIN, probably */
539 tcp_remote_close(s);
540
541 /* tcp_close() was already called on us, remove us from the list and drop the ref */
542 remove_socket_from_list(s);
543 dec_socket_ref(s);
544 }
545 break;
546 case STATE_FIN_WAIT_1:
547 if (packet_flags & PKT_ACK) {
548 /* they're acking our FIN, probably */
549 s->state = STATE_FIN_WAIT_2;
550 /* drop into fin_wait_2 state logic, in case they were FINning us too */
551 goto fin_wait_2;
552 } else if (packet_flags & PKT_FIN) {
553 /* simultaneous close. they finned us without acking our fin */
554 s->rx_win_low++;
555 send_ack(s);
556 s->state = STATE_CLOSING;
557 }
558 break;
559 case STATE_FIN_WAIT_2:
560fin_wait_2:
561 if (packet_flags & PKT_FIN) {
562 /* they're FINning us, ack them */
563 s->rx_win_low++;
564 send_ack(s);
565 s->state = STATE_TIME_WAIT;
566
567 /* set timed wait timer */
568 tcp_timer_set(s, &s->time_wait_timer, &handle_time_wait_timeout, TIME_WAIT_TIMEOUT);
569 }
570 break;
571 case STATE_CLOSING:
572 if (packet_flags & PKT_ACK) {
573 /* they're acking our FIN, probably */
574 s->state = STATE_TIME_WAIT;
575
576 /* set timed wait timer */
577 tcp_timer_set(s, &s->time_wait_timer, &handle_time_wait_timeout, TIME_WAIT_TIMEOUT);
578 }
579 break;
580 case STATE_TIME_WAIT:
581 /* /dev/null of packets */
582 break;
583
584 case STATE_SYN_SENT:
585 PANIC_UNIMPLEMENTED;
586 }
587
588done:
589 mutex_release(&s->lock);
590 dec_socket_ref(s);
591 return;
592
593send_reset:
594 if (s) {
595 mutex_release(&s->lock);
596 dec_socket_ref(s);
597 }
598
599 LTRACEF("SEND RST\n");
600 if (!(packet_flags & PKT_RST)) {
601 tcp_send(src_ip, header->source_port, dst_ip, header->dest_port,
602 NULL, 0, PKT_RST, NULL, 0, 0, header->ack_num, 0);
603 }
604}
605
606static void handle_data(tcp_socket_t *s, const void *data, size_t len, uint32_t sequence)
607{
608 if (unlikely(tcp_debug))
609 TRACEF("data %p, len %zu, sequence %u\n", data, len, sequence);
610
611 DEBUG_ASSERT(s);
612 DEBUG_ASSERT(is_mutex_held(&s->lock));
613 DEBUG_ASSERT(data);
614 DEBUG_ASSERT(len > 0);
615
616 /* see if it matches our current window */
617 uint32_t sequence_top = sequence + len - 1;
618 if (SEQUENCE_LTE(sequence, s->rx_win_low) && SEQUENCE_GTE(sequence_top, s->rx_win_low)) {
619 /* it intersects the bottom of our window, so it's in order */
620
621 /* copy the data we need to our cbuf */
622 size_t offset = sequence - s->rx_win_low;
623 size_t copy_len = MIN(s->rx_win_high - s->rx_win_low, len - offset);
624
625 DEBUG_ASSERT(offset < len);
626
627 LTRACEF("copying from offset %zu, len %zu\n", offset, copy_len);
628
629 s->rx_win_low += copy_len;
630
631 cbuf_write(&s->rx_buffer, (uint8_t *)data + offset, copy_len, false);
632 event_signal(&s->rx_event, true);
633
634 /* keep a counter if they've been sending a full mss */
635 if (copy_len >= s->mss) {
636 s->rx_full_mss_count++;
637 } else {
638 s->rx_full_mss_count = 0;
639 }
640
641 /* immediately ack if we're more than halfway into our buffer or they've sent 2 or more full packets */
642 if (s->rx_full_mss_count >= 2 ||
643 (int)(s->rx_win_low + s->rx_win_size - s->rx_win_high) > (int)s->rx_win_size / 2) {
644 send_ack(s);
645 s->rx_full_mss_count = 0;
646 } else {
647 tcp_timer_set(s, &s->ack_delay_timer, &handle_delayed_ack_timeout, DELAYED_ACK_TIMEOUT);
648 }
649 } else {
650 // either out of order or completely out of our window, drop
651 // duplicately ack the last thing we really got
652 send_ack(s);
653 }
654}
655
656static status_t tcp_socket_send(tcp_socket_t *s, const void *data, size_t len, tcp_flags_t flags,
657 const void *options, size_t options_length, uint32_t sequence)
658{
659 DEBUG_ASSERT(s);
660 DEBUG_ASSERT(is_mutex_held(&s->lock));
661 DEBUG_ASSERT(len == 0 || data);
662 DEBUG_ASSERT(options_length == 0 || options);
663 DEBUG_ASSERT((options_length % 4) == 0);
664
665 // calculate the new right edge of the rx window
666 uint32_t rx_win_high = s->rx_win_low + s->rx_win_size - cbuf_space_used(&s->rx_buffer) - 1;
667
668 LTRACEF("rx_win_low %u rx_win_size %u read_buf_len %zu, new win high %u\n",
669 s->rx_win_low, s->rx_win_size, cbuf_space_used(&s->rx_buffer), rx_win_high);
670
671 uint16_t win_size;
672 if (SEQUENCE_GTE(rx_win_high, s->rx_win_high)) {
673 s->rx_win_high = rx_win_high;
674 win_size = rx_win_high - s->rx_win_low;
675 } else {
676 // the window size has shrunk, but we can't move the
677 // right edge of the window backwards
678 win_size = s->rx_win_high - s->rx_win_low;
679 }
680
681 // we are piggybacking a pending ACK, so clear the delayed ACK timer
682 if (flags & PKT_ACK) {
683 tcp_timer_cancel(s, &s->ack_delay_timer);
684 }
685
686 status_t err = tcp_send(s->remote_ip, s->remote_port, s->local_ip, s->local_port, data, len, flags,
687 options, options_length, (flags & PKT_ACK) ? s->rx_win_low : 0, sequence, win_size);
688
689 return err;
690}
691
692static void send_ack(tcp_socket_t *s)
693{
694 DEBUG_ASSERT(s);
695 DEBUG_ASSERT(is_mutex_held(&s->lock));
696
697 if (s->state != STATE_ESTABLISHED && s->state != STATE_CLOSE_WAIT && s->state != STATE_FIN_WAIT_2)
698 return;
699
700 tcp_socket_send(s, NULL, 0, PKT_ACK, NULL, 0, s->tx_win_low);
701}
702
703static status_t tcp_send(ipv4_addr dest_ip, uint16_t dest_port, ipv4_addr src_ip, uint16_t src_port, const void *buf,
704 size_t len, tcp_flags_t flags, const void *options, size_t options_length, uint32_t ack, uint32_t sequence, uint16_t window_size)
705{
706 DEBUG_ASSERT(len == 0 || buf);
707 DEBUG_ASSERT(options_length == 0 || options);
708 DEBUG_ASSERT((options_length % 4) == 0);
709
710 pktbuf_t *p = pktbuf_alloc();
711 if (!p)
712 return ERR_NO_MEMORY;
713
714 tcp_header_t *header = pktbuf_prepend(p, sizeof(tcp_header_t) + options_length);
715 DEBUG_ASSERT(header);
716
717 /* fill in the header */
718 header->source_port = htons(src_port);
719 header->dest_port = htons(dest_port);
720 header->seq_num = htonl(sequence);
721 header->ack_num = htonl(ack);
722 header->length_flags = htons(((sizeof(tcp_header_t) + options_length) / 4) << 12 | flags);
723 header->win_size = htons(window_size);
724 header->checksum = 0;
725 header->urg_pointer = 0;
726 if (options)
727 memcpy(header + 1, options, options_length);
728
729 /* append the data */
730 if (len > 0)
731 pktbuf_append_data(p, buf, len);
732
733 /* compute the checksum */
734 /* XXX get the tx ckecksum capability from the nic */
735 if (FORCE_TCP_CHECKSUM || true) {
736 tcp_pseudo_header_t pheader;
737 pheader.source_addr = src_ip;
738 pheader.dest_addr = dest_ip;
739 pheader.zero = 0;
740 pheader.protocol = IP_PROTO_TCP;
741 pheader.tcp_length = htons(p->dlen);
742
743 header->checksum = cksum_pheader(&pheader, p->data, p->dlen);
744 }
745
746 if (LOCAL_TRACE) {
747 printf("sending ");
748 dump_tcp_header(header);
749 }
750
751 status_t err = minip_ipv4_send(p, dest_ip, IP_PROTO_TCP);
752
753 return err;
754}
755
756static void handle_ack(tcp_socket_t *s, uint32_t sequence, uint32_t win_size)
757{
758 LTRACEF("socket %p ack sequence %u, win_size %u\n", s, sequence, win_size);
759
760 DEBUG_ASSERT(s);
761 DEBUG_ASSERT(is_mutex_held(&s->lock));
762
763 LTRACEF("s %p, tx_win_low %u tx_win_high %u tx_highest_seq %u bufsize %u offset %u\n",
764 s, s->tx_win_low, s->tx_win_high, s->tx_highest_seq, s->tx_buffer_size, s->tx_buffer_offset);
765 if (SEQUENCE_LTE(sequence, s->tx_win_low)) {
766 /* they're acking stuff we've already received an ack for */
767 return;
768 } else if (SEQUENCE_GT(sequence, s->tx_highest_seq)) {
769 /* they're acking stuff we haven't sent */
770 return;
771 } else {
772 /* their ack is somewhere in our window */
773 uint32_t acked_len;
774
775 acked_len = (sequence - s->tx_win_low);
776
777 LTRACEF("acked len %u\n", acked_len);
778
779 DEBUG_ASSERT(acked_len <= s->tx_buffer_size);
780 DEBUG_ASSERT(acked_len <= s->tx_buffer_offset);
781
782 memmove(s->tx_buffer, s->tx_buffer + acked_len, s->tx_buffer_offset - acked_len);
783
784 s->tx_buffer_offset -= acked_len;
785 s->tx_win_low += acked_len;
786 s->tx_win_high = s->tx_win_low + win_size;
787
788 /* cancel or reset our retransmit timer */
789 if (s->tx_win_low == s->tx_highest_seq) {
790 tcp_timer_cancel(s, &s->retransmit_timer);
791 } else {
792 tcp_timer_set(s, &s->retransmit_timer, &handle_retransmit_timeout, RETRANSMIT_TIMEOUT);
793 }
794
795 /* we have opened the transmit buffer */
796 event_signal(&s->tx_event, true);
797 }
798}
799
800static ssize_t tcp_write_pending_data(tcp_socket_t *s)
801{
802 LTRACEF("s %p, tx_win_low %u tx_win_high %u tx_highest_seq %u bufsize %u offset %u\n",
803 s, s->tx_win_low, s->tx_win_high, s->tx_highest_seq, s->tx_buffer_size, s->tx_buffer_offset);
804
805 DEBUG_ASSERT(s);
806 DEBUG_ASSERT(is_mutex_held(&s->lock));
807 DEBUG_ASSERT(s->tx_buffer_size > 0);
808 DEBUG_ASSERT(s->tx_buffer_offset <= s->tx_buffer_size);
809
810 /* do we have any new data to send? */
811 uint32_t outstanding = (s->tx_highest_seq - s->tx_win_low);
812 uint32_t pending = s->tx_buffer_offset - outstanding;
813 LTRACEF("outstanding %u, pending %u\n", outstanding, pending);
814
815 /* send packets that cover the pending area of the window */
816 uint32_t offset = 0;
817 while (offset < pending) {
818 uint32_t tosend = MIN(s->mss, pending - offset);
819
820 tcp_socket_send(s, s->tx_buffer + outstanding + offset, tosend, PKT_ACK|PKT_PSH, NULL, 0, s->tx_highest_seq);
821 s->tx_highest_seq += tosend;
822 offset += tosend;
823 }
824
825 /* reset the retransmit timer if we sent anything */
826 if (offset > 0) {
827 tcp_timer_set(s, &s->retransmit_timer, &handle_retransmit_timeout, RETRANSMIT_TIMEOUT);
828 }
829
830 return offset;
831}
832
833static ssize_t tcp_retransmit(tcp_socket_t *s)
834{
835 DEBUG_ASSERT(s);
836 DEBUG_ASSERT(is_mutex_held(&s->lock));
837
838 if (s->state != STATE_ESTABLISHED && s->state != STATE_CLOSE_WAIT)
839 return 0;
840
841 /* how much data have we sent but not gotten an ack for? */
842 uint32_t outstanding = (s->tx_highest_seq - s->tx_win_low);
843 if (outstanding == 0)
844 return 0;
845
846 uint32_t tosend = MIN(s->mss, outstanding);
847
848 LTRACEF("s %p, tosend %u seq %u\n", s, tosend, s->tx_win_low);
849 tcp_socket_send(s, s->tx_buffer, tosend, PKT_ACK|PKT_PSH, NULL, 0, s->tx_win_low);
850
851 return tosend;
852}
853
854static void handle_retransmit_timeout(void *_s)
855{
856 tcp_socket_t *s = _s;
857
858 LTRACEF("s %p\n", s);
859
860 DEBUG_ASSERT(s);
861
862 mutex_acquire(&s->lock);
863
864 if (tcp_retransmit(s) == 0)
865 goto done;
866
867 tcp_timer_set(s, &s->retransmit_timer, &handle_retransmit_timeout, RETRANSMIT_TIMEOUT);
868
869done:
870 mutex_release(&s->lock);
871 dec_socket_ref(s);
872}
873
874static void handle_delayed_ack_timeout(void *_s)
875{
876 tcp_socket_t *s = _s;
877
878 LTRACEF("s %p\n", s);
879
880 DEBUG_ASSERT(s);
881
882 mutex_acquire(&s->lock);
883 send_ack(s);
884 mutex_release(&s->lock);
885 dec_socket_ref(s);
886}
887
888static void handle_time_wait_timeout(void *_s)
889{
890 tcp_socket_t *s = _s;
891
892 LTRACEF("s %p\n", s);
893
894 DEBUG_ASSERT(s);
895
896 mutex_acquire(&s->lock);
897
898 DEBUG_ASSERT(s->state == STATE_TIME_WAIT);
899
900 /* remove us from the list and drop the last ref */
901 remove_socket_from_list(s);
902 dec_socket_ref(s);
903
904 mutex_release(&s->lock);
905 dec_socket_ref(s);
906}
907
908static void tcp_wakeup_waiters(tcp_socket_t *s)
909{
910 DEBUG_ASSERT(s);
911 DEBUG_ASSERT(is_mutex_held(&s->lock));
912
913 // wake up any waiters
914 event_signal(&s->rx_event, true);
915 event_signal(&s->tx_event, true);
916}
917
918static void tcp_remote_close(tcp_socket_t *s)
919{
920 LTRACEF("s %p, ref %d\n", s, s->ref);
921
922 DEBUG_ASSERT(s);
923 DEBUG_ASSERT(is_mutex_held(&s->lock));
924 DEBUG_ASSERT(s->ref > 0);
925
926 if (s->state == STATE_CLOSED)
927 return;
928
929 s->state = STATE_CLOSED;
930
931 tcp_timer_cancel(s, &s->retransmit_timer);
932 tcp_timer_cancel(s, &s->ack_delay_timer);
933
934 tcp_wakeup_waiters(s);
935}
936
937static tcp_socket_t *create_tcp_socket(bool alloc_buffers)
938{
939 tcp_socket_t *s;
940
941 s = calloc(1, sizeof(tcp_socket_t));
942 if (!s)
943 return NULL;
944
945 mutex_init(&s->lock);
946 s->ref = 1; // start with the ref already bumped
947
948 s->state = STATE_CLOSED;
949 s->rx_win_size = DEFAULT_RX_WINDOW_SIZE;
950 event_init(&s->rx_event, false, 0);
951
952 s->mss = DEFAULT_MSS;
953
954 s->tx_win_low = rand();
955 s->tx_win_high = s->tx_win_low;
956 s->tx_highest_seq = s->tx_win_low;
957 event_init(&s->tx_event, true, 0);
958
959 if (alloc_buffers) {
960 // XXX check for error
961 s->rx_buffer_raw = malloc(s->rx_win_size);
962 cbuf_initialize_etc(&s->rx_buffer, s->rx_win_size, s->rx_buffer_raw);
963
964 s->tx_buffer_size = DEFAULT_TX_BUFFER_SIZE;
965 s->tx_buffer = malloc(s->tx_buffer_size);
966 }
967
968 sem_init(&s->accept_sem, 0);
969
970 return s;
971}
972
973/* user api */
974
975status_t tcp_open_listen(tcp_socket_t **handle, uint16_t port)
976{
977 tcp_socket_t *s;
978
979 if (!handle)
980 return ERR_INVALID_ARGS;
981
982 s = create_tcp_socket(false);
983 if (!s)
984 return ERR_NO_MEMORY;
985
986 // XXX see if there's another listen socket already on this port
987
988 s->local_port = port;
989
990 /* go to listen state */
991 s->state = STATE_LISTEN;
992
993 add_socket_to_list(s);
994
995 *handle = s;
996
997 return NO_ERROR;
998}
999
1000status_t tcp_accept_timeout(tcp_socket_t *listen_socket, tcp_socket_t **accept_socket, lk_time_t timeout)
1001{
1002 if (!listen_socket || !accept_socket)
1003 return ERR_INVALID_ARGS;
1004
1005 tcp_socket_t *s = listen_socket;
1006 inc_socket_ref(s);
1007
1008 /* block to accept a socket for an amount of time */
1009 if (sem_timedwait(&s->accept_sem, timeout) == ERR_TIMED_OUT) {
1010 dec_socket_ref(s);
1011 return ERR_TIMED_OUT;
1012 }
1013
1014 mutex_acquire(&s->lock);
1015
1016 /* we got here, grab the accepted socket and return */
1017 DEBUG_ASSERT(s->accepted);
1018 *accept_socket = s->accepted;
1019 s->accepted = NULL;
1020
1021 mutex_release(&s->lock);
1022 dec_socket_ref(s);
1023
1024 return NO_ERROR;
1025}
1026
1027ssize_t tcp_read(tcp_socket_t *socket, void *buf, size_t len)
1028{
1029 LTRACEF("socket %p, buf %p, len %zu\n", socket, buf, len);
1030 if (!socket)
1031 return ERR_INVALID_ARGS;
1032 if (len == 0)
1033 return 0;
1034 if (!buf)
1035 return ERR_INVALID_ARGS;
1036
1037 tcp_socket_t *s = socket;
1038 inc_socket_ref(s);
1039
1040 ssize_t ret = 0;
1041retry:
1042 /* block on available data */
1043 event_wait(&s->rx_event);
1044
1045 mutex_acquire(&s->lock);
1046
1047 /* try to read some data from the receive buffer, even if we're closed */
1048 ret = cbuf_read(&s->rx_buffer, buf, len, false);
1049 if (ret == 0) {
1050 /* check to see if we've closed */
1051 if (s->state != STATE_ESTABLISHED) {
1052 ret = ERR_CHANNEL_CLOSED;
1053 goto out;
1054 }
1055
1056 /* we must have raced with another thread */
1057 event_unsignal(&s->rx_event);
1058 mutex_release(&s->lock);
1059 goto retry;
1060 }
1061
1062 /* if we've used up the last byte in the read buffer, unsignal the read event */
1063 size_t remaining_bytes = cbuf_space_used(&s->rx_buffer);
1064 if (s->state == STATE_ESTABLISHED && remaining_bytes == 0) {
1065 event_unsignal(&s->rx_event);
1066 }
1067
1068 /* we've read something, make sure the other end knows that our window is opening */
1069 uint32_t new_rx_win_size = s->rx_win_size - remaining_bytes;
1070
1071 /* if we've opened it enough, send an ack */
1072 if (new_rx_win_size >= s->mss && s->rx_win_high - s->rx_win_low < s->mss)
1073 send_ack(s);
1074
1075out:
1076 mutex_release(&s->lock);
1077 dec_socket_ref(s);
1078
1079 return ret;
1080}
1081
1082ssize_t tcp_write(tcp_socket_t *socket, const void *buf, size_t len)
1083{
1084 LTRACEF("socket %p, buf %p, len %zu\n", socket, buf, len);
1085 if (!socket)
1086 return ERR_INVALID_ARGS;
1087 if (len == 0)
1088 return 0;
1089 if (!buf)
1090 return ERR_INVALID_ARGS;
1091
1092 tcp_socket_t *s = socket;
1093 inc_socket_ref(s);
1094
1095 size_t off = 0;
1096 while (off < len) {
1097 LTRACEF("off %zu, len %zu\n", off, len);
1098
1099 /* wait for the tx buffer to open up */
1100 event_wait(&s->tx_event);
1101 LTRACEF("after event_wait\n");
1102
1103 mutex_acquire(&s->lock);
1104
1105 /* check to see if we've closed */
1106 if (s->state != STATE_ESTABLISHED && s->state != STATE_CLOSE_WAIT) {
1107 mutex_release(&s->lock);
1108 dec_socket_ref(s);
1109 return ERR_CHANNEL_CLOSED;
1110 }
1111
1112 DEBUG_ASSERT(s->tx_buffer_size > 0);
1113 DEBUG_ASSERT(s->tx_buffer_offset <= s->tx_buffer_size);
1114
1115 /* figure out how much data to copy in */
1116 size_t to_copy = MIN(s->tx_buffer_size - s->tx_buffer_offset, len - off);
1117 if (to_copy == 0) {
1118 mutex_release(&s->lock);
1119 continue;
1120 }
1121
1122 memcpy(s->tx_buffer + s->tx_buffer_offset, (uint8_t *)buf + off, to_copy);
1123 s->tx_buffer_offset += to_copy;
1124
1125 /* if this has completely filled it, unsignal the event */
1126 DEBUG_ASSERT(s->tx_buffer_offset <= s->tx_buffer_size);
1127 if (s->tx_buffer_offset == s->tx_buffer_size) {
1128 event_unsignal(&s->tx_event);
1129 }
1130
1131 /* send as much data as we can */
1132 tcp_write_pending_data(s);
1133
1134 off += to_copy;
1135
1136 mutex_release(&s->lock);
1137 }
1138
1139 dec_socket_ref(s);
1140 return len;
1141}
1142
1143status_t tcp_close(tcp_socket_t *socket)
1144{
1145 if (!socket)
1146 return ERR_INVALID_ARGS;
1147
1148 tcp_socket_t *s = socket;
1149
1150 inc_socket_ref(s);
1151 mutex_acquire(&s->lock);
1152
1153 LTRACEF("socket %p, state %d (%s), ref %d\n", s, s->state, tcp_state_to_string(s->state), s->ref);
1154
1155 status_t err;
1156 switch (s->state) {
1157 case STATE_CLOSED:
1158 case STATE_LISTEN:
1159 /* we can directly remove this socket */
1160 remove_socket_from_list(s);
1161
1162 /* drop any timers that may be pending on this */
1163 tcp_timer_cancel(s, &s->ack_delay_timer);
1164 tcp_timer_cancel(s, &s->retransmit_timer);
1165
1166 s->state = STATE_CLOSED;
1167
1168 /* drop the extra ref that was held when the socket was created */
1169 dec_socket_ref(s);
1170 break;
1171 case STATE_SYN_RCVD:
1172 case STATE_ESTABLISHED:
1173 s->state = STATE_FIN_WAIT_1;
1174 tcp_socket_send(s, NULL, 0, PKT_ACK|PKT_FIN, NULL, 0, s->tx_win_low);
1175 s->tx_win_low++;
1176
1177 /* stick around and wait for them to FIN us */
1178 break;
1179 case STATE_CLOSE_WAIT:
1180 s->state = STATE_LAST_ACK;
1181 tcp_socket_send(s, NULL, 0, PKT_ACK|PKT_FIN, NULL, 0, s->tx_win_low);
1182 s->tx_win_low++;
1183
1184 // XXX set up fin retransmit timer here
1185 break;
1186 case STATE_FIN_WAIT_1:
1187 case STATE_FIN_WAIT_2:
1188 case STATE_CLOSING:
1189 case STATE_TIME_WAIT:
1190 case STATE_LAST_ACK:
1191 /* these states are all post tcp_close(), so it's invalid to call it here */
1192 err = ERR_CHANNEL_CLOSED;
1193 goto out;
1194 default:
1195 PANIC_UNIMPLEMENTED;
1196 }
1197
1198 /* make sure anyone blocked on this wakes up */
1199 tcp_wakeup_waiters(s);
1200
1201 mutex_release(&s->lock);
1202
1203 err = NO_ERROR;
1204
1205out:
1206 /* if this was the last ref, it should destroy the socket */
1207 dec_socket_ref(s);
1208
1209 return err;
1210}
1211
1212/* debug stuff */
1213static int cmd_tcp(int argc, const cmd_args *argv)
1214{
1215 status_t err;
1216
1217 if (argc < 2) {
1218notenoughargs:
1219 printf("ERROR not enough arguments\n");
1220usage:
1221 printf("usage: %s sockets\n", argv[0].str);
1222 printf("usage: %s listenclose <port>\n", argv[0].str);
1223 printf("usage: %s listen <port>\n", argv[0].str);
1224 printf("usage: %s debug\n", argv[0].str);
1225 return ERR_INVALID_ARGS;
1226 }
1227
1228 if (!strcmp(argv[1].str, "sockets")) {
1229
1230 mutex_acquire(&tcp_socket_list_lock);
1231 tcp_socket_t *s = NULL;
1232 list_for_every_entry(&tcp_socket_list, s, tcp_socket_t, node) {
1233 dump_socket(s);
1234 }
1235 mutex_release(&tcp_socket_list_lock);
1236 } else if (!strcmp(argv[1].str, "listenclose")) {
1237 /* listen for a connection, accept it, then immediately close it */
1238 if (argc < 3) goto notenoughargs;
1239
1240 tcp_socket_t *handle;
1241
1242 err = tcp_open_listen(&handle, argv[2].u);
1243 printf("tcp_open_listen returns %d, handle %p\n", err, handle);
1244
1245 tcp_socket_t *accepted;
1246 err = tcp_accept(handle, &accepted);
1247 printf("tcp_accept returns returns %d, handle %p\n", err, accepted);
1248
1249 err = tcp_close(accepted);
1250 printf("tcp_close returns %d\n", err);
1251
1252 err = tcp_close(handle);
1253 printf("tcp_close returns %d\n", err);
1254 } else if (!strcmp(argv[1].str, "listen")) {
1255 if (argc < 3) goto notenoughargs;
1256
1257 tcp_socket_t *handle;
1258
1259 err = tcp_open_listen(&handle, argv[2].u);
1260 printf("tcp_open_listen returns %d, handle %p\n", err, handle);
1261
1262 tcp_socket_t *accepted;
1263 err = tcp_accept(handle, &accepted);
1264 printf("tcp_accept returns returns %d, handle %p\n", err, accepted);
1265
1266 for (;;) {
1267 uint8_t buf[512];
1268
1269 ssize_t err = tcp_read(accepted, buf, sizeof(buf));
1270 printf("tcp_read returns %ld\n", err);
1271 if (err < 0)
1272 break;
1273 if (err > 0) {
1274 hexdump8(buf, err);
1275 }
1276
1277 err = tcp_write(accepted, buf, err);
1278 printf("tcp_write returns %ld\n", err);
1279 if (err < 0)
1280 break;
1281 }
1282
1283 err = tcp_close(accepted);
1284 printf("tcp_close returns %d\n", err);
1285
1286 err = tcp_close(handle);
1287 printf("tcp_close returns %d\n", err);
1288 } else if (!strcmp(argv[1].str, "debug")) {
1289 tcp_debug = !tcp_debug;
1290 printf("tcp debug now %u\n", tcp_debug);
1291 } else {
1292 printf("ERROR unknown command\n");
1293 goto usage;
1294 }
1295
1296 return NO_ERROR;
1297}
1298
1299STATIC_COMMAND_START
1300STATIC_COMMAND("tcp", "tcp commands", &cmd_tcp)
1301STATIC_COMMAND_END(tcp);
1302
1303
1304// vim: set ts=4 sw=4 expandtab: