rjw | 1f88458 | 2022-01-06 17:20:42 +0800 | [diff] [blame^] | 1 | /* |
| 2 | * Copyright (c) 2014 Travis Geiselbrecht |
| 3 | * |
| 4 | * Permission is hereby granted, free of charge, to any person obtaining |
| 5 | * a copy of this software and associated documentation files |
| 6 | * (the "Software"), to deal in the Software without restriction, |
| 7 | * including without limitation the rights to use, copy, modify, merge, |
| 8 | * publish, distribute, sublicense, and/or sell copies of the Software, |
| 9 | * and to permit persons to whom the Software is furnished to do so, |
| 10 | * subject to the following conditions: |
| 11 | * |
| 12 | * The above copyright notice and this permission notice shall be |
| 13 | * included in all copies or substantial portions of the Software. |
| 14 | * |
| 15 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
| 16 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
| 17 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
| 18 | * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
| 19 | * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
| 20 | * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE |
| 21 | * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
| 22 | */ |
| 23 | |
| 24 | #include "minip-internal.h" |
| 25 | |
| 26 | #include <trace.h> |
| 27 | #include <assert.h> |
| 28 | #include <compiler.h> |
| 29 | #include <stdlib.h> |
| 30 | #include <err.h> |
| 31 | #include <string.h> |
| 32 | #include <sys/types.h> |
| 33 | #include <lib/console.h> |
| 34 | #include <lib/cbuf.h> |
| 35 | #include <kernel/mutex.h> |
| 36 | #include <kernel/semaphore.h> |
| 37 | #include <arch/ops.h> |
| 38 | #include <platform.h> |
| 39 | |
| 40 | #define LOCAL_TRACE 0 |
| 41 | |
| 42 | typedef uint32_t ipv4_addr; |
| 43 | |
| 44 | typedef struct tcp_header { |
| 45 | uint16_t source_port; |
| 46 | uint16_t dest_port; |
| 47 | uint32_t seq_num; |
| 48 | uint32_t ack_num; |
| 49 | uint16_t length_flags; |
| 50 | uint16_t win_size; |
| 51 | uint16_t checksum; |
| 52 | uint16_t urg_pointer; |
| 53 | } __PACKED tcp_header_t; |
| 54 | |
| 55 | typedef struct tcp_pseudo_header { |
| 56 | ipv4_addr source_addr; |
| 57 | ipv4_addr dest_addr; |
| 58 | uint8_t zero; |
| 59 | uint8_t protocol; |
| 60 | uint16_t tcp_length; |
| 61 | } __PACKED tcp_pseudo_header_t; |
| 62 | |
| 63 | typedef struct tcp_mss_option { |
| 64 | uint8_t kind; /* 0x2 */ |
| 65 | uint8_t len; /* 0x4 */ |
| 66 | uint16_t mss; |
| 67 | } __PACKED tcp_mss_option_t; |
| 68 | |
| 69 | typedef enum tcp_state { |
| 70 | STATE_CLOSED, |
| 71 | STATE_LISTEN, |
| 72 | STATE_SYN_SENT, |
| 73 | STATE_SYN_RCVD, |
| 74 | STATE_ESTABLISHED, |
| 75 | STATE_CLOSE_WAIT, |
| 76 | STATE_LAST_ACK, |
| 77 | STATE_CLOSING, |
| 78 | STATE_FIN_WAIT_1, |
| 79 | STATE_FIN_WAIT_2, |
| 80 | STATE_TIME_WAIT |
| 81 | } tcp_state_t; |
| 82 | |
| 83 | typedef enum tcp_flags { |
| 84 | PKT_FIN = 1, |
| 85 | PKT_SYN = 2, |
| 86 | PKT_RST = 4, |
| 87 | PKT_PSH = 8, |
| 88 | PKT_ACK = 16, |
| 89 | PKT_URG = 32 |
| 90 | } tcp_flags_t; |
| 91 | |
| 92 | typedef struct tcp_socket { |
| 93 | struct list_node node; |
| 94 | |
| 95 | mutex_t lock; |
| 96 | volatile int ref; |
| 97 | |
| 98 | tcp_state_t state; |
| 99 | ipv4_addr local_ip; |
| 100 | ipv4_addr remote_ip; |
| 101 | uint16_t local_port; |
| 102 | uint16_t remote_port; |
| 103 | |
| 104 | uint32_t mss; |
| 105 | |
| 106 | /* rx */ |
| 107 | uint32_t rx_win_size; |
| 108 | uint32_t rx_win_low; |
| 109 | uint32_t rx_win_high; |
| 110 | uint8_t *rx_buffer_raw; |
| 111 | cbuf_t rx_buffer; |
| 112 | event_t rx_event; |
| 113 | int rx_full_mss_count; // number of packets we have received in a row with a full mss |
| 114 | net_timer_t ack_delay_timer; |
| 115 | |
| 116 | /* tx */ |
| 117 | uint32_t tx_win_low; // low side of the acked window |
| 118 | uint32_t tx_win_high; // tx_win_low + their advertised window size |
| 119 | uint32_t tx_highest_seq; // highest sequence we have txed them |
| 120 | uint8_t *tx_buffer; // our outgoing buffer |
| 121 | uint32_t tx_buffer_size; // size of tx_buffer |
| 122 | uint32_t tx_buffer_offset; // offset into the buffer to append new data to |
| 123 | event_t tx_event; |
| 124 | net_timer_t retransmit_timer; |
| 125 | |
| 126 | /* listen accept */ |
| 127 | semaphore_t accept_sem; |
| 128 | struct tcp_socket *accepted; |
| 129 | |
| 130 | net_timer_t time_wait_timer; |
| 131 | } tcp_socket_t; |
| 132 | |
| 133 | #define DEFAULT_MSS (1460) |
| 134 | #define DEFAULT_RX_WINDOW_SIZE (8192) |
| 135 | #define DEFAULT_TX_BUFFER_SIZE (8192) |
| 136 | |
| 137 | #define RETRANSMIT_TIMEOUT (50) |
| 138 | #define DELAYED_ACK_TIMEOUT (50) |
| 139 | #define TIME_WAIT_TIMEOUT (60000) // 1 minute |
| 140 | |
| 141 | #define FORCE_TCP_CHECKSUM (false) |
| 142 | |
| 143 | #define SEQUENCE_GTE(a, b) ((int32_t)((a) - (b)) >= 0) |
| 144 | #define SEQUENCE_LTE(a, b) ((int32_t)((a) - (b)) <= 0) |
| 145 | #define SEQUENCE_GT(a, b) ((int32_t)((a) - (b)) > 0) |
| 146 | #define SEQUENCE_LT(a, b) ((int32_t)((a) - (b)) < 0) |
| 147 | |
| 148 | static mutex_t tcp_socket_list_lock = MUTEX_INITIAL_VALUE(tcp_socket_list_lock); |
| 149 | static struct list_node tcp_socket_list = LIST_INITIAL_VALUE(tcp_socket_list); |
| 150 | |
| 151 | static bool tcp_debug = false; |
| 152 | |
| 153 | /* local routines */ |
| 154 | static tcp_socket_t *lookup_socket(ipv4_addr remote_ip, ipv4_addr local_ip, uint16_t remote_port, uint16_t local_port); |
| 155 | static void add_socket_to_list(tcp_socket_t *s); |
| 156 | static void remove_socket_from_list(tcp_socket_t *s); |
| 157 | static tcp_socket_t *create_tcp_socket(bool alloc_buffers); |
| 158 | static status_t tcp_send(ipv4_addr dest_ip, uint16_t dest_port, ipv4_addr src_ip, uint16_t src_port, const void *buf, |
| 159 | size_t len, tcp_flags_t flags, const void *options, size_t options_length, uint32_t ack, uint32_t sequence, uint16_t window_size); |
| 160 | static status_t tcp_socket_send(tcp_socket_t *s, const void *data, size_t len, tcp_flags_t flags, const void *options, size_t options_length, uint32_t sequence); |
| 161 | static void handle_data(tcp_socket_t *s, const void *data, size_t len, uint32_t sequence); |
| 162 | static void send_ack(tcp_socket_t *s); |
| 163 | static void handle_ack(tcp_socket_t *s, uint32_t sequence, uint32_t win_size); |
| 164 | static void handle_retransmit_timeout(void *_s); |
| 165 | static void handle_time_wait_timeout(void *_s); |
| 166 | static void handle_delayed_ack_timeout(void *_s); |
| 167 | static void tcp_remote_close(tcp_socket_t *s); |
| 168 | static void tcp_wakeup_waiters(tcp_socket_t *s); |
| 169 | static void inc_socket_ref(tcp_socket_t *s); |
| 170 | static bool dec_socket_ref(tcp_socket_t *s); |
| 171 | |
| 172 | static uint16_t cksum_pheader(const tcp_pseudo_header_t *pheader, const void *buf, size_t len) |
| 173 | { |
| 174 | uint16_t checksum = ones_sum16(0, pheader, sizeof(*pheader)); |
| 175 | return ~ones_sum16(checksum, buf, len); |
| 176 | } |
| 177 | |
| 178 | __NO_INLINE static void dump_tcp_header(const tcp_header_t *header) |
| 179 | { |
| 180 | printf("TCP: src_port %u, dest_port %u, seq %u, ack %u, win %u, flags %c%c%c%c%c%c\n", |
| 181 | ntohs(header->source_port), ntohs(header->dest_port), ntohl(header->seq_num), ntohl(header->ack_num), |
| 182 | ntohs(header->win_size), |
| 183 | (ntohs(header->length_flags) & PKT_FIN) ? 'F' : ' ', |
| 184 | (ntohs(header->length_flags) & PKT_SYN) ? 'S' : ' ', |
| 185 | (ntohs(header->length_flags) & PKT_RST) ? 'R' : ' ', |
| 186 | (ntohs(header->length_flags) & PKT_PSH) ? 'P' : ' ', |
| 187 | (ntohs(header->length_flags) & PKT_ACK) ? 'A' : ' ', |
| 188 | (ntohs(header->length_flags) & PKT_URG) ? 'U' : ' '); |
| 189 | } |
| 190 | |
| 191 | static const char *tcp_state_to_string(tcp_state_t state) |
| 192 | { |
| 193 | switch (state) { |
| 194 | default: |
| 195 | case STATE_CLOSED: return "CLOSED"; |
| 196 | case STATE_LISTEN: return "LISTEN"; |
| 197 | case STATE_SYN_SENT: return "SYN_SENT"; |
| 198 | case STATE_SYN_RCVD: return "SYN_RCVD"; |
| 199 | case STATE_ESTABLISHED: return "ESTABLISHED"; |
| 200 | case STATE_CLOSE_WAIT: return "CLOSE_WAIT"; |
| 201 | case STATE_LAST_ACK: return "LAST_ACK"; |
| 202 | case STATE_CLOSING: return "CLOSING"; |
| 203 | case STATE_FIN_WAIT_1: return "FIN_WAIT_1"; |
| 204 | case STATE_FIN_WAIT_2: return "FIN_WAIT_2"; |
| 205 | case STATE_TIME_WAIT: return "TIME_WAIT"; |
| 206 | } |
| 207 | } |
| 208 | |
| 209 | static void dump_socket(tcp_socket_t *s) |
| 210 | { |
| 211 | printf("socket %p: state %d (%s), local 0x%x:%hu, remote 0x%x:%hu, ref %d\n", |
| 212 | s, s->state, tcp_state_to_string(s->state), |
| 213 | s->local_ip, s->local_port, s->remote_ip, s->remote_port, s->ref); |
| 214 | if (s->state == STATE_ESTABLISHED || s->state == STATE_CLOSE_WAIT) { |
| 215 | printf("\trx: wsize %u wlo %u whi %u (%u)\n", |
| 216 | s->rx_win_size, s->rx_win_low, s->rx_win_high, |
| 217 | s->rx_win_high - s->rx_win_low); |
| 218 | printf("\ttx: wlo %u whi %u (%u) highest_seq %u (%u) bufsize %u bufoff %u\n", |
| 219 | s->tx_win_low, s->tx_win_high, s->tx_win_high - s->tx_win_low, |
| 220 | s->tx_highest_seq, s->tx_highest_seq - s->tx_win_low, |
| 221 | s->tx_buffer_size, s->tx_buffer_offset); |
| 222 | } |
| 223 | } |
| 224 | |
| 225 | static tcp_socket_t *lookup_socket(ipv4_addr remote_ip, ipv4_addr local_ip, uint16_t remote_port, uint16_t local_port) |
| 226 | { |
| 227 | LTRACEF("remote ip 0x%x local ip 0x%x remote port %u local port %u\n", remote_ip, local_ip, remote_port, local_port); |
| 228 | |
| 229 | mutex_acquire(&tcp_socket_list_lock); |
| 230 | |
| 231 | /* XXX replace with something faster, like a hash table */ |
| 232 | tcp_socket_t *s = NULL; |
| 233 | list_for_every_entry(&tcp_socket_list, s, tcp_socket_t, node) { |
| 234 | if (s->state == STATE_CLOSED || s->state == STATE_LISTEN) { |
| 235 | continue; |
| 236 | } else { |
| 237 | /* full check */ |
| 238 | if (s->remote_ip == remote_ip && |
| 239 | s->local_ip == local_ip && |
| 240 | s->remote_port == remote_port && |
| 241 | s->local_port == local_port) { |
| 242 | goto out; |
| 243 | } |
| 244 | } |
| 245 | } |
| 246 | |
| 247 | /* walk the list again, looking only for listen matches */ |
| 248 | list_for_every_entry(&tcp_socket_list, s, tcp_socket_t, node) { |
| 249 | if (s->state == STATE_LISTEN) { |
| 250 | /* sockets in listen state only care about local port */ |
| 251 | if (s->local_port == local_port) { |
| 252 | goto out; |
| 253 | } |
| 254 | } |
| 255 | } |
| 256 | |
| 257 | /* fall through case returns null */ |
| 258 | s = NULL; |
| 259 | |
| 260 | out: |
| 261 | /* bump the ref before returning it */ |
| 262 | if (s) |
| 263 | inc_socket_ref(s); |
| 264 | |
| 265 | mutex_release(&tcp_socket_list_lock); |
| 266 | |
| 267 | return s; |
| 268 | } |
| 269 | |
| 270 | static void add_socket_to_list(tcp_socket_t *s) |
| 271 | { |
| 272 | DEBUG_ASSERT(s); |
| 273 | DEBUG_ASSERT(s->ref > 0); // we should have implicitly bumped the ref when creating the socket |
| 274 | |
| 275 | mutex_acquire(&tcp_socket_list_lock); |
| 276 | |
| 277 | list_add_head(&tcp_socket_list, &s->node); |
| 278 | |
| 279 | mutex_release(&tcp_socket_list_lock); |
| 280 | } |
| 281 | |
| 282 | static void remove_socket_from_list(tcp_socket_t *s) |
| 283 | { |
| 284 | DEBUG_ASSERT(s); |
| 285 | DEBUG_ASSERT(s->ref > 0); |
| 286 | |
| 287 | mutex_acquire(&tcp_socket_list_lock); |
| 288 | |
| 289 | DEBUG_ASSERT(list_in_list(&s->node)); |
| 290 | list_delete(&s->node); |
| 291 | |
| 292 | mutex_release(&tcp_socket_list_lock); |
| 293 | } |
| 294 | |
| 295 | static void inc_socket_ref(tcp_socket_t *s) |
| 296 | { |
| 297 | DEBUG_ASSERT(s); |
| 298 | |
| 299 | __UNUSED int oldval = atomic_add(&s->ref, 1); |
| 300 | LTRACEF("caller %p, thread %p, socket %p, ref now %d\n", __GET_CALLER(), get_current_thread(), s, oldval + 1); |
| 301 | DEBUG_ASSERT(oldval > 0); |
| 302 | } |
| 303 | |
| 304 | static bool dec_socket_ref(tcp_socket_t *s) |
| 305 | { |
| 306 | DEBUG_ASSERT(s); |
| 307 | |
| 308 | int oldval = atomic_add(&s->ref, -1); |
| 309 | LTRACEF("caller %p, thread %p, socket %p, ref now %d\n", __GET_CALLER(), get_current_thread(), s, oldval - 1); |
| 310 | |
| 311 | if (oldval == 1) { |
| 312 | LTRACEF("destroying socket\n"); |
| 313 | event_destroy(&s->tx_event); |
| 314 | event_destroy(&s->rx_event); |
| 315 | |
| 316 | free(s->rx_buffer_raw); |
| 317 | free(s->tx_buffer); |
| 318 | |
| 319 | free(s); |
| 320 | } |
| 321 | return (oldval == 1); |
| 322 | } |
| 323 | |
| 324 | static void tcp_timer_set(tcp_socket_t *s, net_timer_t *timer, net_timer_callback_t cb, lk_time_t delay) |
| 325 | { |
| 326 | DEBUG_ASSERT(s); |
| 327 | DEBUG_ASSERT(timer); |
| 328 | |
| 329 | if (net_timer_set(timer, cb, s, delay)) |
| 330 | inc_socket_ref(s); |
| 331 | } |
| 332 | |
| 333 | static void tcp_timer_cancel(tcp_socket_t *s, net_timer_t *timer) |
| 334 | { |
| 335 | |
| 336 | DEBUG_ASSERT(s); |
| 337 | DEBUG_ASSERT(timer); |
| 338 | |
| 339 | if (net_timer_cancel(timer)) |
| 340 | dec_socket_ref(s); |
| 341 | } |
| 342 | |
| 343 | void tcp_input(pktbuf_t *p, uint32_t src_ip, uint32_t dst_ip) |
| 344 | { |
| 345 | if (unlikely(tcp_debug)) |
| 346 | TRACEF("p %p (len %u), src_ip 0x%x, dst_ip 0x%x\n", p, p->dlen, src_ip, dst_ip); |
| 347 | |
| 348 | tcp_header_t *header = (tcp_header_t *)p->data; |
| 349 | |
| 350 | /* reject if too small */ |
| 351 | if (p->dlen < sizeof(tcp_header_t)) |
| 352 | return; |
| 353 | |
| 354 | if (unlikely(tcp_debug) || LOCAL_TRACE) { |
| 355 | dump_tcp_header(header); |
| 356 | } |
| 357 | |
| 358 | /* compute the actual header length (+ options) */ |
| 359 | size_t header_len = ((ntohs(header->length_flags) >> 12) & 0xf) * 4; |
| 360 | if (p->dlen < header_len) { |
| 361 | TRACEF("REJECT: packet too large for buffer\n"); |
| 362 | return; |
| 363 | } |
| 364 | |
| 365 | /* checksum */ |
| 366 | if (FORCE_TCP_CHECKSUM || (p->flags & PKTBUF_FLAG_CKSUM_TCP_GOOD) == 0) { |
| 367 | tcp_pseudo_header_t pheader; |
| 368 | |
| 369 | // set up the pseudo header for checksum purposes |
| 370 | pheader.source_addr = src_ip; |
| 371 | pheader.dest_addr = dst_ip; |
| 372 | pheader.zero = 0; |
| 373 | pheader.protocol = IP_PROTO_TCP; |
| 374 | pheader.tcp_length = htons(p->dlen); |
| 375 | |
| 376 | uint16_t checksum = cksum_pheader(&pheader, p->data, p->dlen); |
| 377 | if(checksum != 0) { |
| 378 | TRACEF("REJECT: failed checksum, header says 0x%x, we got 0x%x\n", header->checksum, checksum); |
| 379 | return; |
| 380 | } |
| 381 | } |
| 382 | |
| 383 | /* byte swap header in place */ |
| 384 | header->source_port = ntohs(header->source_port); |
| 385 | header->dest_port = ntohs(header->dest_port); |
| 386 | header->seq_num = ntohl(header->seq_num); |
| 387 | header->ack_num = ntohl(header->ack_num); |
| 388 | header->length_flags = ntohs(header->length_flags); |
| 389 | header->win_size = ntohs(header->win_size); |
| 390 | header->urg_pointer = ntohs(header->urg_pointer); |
| 391 | |
| 392 | /* get some data from the packet */ |
| 393 | uint8_t packet_flags = header->length_flags & 0x3f; |
| 394 | size_t data_len = p->dlen - header_len; |
| 395 | uint32_t highest_sequence = header->seq_num + ((data_len > 0) ? (data_len - 1) : 0); |
| 396 | |
| 397 | /* see if it matches a socket we have */ |
| 398 | tcp_socket_t *s = lookup_socket(src_ip, dst_ip, header->source_port, header->dest_port); |
| 399 | if (!s) { |
| 400 | /* send a RST packet */ |
| 401 | goto send_reset; |
| 402 | } |
| 403 | |
| 404 | if (unlikely(tcp_debug)) |
| 405 | TRACEF("got socket %p, state %d (%s), ref %d\n", s, s->state, tcp_state_to_string(s->state), s->ref); |
| 406 | |
| 407 | /* remove the header */ |
| 408 | pktbuf_consume(p, header_len); |
| 409 | |
| 410 | mutex_acquire(&s->lock); |
| 411 | |
| 412 | /* check to see if they're resetting us */ |
| 413 | if (packet_flags & PKT_RST) { |
| 414 | if (s->state != STATE_CLOSED && s->state != STATE_LISTEN) { |
| 415 | tcp_remote_close(s); |
| 416 | } |
| 417 | goto done; |
| 418 | } |
| 419 | |
| 420 | switch (s->state) { |
| 421 | case STATE_CLOSED: |
| 422 | /* socket closed, send RST */ |
| 423 | goto send_reset; |
| 424 | |
| 425 | /* passive connect states */ |
| 426 | case STATE_LISTEN: { |
| 427 | /* we're in listen and they want to talk to us */ |
| 428 | if (!(packet_flags & PKT_SYN)) { |
| 429 | /* not a SYN, send RST */ |
| 430 | goto send_reset; |
| 431 | } |
| 432 | |
| 433 | /* see if we have a slot to accept */ |
| 434 | if (s->accepted != NULL) |
| 435 | goto done; |
| 436 | |
| 437 | /* make a new accept socket */ |
| 438 | tcp_socket_t *accept_socket = create_tcp_socket(true); |
| 439 | if (!accept_socket) |
| 440 | goto done; |
| 441 | |
| 442 | /* set it up */ |
| 443 | accept_socket->local_ip = minip_get_ipaddr(); |
| 444 | accept_socket->local_port = s->local_port; |
| 445 | accept_socket->remote_ip = src_ip; |
| 446 | accept_socket->remote_port = header->source_port; |
| 447 | accept_socket->state = STATE_SYN_RCVD; |
| 448 | |
| 449 | mutex_acquire(&accept_socket->lock); |
| 450 | |
| 451 | add_socket_to_list(accept_socket); |
| 452 | |
| 453 | /* remember their sequence */ |
| 454 | accept_socket->rx_win_low = header->seq_num + 1; |
| 455 | accept_socket->rx_win_high = accept_socket->rx_win_low + accept_socket->rx_win_size - 1; |
| 456 | |
| 457 | /* save this socket and wake anyone up that is waiting to accept */ |
| 458 | s->accepted = accept_socket; |
| 459 | sem_post(&s->accept_sem, true); |
| 460 | |
| 461 | /* set up a mss option for sending back */ |
| 462 | tcp_mss_option_t mss_option; |
| 463 | mss_option.kind = 0x2; |
| 464 | mss_option.len = 0x4; |
| 465 | mss_option.mss = ntohs(s->mss); // XXX make sure we fit in their mss |
| 466 | |
| 467 | /* send a response */ |
| 468 | tcp_socket_send(accept_socket, NULL, 0, PKT_ACK|PKT_SYN, &mss_option, sizeof(mss_option), |
| 469 | accept_socket->tx_win_low); |
| 470 | |
| 471 | /* SYN consumed a sequence */ |
| 472 | accept_socket->tx_win_low++; |
| 473 | |
| 474 | mutex_release(&accept_socket->lock); |
| 475 | break; |
| 476 | } |
| 477 | case STATE_SYN_RCVD: |
| 478 | if (packet_flags & PKT_SYN) { |
| 479 | /* they must have not seen our ack of their original syn, retransmit */ |
| 480 | // XXX implement |
| 481 | goto send_reset; |
| 482 | } |
| 483 | |
| 484 | /* if they ack our SYN, we can move on to ESTABLISHED */ |
| 485 | if (packet_flags & PKT_ACK) { |
| 486 | if (header->ack_num != s->tx_win_low) { |
| 487 | goto send_reset; |
| 488 | } |
| 489 | |
| 490 | s->tx_win_high = s->tx_win_low + header->win_size; |
| 491 | s->tx_highest_seq = s->tx_win_low; |
| 492 | |
| 493 | s->state = STATE_ESTABLISHED; |
| 494 | } else { |
| 495 | goto send_reset; |
| 496 | } |
| 497 | |
| 498 | break; |
| 499 | |
| 500 | case STATE_ESTABLISHED: |
| 501 | if (packet_flags & PKT_ACK) { |
| 502 | /* they're acking us */ |
| 503 | handle_ack(s, header->ack_num, header->win_size); |
| 504 | } |
| 505 | |
| 506 | if (data_len > 0) { |
| 507 | LTRACEF("new data, len %zu\n", data_len); |
| 508 | handle_data(s, p->data, p->dlen, header->seq_num); |
| 509 | } |
| 510 | |
| 511 | if ((packet_flags & PKT_FIN) && SEQUENCE_GTE(s->rx_win_low, highest_sequence)) { |
| 512 | /* they're closing with us, and there's no outstanding data */ |
| 513 | |
| 514 | /* FIN consumed a sequence */ |
| 515 | s->rx_win_low++; |
| 516 | |
| 517 | /* ack them and transition to new state */ |
| 518 | send_ack(s); |
| 519 | s->state = STATE_CLOSE_WAIT; |
| 520 | |
| 521 | /* wake up any read waiters */ |
| 522 | event_signal(&s->rx_event, true); |
| 523 | } |
| 524 | break; |
| 525 | |
| 526 | case STATE_CLOSE_WAIT: |
| 527 | if (packet_flags & PKT_ACK) { |
| 528 | /* they're acking us */ |
| 529 | handle_ack(s, header->ack_num, header->win_size); |
| 530 | } |
| 531 | if (packet_flags & PKT_FIN) { |
| 532 | /* they must have missed our ack, ack them again */ |
| 533 | send_ack(s); |
| 534 | } |
| 535 | break; |
| 536 | case STATE_LAST_ACK: |
| 537 | if (packet_flags & PKT_ACK) { |
| 538 | /* they're acking our FIN, probably */ |
| 539 | tcp_remote_close(s); |
| 540 | |
| 541 | /* tcp_close() was already called on us, remove us from the list and drop the ref */ |
| 542 | remove_socket_from_list(s); |
| 543 | dec_socket_ref(s); |
| 544 | } |
| 545 | break; |
| 546 | case STATE_FIN_WAIT_1: |
| 547 | if (packet_flags & PKT_ACK) { |
| 548 | /* they're acking our FIN, probably */ |
| 549 | s->state = STATE_FIN_WAIT_2; |
| 550 | /* drop into fin_wait_2 state logic, in case they were FINning us too */ |
| 551 | goto fin_wait_2; |
| 552 | } else if (packet_flags & PKT_FIN) { |
| 553 | /* simultaneous close. they finned us without acking our fin */ |
| 554 | s->rx_win_low++; |
| 555 | send_ack(s); |
| 556 | s->state = STATE_CLOSING; |
| 557 | } |
| 558 | break; |
| 559 | case STATE_FIN_WAIT_2: |
| 560 | fin_wait_2: |
| 561 | if (packet_flags & PKT_FIN) { |
| 562 | /* they're FINning us, ack them */ |
| 563 | s->rx_win_low++; |
| 564 | send_ack(s); |
| 565 | s->state = STATE_TIME_WAIT; |
| 566 | |
| 567 | /* set timed wait timer */ |
| 568 | tcp_timer_set(s, &s->time_wait_timer, &handle_time_wait_timeout, TIME_WAIT_TIMEOUT); |
| 569 | } |
| 570 | break; |
| 571 | case STATE_CLOSING: |
| 572 | if (packet_flags & PKT_ACK) { |
| 573 | /* they're acking our FIN, probably */ |
| 574 | s->state = STATE_TIME_WAIT; |
| 575 | |
| 576 | /* set timed wait timer */ |
| 577 | tcp_timer_set(s, &s->time_wait_timer, &handle_time_wait_timeout, TIME_WAIT_TIMEOUT); |
| 578 | } |
| 579 | break; |
| 580 | case STATE_TIME_WAIT: |
| 581 | /* /dev/null of packets */ |
| 582 | break; |
| 583 | |
| 584 | case STATE_SYN_SENT: |
| 585 | PANIC_UNIMPLEMENTED; |
| 586 | } |
| 587 | |
| 588 | done: |
| 589 | mutex_release(&s->lock); |
| 590 | dec_socket_ref(s); |
| 591 | return; |
| 592 | |
| 593 | send_reset: |
| 594 | if (s) { |
| 595 | mutex_release(&s->lock); |
| 596 | dec_socket_ref(s); |
| 597 | } |
| 598 | |
| 599 | LTRACEF("SEND RST\n"); |
| 600 | if (!(packet_flags & PKT_RST)) { |
| 601 | tcp_send(src_ip, header->source_port, dst_ip, header->dest_port, |
| 602 | NULL, 0, PKT_RST, NULL, 0, 0, header->ack_num, 0); |
| 603 | } |
| 604 | } |
| 605 | |
| 606 | static void handle_data(tcp_socket_t *s, const void *data, size_t len, uint32_t sequence) |
| 607 | { |
| 608 | if (unlikely(tcp_debug)) |
| 609 | TRACEF("data %p, len %zu, sequence %u\n", data, len, sequence); |
| 610 | |
| 611 | DEBUG_ASSERT(s); |
| 612 | DEBUG_ASSERT(is_mutex_held(&s->lock)); |
| 613 | DEBUG_ASSERT(data); |
| 614 | DEBUG_ASSERT(len > 0); |
| 615 | |
| 616 | /* see if it matches our current window */ |
| 617 | uint32_t sequence_top = sequence + len - 1; |
| 618 | if (SEQUENCE_LTE(sequence, s->rx_win_low) && SEQUENCE_GTE(sequence_top, s->rx_win_low)) { |
| 619 | /* it intersects the bottom of our window, so it's in order */ |
| 620 | |
| 621 | /* copy the data we need to our cbuf */ |
| 622 | size_t offset = sequence - s->rx_win_low; |
| 623 | size_t copy_len = MIN(s->rx_win_high - s->rx_win_low, len - offset); |
| 624 | |
| 625 | DEBUG_ASSERT(offset < len); |
| 626 | |
| 627 | LTRACEF("copying from offset %zu, len %zu\n", offset, copy_len); |
| 628 | |
| 629 | s->rx_win_low += copy_len; |
| 630 | |
| 631 | cbuf_write(&s->rx_buffer, (uint8_t *)data + offset, copy_len, false); |
| 632 | event_signal(&s->rx_event, true); |
| 633 | |
| 634 | /* keep a counter if they've been sending a full mss */ |
| 635 | if (copy_len >= s->mss) { |
| 636 | s->rx_full_mss_count++; |
| 637 | } else { |
| 638 | s->rx_full_mss_count = 0; |
| 639 | } |
| 640 | |
| 641 | /* immediately ack if we're more than halfway into our buffer or they've sent 2 or more full packets */ |
| 642 | if (s->rx_full_mss_count >= 2 || |
| 643 | (int)(s->rx_win_low + s->rx_win_size - s->rx_win_high) > (int)s->rx_win_size / 2) { |
| 644 | send_ack(s); |
| 645 | s->rx_full_mss_count = 0; |
| 646 | } else { |
| 647 | tcp_timer_set(s, &s->ack_delay_timer, &handle_delayed_ack_timeout, DELAYED_ACK_TIMEOUT); |
| 648 | } |
| 649 | } else { |
| 650 | // either out of order or completely out of our window, drop |
| 651 | // duplicately ack the last thing we really got |
| 652 | send_ack(s); |
| 653 | } |
| 654 | } |
| 655 | |
| 656 | static status_t tcp_socket_send(tcp_socket_t *s, const void *data, size_t len, tcp_flags_t flags, |
| 657 | const void *options, size_t options_length, uint32_t sequence) |
| 658 | { |
| 659 | DEBUG_ASSERT(s); |
| 660 | DEBUG_ASSERT(is_mutex_held(&s->lock)); |
| 661 | DEBUG_ASSERT(len == 0 || data); |
| 662 | DEBUG_ASSERT(options_length == 0 || options); |
| 663 | DEBUG_ASSERT((options_length % 4) == 0); |
| 664 | |
| 665 | // calculate the new right edge of the rx window |
| 666 | uint32_t rx_win_high = s->rx_win_low + s->rx_win_size - cbuf_space_used(&s->rx_buffer) - 1; |
| 667 | |
| 668 | LTRACEF("rx_win_low %u rx_win_size %u read_buf_len %zu, new win high %u\n", |
| 669 | s->rx_win_low, s->rx_win_size, cbuf_space_used(&s->rx_buffer), rx_win_high); |
| 670 | |
| 671 | uint16_t win_size; |
| 672 | if (SEQUENCE_GTE(rx_win_high, s->rx_win_high)) { |
| 673 | s->rx_win_high = rx_win_high; |
| 674 | win_size = rx_win_high - s->rx_win_low; |
| 675 | } else { |
| 676 | // the window size has shrunk, but we can't move the |
| 677 | // right edge of the window backwards |
| 678 | win_size = s->rx_win_high - s->rx_win_low; |
| 679 | } |
| 680 | |
| 681 | // we are piggybacking a pending ACK, so clear the delayed ACK timer |
| 682 | if (flags & PKT_ACK) { |
| 683 | tcp_timer_cancel(s, &s->ack_delay_timer); |
| 684 | } |
| 685 | |
| 686 | status_t err = tcp_send(s->remote_ip, s->remote_port, s->local_ip, s->local_port, data, len, flags, |
| 687 | options, options_length, (flags & PKT_ACK) ? s->rx_win_low : 0, sequence, win_size); |
| 688 | |
| 689 | return err; |
| 690 | } |
| 691 | |
| 692 | static void send_ack(tcp_socket_t *s) |
| 693 | { |
| 694 | DEBUG_ASSERT(s); |
| 695 | DEBUG_ASSERT(is_mutex_held(&s->lock)); |
| 696 | |
| 697 | if (s->state != STATE_ESTABLISHED && s->state != STATE_CLOSE_WAIT && s->state != STATE_FIN_WAIT_2) |
| 698 | return; |
| 699 | |
| 700 | tcp_socket_send(s, NULL, 0, PKT_ACK, NULL, 0, s->tx_win_low); |
| 701 | } |
| 702 | |
| 703 | static status_t tcp_send(ipv4_addr dest_ip, uint16_t dest_port, ipv4_addr src_ip, uint16_t src_port, const void *buf, |
| 704 | size_t len, tcp_flags_t flags, const void *options, size_t options_length, uint32_t ack, uint32_t sequence, uint16_t window_size) |
| 705 | { |
| 706 | DEBUG_ASSERT(len == 0 || buf); |
| 707 | DEBUG_ASSERT(options_length == 0 || options); |
| 708 | DEBUG_ASSERT((options_length % 4) == 0); |
| 709 | |
| 710 | pktbuf_t *p = pktbuf_alloc(); |
| 711 | if (!p) |
| 712 | return ERR_NO_MEMORY; |
| 713 | |
| 714 | tcp_header_t *header = pktbuf_prepend(p, sizeof(tcp_header_t) + options_length); |
| 715 | DEBUG_ASSERT(header); |
| 716 | |
| 717 | /* fill in the header */ |
| 718 | header->source_port = htons(src_port); |
| 719 | header->dest_port = htons(dest_port); |
| 720 | header->seq_num = htonl(sequence); |
| 721 | header->ack_num = htonl(ack); |
| 722 | header->length_flags = htons(((sizeof(tcp_header_t) + options_length) / 4) << 12 | flags); |
| 723 | header->win_size = htons(window_size); |
| 724 | header->checksum = 0; |
| 725 | header->urg_pointer = 0; |
| 726 | if (options) |
| 727 | memcpy(header + 1, options, options_length); |
| 728 | |
| 729 | /* append the data */ |
| 730 | if (len > 0) |
| 731 | pktbuf_append_data(p, buf, len); |
| 732 | |
| 733 | /* compute the checksum */ |
| 734 | /* XXX get the tx ckecksum capability from the nic */ |
| 735 | if (FORCE_TCP_CHECKSUM || true) { |
| 736 | tcp_pseudo_header_t pheader; |
| 737 | pheader.source_addr = src_ip; |
| 738 | pheader.dest_addr = dest_ip; |
| 739 | pheader.zero = 0; |
| 740 | pheader.protocol = IP_PROTO_TCP; |
| 741 | pheader.tcp_length = htons(p->dlen); |
| 742 | |
| 743 | header->checksum = cksum_pheader(&pheader, p->data, p->dlen); |
| 744 | } |
| 745 | |
| 746 | if (LOCAL_TRACE) { |
| 747 | printf("sending "); |
| 748 | dump_tcp_header(header); |
| 749 | } |
| 750 | |
| 751 | status_t err = minip_ipv4_send(p, dest_ip, IP_PROTO_TCP); |
| 752 | |
| 753 | return err; |
| 754 | } |
| 755 | |
| 756 | static void handle_ack(tcp_socket_t *s, uint32_t sequence, uint32_t win_size) |
| 757 | { |
| 758 | LTRACEF("socket %p ack sequence %u, win_size %u\n", s, sequence, win_size); |
| 759 | |
| 760 | DEBUG_ASSERT(s); |
| 761 | DEBUG_ASSERT(is_mutex_held(&s->lock)); |
| 762 | |
| 763 | LTRACEF("s %p, tx_win_low %u tx_win_high %u tx_highest_seq %u bufsize %u offset %u\n", |
| 764 | s, s->tx_win_low, s->tx_win_high, s->tx_highest_seq, s->tx_buffer_size, s->tx_buffer_offset); |
| 765 | if (SEQUENCE_LTE(sequence, s->tx_win_low)) { |
| 766 | /* they're acking stuff we've already received an ack for */ |
| 767 | return; |
| 768 | } else if (SEQUENCE_GT(sequence, s->tx_highest_seq)) { |
| 769 | /* they're acking stuff we haven't sent */ |
| 770 | return; |
| 771 | } else { |
| 772 | /* their ack is somewhere in our window */ |
| 773 | uint32_t acked_len; |
| 774 | |
| 775 | acked_len = (sequence - s->tx_win_low); |
| 776 | |
| 777 | LTRACEF("acked len %u\n", acked_len); |
| 778 | |
| 779 | DEBUG_ASSERT(acked_len <= s->tx_buffer_size); |
| 780 | DEBUG_ASSERT(acked_len <= s->tx_buffer_offset); |
| 781 | |
| 782 | memmove(s->tx_buffer, s->tx_buffer + acked_len, s->tx_buffer_offset - acked_len); |
| 783 | |
| 784 | s->tx_buffer_offset -= acked_len; |
| 785 | s->tx_win_low += acked_len; |
| 786 | s->tx_win_high = s->tx_win_low + win_size; |
| 787 | |
| 788 | /* cancel or reset our retransmit timer */ |
| 789 | if (s->tx_win_low == s->tx_highest_seq) { |
| 790 | tcp_timer_cancel(s, &s->retransmit_timer); |
| 791 | } else { |
| 792 | tcp_timer_set(s, &s->retransmit_timer, &handle_retransmit_timeout, RETRANSMIT_TIMEOUT); |
| 793 | } |
| 794 | |
| 795 | /* we have opened the transmit buffer */ |
| 796 | event_signal(&s->tx_event, true); |
| 797 | } |
| 798 | } |
| 799 | |
| 800 | static ssize_t tcp_write_pending_data(tcp_socket_t *s) |
| 801 | { |
| 802 | LTRACEF("s %p, tx_win_low %u tx_win_high %u tx_highest_seq %u bufsize %u offset %u\n", |
| 803 | s, s->tx_win_low, s->tx_win_high, s->tx_highest_seq, s->tx_buffer_size, s->tx_buffer_offset); |
| 804 | |
| 805 | DEBUG_ASSERT(s); |
| 806 | DEBUG_ASSERT(is_mutex_held(&s->lock)); |
| 807 | DEBUG_ASSERT(s->tx_buffer_size > 0); |
| 808 | DEBUG_ASSERT(s->tx_buffer_offset <= s->tx_buffer_size); |
| 809 | |
| 810 | /* do we have any new data to send? */ |
| 811 | uint32_t outstanding = (s->tx_highest_seq - s->tx_win_low); |
| 812 | uint32_t pending = s->tx_buffer_offset - outstanding; |
| 813 | LTRACEF("outstanding %u, pending %u\n", outstanding, pending); |
| 814 | |
| 815 | /* send packets that cover the pending area of the window */ |
| 816 | uint32_t offset = 0; |
| 817 | while (offset < pending) { |
| 818 | uint32_t tosend = MIN(s->mss, pending - offset); |
| 819 | |
| 820 | tcp_socket_send(s, s->tx_buffer + outstanding + offset, tosend, PKT_ACK|PKT_PSH, NULL, 0, s->tx_highest_seq); |
| 821 | s->tx_highest_seq += tosend; |
| 822 | offset += tosend; |
| 823 | } |
| 824 | |
| 825 | /* reset the retransmit timer if we sent anything */ |
| 826 | if (offset > 0) { |
| 827 | tcp_timer_set(s, &s->retransmit_timer, &handle_retransmit_timeout, RETRANSMIT_TIMEOUT); |
| 828 | } |
| 829 | |
| 830 | return offset; |
| 831 | } |
| 832 | |
| 833 | static ssize_t tcp_retransmit(tcp_socket_t *s) |
| 834 | { |
| 835 | DEBUG_ASSERT(s); |
| 836 | DEBUG_ASSERT(is_mutex_held(&s->lock)); |
| 837 | |
| 838 | if (s->state != STATE_ESTABLISHED && s->state != STATE_CLOSE_WAIT) |
| 839 | return 0; |
| 840 | |
| 841 | /* how much data have we sent but not gotten an ack for? */ |
| 842 | uint32_t outstanding = (s->tx_highest_seq - s->tx_win_low); |
| 843 | if (outstanding == 0) |
| 844 | return 0; |
| 845 | |
| 846 | uint32_t tosend = MIN(s->mss, outstanding); |
| 847 | |
| 848 | LTRACEF("s %p, tosend %u seq %u\n", s, tosend, s->tx_win_low); |
| 849 | tcp_socket_send(s, s->tx_buffer, tosend, PKT_ACK|PKT_PSH, NULL, 0, s->tx_win_low); |
| 850 | |
| 851 | return tosend; |
| 852 | } |
| 853 | |
| 854 | static void handle_retransmit_timeout(void *_s) |
| 855 | { |
| 856 | tcp_socket_t *s = _s; |
| 857 | |
| 858 | LTRACEF("s %p\n", s); |
| 859 | |
| 860 | DEBUG_ASSERT(s); |
| 861 | |
| 862 | mutex_acquire(&s->lock); |
| 863 | |
| 864 | if (tcp_retransmit(s) == 0) |
| 865 | goto done; |
| 866 | |
| 867 | tcp_timer_set(s, &s->retransmit_timer, &handle_retransmit_timeout, RETRANSMIT_TIMEOUT); |
| 868 | |
| 869 | done: |
| 870 | mutex_release(&s->lock); |
| 871 | dec_socket_ref(s); |
| 872 | } |
| 873 | |
| 874 | static void handle_delayed_ack_timeout(void *_s) |
| 875 | { |
| 876 | tcp_socket_t *s = _s; |
| 877 | |
| 878 | LTRACEF("s %p\n", s); |
| 879 | |
| 880 | DEBUG_ASSERT(s); |
| 881 | |
| 882 | mutex_acquire(&s->lock); |
| 883 | send_ack(s); |
| 884 | mutex_release(&s->lock); |
| 885 | dec_socket_ref(s); |
| 886 | } |
| 887 | |
| 888 | static void handle_time_wait_timeout(void *_s) |
| 889 | { |
| 890 | tcp_socket_t *s = _s; |
| 891 | |
| 892 | LTRACEF("s %p\n", s); |
| 893 | |
| 894 | DEBUG_ASSERT(s); |
| 895 | |
| 896 | mutex_acquire(&s->lock); |
| 897 | |
| 898 | DEBUG_ASSERT(s->state == STATE_TIME_WAIT); |
| 899 | |
| 900 | /* remove us from the list and drop the last ref */ |
| 901 | remove_socket_from_list(s); |
| 902 | dec_socket_ref(s); |
| 903 | |
| 904 | mutex_release(&s->lock); |
| 905 | dec_socket_ref(s); |
| 906 | } |
| 907 | |
| 908 | static void tcp_wakeup_waiters(tcp_socket_t *s) |
| 909 | { |
| 910 | DEBUG_ASSERT(s); |
| 911 | DEBUG_ASSERT(is_mutex_held(&s->lock)); |
| 912 | |
| 913 | // wake up any waiters |
| 914 | event_signal(&s->rx_event, true); |
| 915 | event_signal(&s->tx_event, true); |
| 916 | } |
| 917 | |
| 918 | static void tcp_remote_close(tcp_socket_t *s) |
| 919 | { |
| 920 | LTRACEF("s %p, ref %d\n", s, s->ref); |
| 921 | |
| 922 | DEBUG_ASSERT(s); |
| 923 | DEBUG_ASSERT(is_mutex_held(&s->lock)); |
| 924 | DEBUG_ASSERT(s->ref > 0); |
| 925 | |
| 926 | if (s->state == STATE_CLOSED) |
| 927 | return; |
| 928 | |
| 929 | s->state = STATE_CLOSED; |
| 930 | |
| 931 | tcp_timer_cancel(s, &s->retransmit_timer); |
| 932 | tcp_timer_cancel(s, &s->ack_delay_timer); |
| 933 | |
| 934 | tcp_wakeup_waiters(s); |
| 935 | } |
| 936 | |
| 937 | static tcp_socket_t *create_tcp_socket(bool alloc_buffers) |
| 938 | { |
| 939 | tcp_socket_t *s; |
| 940 | |
| 941 | s = calloc(1, sizeof(tcp_socket_t)); |
| 942 | if (!s) |
| 943 | return NULL; |
| 944 | |
| 945 | mutex_init(&s->lock); |
| 946 | s->ref = 1; // start with the ref already bumped |
| 947 | |
| 948 | s->state = STATE_CLOSED; |
| 949 | s->rx_win_size = DEFAULT_RX_WINDOW_SIZE; |
| 950 | event_init(&s->rx_event, false, 0); |
| 951 | |
| 952 | s->mss = DEFAULT_MSS; |
| 953 | |
| 954 | s->tx_win_low = rand(); |
| 955 | s->tx_win_high = s->tx_win_low; |
| 956 | s->tx_highest_seq = s->tx_win_low; |
| 957 | event_init(&s->tx_event, true, 0); |
| 958 | |
| 959 | if (alloc_buffers) { |
| 960 | // XXX check for error |
| 961 | s->rx_buffer_raw = malloc(s->rx_win_size); |
| 962 | cbuf_initialize_etc(&s->rx_buffer, s->rx_win_size, s->rx_buffer_raw); |
| 963 | |
| 964 | s->tx_buffer_size = DEFAULT_TX_BUFFER_SIZE; |
| 965 | s->tx_buffer = malloc(s->tx_buffer_size); |
| 966 | } |
| 967 | |
| 968 | sem_init(&s->accept_sem, 0); |
| 969 | |
| 970 | return s; |
| 971 | } |
| 972 | |
| 973 | /* user api */ |
| 974 | |
| 975 | status_t tcp_open_listen(tcp_socket_t **handle, uint16_t port) |
| 976 | { |
| 977 | tcp_socket_t *s; |
| 978 | |
| 979 | if (!handle) |
| 980 | return ERR_INVALID_ARGS; |
| 981 | |
| 982 | s = create_tcp_socket(false); |
| 983 | if (!s) |
| 984 | return ERR_NO_MEMORY; |
| 985 | |
| 986 | // XXX see if there's another listen socket already on this port |
| 987 | |
| 988 | s->local_port = port; |
| 989 | |
| 990 | /* go to listen state */ |
| 991 | s->state = STATE_LISTEN; |
| 992 | |
| 993 | add_socket_to_list(s); |
| 994 | |
| 995 | *handle = s; |
| 996 | |
| 997 | return NO_ERROR; |
| 998 | } |
| 999 | |
| 1000 | status_t tcp_accept_timeout(tcp_socket_t *listen_socket, tcp_socket_t **accept_socket, lk_time_t timeout) |
| 1001 | { |
| 1002 | if (!listen_socket || !accept_socket) |
| 1003 | return ERR_INVALID_ARGS; |
| 1004 | |
| 1005 | tcp_socket_t *s = listen_socket; |
| 1006 | inc_socket_ref(s); |
| 1007 | |
| 1008 | /* block to accept a socket for an amount of time */ |
| 1009 | if (sem_timedwait(&s->accept_sem, timeout) == ERR_TIMED_OUT) { |
| 1010 | dec_socket_ref(s); |
| 1011 | return ERR_TIMED_OUT; |
| 1012 | } |
| 1013 | |
| 1014 | mutex_acquire(&s->lock); |
| 1015 | |
| 1016 | /* we got here, grab the accepted socket and return */ |
| 1017 | DEBUG_ASSERT(s->accepted); |
| 1018 | *accept_socket = s->accepted; |
| 1019 | s->accepted = NULL; |
| 1020 | |
| 1021 | mutex_release(&s->lock); |
| 1022 | dec_socket_ref(s); |
| 1023 | |
| 1024 | return NO_ERROR; |
| 1025 | } |
| 1026 | |
| 1027 | ssize_t tcp_read(tcp_socket_t *socket, void *buf, size_t len) |
| 1028 | { |
| 1029 | LTRACEF("socket %p, buf %p, len %zu\n", socket, buf, len); |
| 1030 | if (!socket) |
| 1031 | return ERR_INVALID_ARGS; |
| 1032 | if (len == 0) |
| 1033 | return 0; |
| 1034 | if (!buf) |
| 1035 | return ERR_INVALID_ARGS; |
| 1036 | |
| 1037 | tcp_socket_t *s = socket; |
| 1038 | inc_socket_ref(s); |
| 1039 | |
| 1040 | ssize_t ret = 0; |
| 1041 | retry: |
| 1042 | /* block on available data */ |
| 1043 | event_wait(&s->rx_event); |
| 1044 | |
| 1045 | mutex_acquire(&s->lock); |
| 1046 | |
| 1047 | /* try to read some data from the receive buffer, even if we're closed */ |
| 1048 | ret = cbuf_read(&s->rx_buffer, buf, len, false); |
| 1049 | if (ret == 0) { |
| 1050 | /* check to see if we've closed */ |
| 1051 | if (s->state != STATE_ESTABLISHED) { |
| 1052 | ret = ERR_CHANNEL_CLOSED; |
| 1053 | goto out; |
| 1054 | } |
| 1055 | |
| 1056 | /* we must have raced with another thread */ |
| 1057 | event_unsignal(&s->rx_event); |
| 1058 | mutex_release(&s->lock); |
| 1059 | goto retry; |
| 1060 | } |
| 1061 | |
| 1062 | /* if we've used up the last byte in the read buffer, unsignal the read event */ |
| 1063 | size_t remaining_bytes = cbuf_space_used(&s->rx_buffer); |
| 1064 | if (s->state == STATE_ESTABLISHED && remaining_bytes == 0) { |
| 1065 | event_unsignal(&s->rx_event); |
| 1066 | } |
| 1067 | |
| 1068 | /* we've read something, make sure the other end knows that our window is opening */ |
| 1069 | uint32_t new_rx_win_size = s->rx_win_size - remaining_bytes; |
| 1070 | |
| 1071 | /* if we've opened it enough, send an ack */ |
| 1072 | if (new_rx_win_size >= s->mss && s->rx_win_high - s->rx_win_low < s->mss) |
| 1073 | send_ack(s); |
| 1074 | |
| 1075 | out: |
| 1076 | mutex_release(&s->lock); |
| 1077 | dec_socket_ref(s); |
| 1078 | |
| 1079 | return ret; |
| 1080 | } |
| 1081 | |
| 1082 | ssize_t tcp_write(tcp_socket_t *socket, const void *buf, size_t len) |
| 1083 | { |
| 1084 | LTRACEF("socket %p, buf %p, len %zu\n", socket, buf, len); |
| 1085 | if (!socket) |
| 1086 | return ERR_INVALID_ARGS; |
| 1087 | if (len == 0) |
| 1088 | return 0; |
| 1089 | if (!buf) |
| 1090 | return ERR_INVALID_ARGS; |
| 1091 | |
| 1092 | tcp_socket_t *s = socket; |
| 1093 | inc_socket_ref(s); |
| 1094 | |
| 1095 | size_t off = 0; |
| 1096 | while (off < len) { |
| 1097 | LTRACEF("off %zu, len %zu\n", off, len); |
| 1098 | |
| 1099 | /* wait for the tx buffer to open up */ |
| 1100 | event_wait(&s->tx_event); |
| 1101 | LTRACEF("after event_wait\n"); |
| 1102 | |
| 1103 | mutex_acquire(&s->lock); |
| 1104 | |
| 1105 | /* check to see if we've closed */ |
| 1106 | if (s->state != STATE_ESTABLISHED && s->state != STATE_CLOSE_WAIT) { |
| 1107 | mutex_release(&s->lock); |
| 1108 | dec_socket_ref(s); |
| 1109 | return ERR_CHANNEL_CLOSED; |
| 1110 | } |
| 1111 | |
| 1112 | DEBUG_ASSERT(s->tx_buffer_size > 0); |
| 1113 | DEBUG_ASSERT(s->tx_buffer_offset <= s->tx_buffer_size); |
| 1114 | |
| 1115 | /* figure out how much data to copy in */ |
| 1116 | size_t to_copy = MIN(s->tx_buffer_size - s->tx_buffer_offset, len - off); |
| 1117 | if (to_copy == 0) { |
| 1118 | mutex_release(&s->lock); |
| 1119 | continue; |
| 1120 | } |
| 1121 | |
| 1122 | memcpy(s->tx_buffer + s->tx_buffer_offset, (uint8_t *)buf + off, to_copy); |
| 1123 | s->tx_buffer_offset += to_copy; |
| 1124 | |
| 1125 | /* if this has completely filled it, unsignal the event */ |
| 1126 | DEBUG_ASSERT(s->tx_buffer_offset <= s->tx_buffer_size); |
| 1127 | if (s->tx_buffer_offset == s->tx_buffer_size) { |
| 1128 | event_unsignal(&s->tx_event); |
| 1129 | } |
| 1130 | |
| 1131 | /* send as much data as we can */ |
| 1132 | tcp_write_pending_data(s); |
| 1133 | |
| 1134 | off += to_copy; |
| 1135 | |
| 1136 | mutex_release(&s->lock); |
| 1137 | } |
| 1138 | |
| 1139 | dec_socket_ref(s); |
| 1140 | return len; |
| 1141 | } |
| 1142 | |
| 1143 | status_t tcp_close(tcp_socket_t *socket) |
| 1144 | { |
| 1145 | if (!socket) |
| 1146 | return ERR_INVALID_ARGS; |
| 1147 | |
| 1148 | tcp_socket_t *s = socket; |
| 1149 | |
| 1150 | inc_socket_ref(s); |
| 1151 | mutex_acquire(&s->lock); |
| 1152 | |
| 1153 | LTRACEF("socket %p, state %d (%s), ref %d\n", s, s->state, tcp_state_to_string(s->state), s->ref); |
| 1154 | |
| 1155 | status_t err; |
| 1156 | switch (s->state) { |
| 1157 | case STATE_CLOSED: |
| 1158 | case STATE_LISTEN: |
| 1159 | /* we can directly remove this socket */ |
| 1160 | remove_socket_from_list(s); |
| 1161 | |
| 1162 | /* drop any timers that may be pending on this */ |
| 1163 | tcp_timer_cancel(s, &s->ack_delay_timer); |
| 1164 | tcp_timer_cancel(s, &s->retransmit_timer); |
| 1165 | |
| 1166 | s->state = STATE_CLOSED; |
| 1167 | |
| 1168 | /* drop the extra ref that was held when the socket was created */ |
| 1169 | dec_socket_ref(s); |
| 1170 | break; |
| 1171 | case STATE_SYN_RCVD: |
| 1172 | case STATE_ESTABLISHED: |
| 1173 | s->state = STATE_FIN_WAIT_1; |
| 1174 | tcp_socket_send(s, NULL, 0, PKT_ACK|PKT_FIN, NULL, 0, s->tx_win_low); |
| 1175 | s->tx_win_low++; |
| 1176 | |
| 1177 | /* stick around and wait for them to FIN us */ |
| 1178 | break; |
| 1179 | case STATE_CLOSE_WAIT: |
| 1180 | s->state = STATE_LAST_ACK; |
| 1181 | tcp_socket_send(s, NULL, 0, PKT_ACK|PKT_FIN, NULL, 0, s->tx_win_low); |
| 1182 | s->tx_win_low++; |
| 1183 | |
| 1184 | // XXX set up fin retransmit timer here |
| 1185 | break; |
| 1186 | case STATE_FIN_WAIT_1: |
| 1187 | case STATE_FIN_WAIT_2: |
| 1188 | case STATE_CLOSING: |
| 1189 | case STATE_TIME_WAIT: |
| 1190 | case STATE_LAST_ACK: |
| 1191 | /* these states are all post tcp_close(), so it's invalid to call it here */ |
| 1192 | err = ERR_CHANNEL_CLOSED; |
| 1193 | goto out; |
| 1194 | default: |
| 1195 | PANIC_UNIMPLEMENTED; |
| 1196 | } |
| 1197 | |
| 1198 | /* make sure anyone blocked on this wakes up */ |
| 1199 | tcp_wakeup_waiters(s); |
| 1200 | |
| 1201 | mutex_release(&s->lock); |
| 1202 | |
| 1203 | err = NO_ERROR; |
| 1204 | |
| 1205 | out: |
| 1206 | /* if this was the last ref, it should destroy the socket */ |
| 1207 | dec_socket_ref(s); |
| 1208 | |
| 1209 | return err; |
| 1210 | } |
| 1211 | |
| 1212 | /* debug stuff */ |
| 1213 | static int cmd_tcp(int argc, const cmd_args *argv) |
| 1214 | { |
| 1215 | status_t err; |
| 1216 | |
| 1217 | if (argc < 2) { |
| 1218 | notenoughargs: |
| 1219 | printf("ERROR not enough arguments\n"); |
| 1220 | usage: |
| 1221 | printf("usage: %s sockets\n", argv[0].str); |
| 1222 | printf("usage: %s listenclose <port>\n", argv[0].str); |
| 1223 | printf("usage: %s listen <port>\n", argv[0].str); |
| 1224 | printf("usage: %s debug\n", argv[0].str); |
| 1225 | return ERR_INVALID_ARGS; |
| 1226 | } |
| 1227 | |
| 1228 | if (!strcmp(argv[1].str, "sockets")) { |
| 1229 | |
| 1230 | mutex_acquire(&tcp_socket_list_lock); |
| 1231 | tcp_socket_t *s = NULL; |
| 1232 | list_for_every_entry(&tcp_socket_list, s, tcp_socket_t, node) { |
| 1233 | dump_socket(s); |
| 1234 | } |
| 1235 | mutex_release(&tcp_socket_list_lock); |
| 1236 | } else if (!strcmp(argv[1].str, "listenclose")) { |
| 1237 | /* listen for a connection, accept it, then immediately close it */ |
| 1238 | if (argc < 3) goto notenoughargs; |
| 1239 | |
| 1240 | tcp_socket_t *handle; |
| 1241 | |
| 1242 | err = tcp_open_listen(&handle, argv[2].u); |
| 1243 | printf("tcp_open_listen returns %d, handle %p\n", err, handle); |
| 1244 | |
| 1245 | tcp_socket_t *accepted; |
| 1246 | err = tcp_accept(handle, &accepted); |
| 1247 | printf("tcp_accept returns returns %d, handle %p\n", err, accepted); |
| 1248 | |
| 1249 | err = tcp_close(accepted); |
| 1250 | printf("tcp_close returns %d\n", err); |
| 1251 | |
| 1252 | err = tcp_close(handle); |
| 1253 | printf("tcp_close returns %d\n", err); |
| 1254 | } else if (!strcmp(argv[1].str, "listen")) { |
| 1255 | if (argc < 3) goto notenoughargs; |
| 1256 | |
| 1257 | tcp_socket_t *handle; |
| 1258 | |
| 1259 | err = tcp_open_listen(&handle, argv[2].u); |
| 1260 | printf("tcp_open_listen returns %d, handle %p\n", err, handle); |
| 1261 | |
| 1262 | tcp_socket_t *accepted; |
| 1263 | err = tcp_accept(handle, &accepted); |
| 1264 | printf("tcp_accept returns returns %d, handle %p\n", err, accepted); |
| 1265 | |
| 1266 | for (;;) { |
| 1267 | uint8_t buf[512]; |
| 1268 | |
| 1269 | ssize_t err = tcp_read(accepted, buf, sizeof(buf)); |
| 1270 | printf("tcp_read returns %ld\n", err); |
| 1271 | if (err < 0) |
| 1272 | break; |
| 1273 | if (err > 0) { |
| 1274 | hexdump8(buf, err); |
| 1275 | } |
| 1276 | |
| 1277 | err = tcp_write(accepted, buf, err); |
| 1278 | printf("tcp_write returns %ld\n", err); |
| 1279 | if (err < 0) |
| 1280 | break; |
| 1281 | } |
| 1282 | |
| 1283 | err = tcp_close(accepted); |
| 1284 | printf("tcp_close returns %d\n", err); |
| 1285 | |
| 1286 | err = tcp_close(handle); |
| 1287 | printf("tcp_close returns %d\n", err); |
| 1288 | } else if (!strcmp(argv[1].str, "debug")) { |
| 1289 | tcp_debug = !tcp_debug; |
| 1290 | printf("tcp debug now %u\n", tcp_debug); |
| 1291 | } else { |
| 1292 | printf("ERROR unknown command\n"); |
| 1293 | goto usage; |
| 1294 | } |
| 1295 | |
| 1296 | return NO_ERROR; |
| 1297 | } |
| 1298 | |
| 1299 | STATIC_COMMAND_START |
| 1300 | STATIC_COMMAND("tcp", "tcp commands", &cmd_tcp) |
| 1301 | STATIC_COMMAND_END(tcp); |
| 1302 | |
| 1303 | |
| 1304 | // vim: set ts=4 sw=4 expandtab: |