[Feature]add MT2731_MP2_MR2_SVN388 baseline version
Change-Id: Ief04314834b31e27effab435d3ca8ba33b499059
diff --git a/src/bsp/lk/lib/minip/tcp.c b/src/bsp/lk/lib/minip/tcp.c
new file mode 100644
index 0000000..6b367d6
--- /dev/null
+++ b/src/bsp/lk/lib/minip/tcp.c
@@ -0,0 +1,1304 @@
+/*
+ * Copyright (c) 2014 Travis Geiselbrecht
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining
+ * a copy of this software and associated documentation files
+ * (the "Software"), to deal in the Software without restriction,
+ * including without limitation the rights to use, copy, modify, merge,
+ * publish, distribute, sublicense, and/or sell copies of the Software,
+ * and to permit persons to whom the Software is furnished to do so,
+ * subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be
+ * included in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+#include "minip-internal.h"
+
+#include <trace.h>
+#include <assert.h>
+#include <compiler.h>
+#include <stdlib.h>
+#include <err.h>
+#include <string.h>
+#include <sys/types.h>
+#include <lib/console.h>
+#include <lib/cbuf.h>
+#include <kernel/mutex.h>
+#include <kernel/semaphore.h>
+#include <arch/ops.h>
+#include <platform.h>
+
+#define LOCAL_TRACE 0
+
+typedef uint32_t ipv4_addr;
+
+typedef struct tcp_header {
+ uint16_t source_port;
+ uint16_t dest_port;
+ uint32_t seq_num;
+ uint32_t ack_num;
+ uint16_t length_flags;
+ uint16_t win_size;
+ uint16_t checksum;
+ uint16_t urg_pointer;
+} __PACKED tcp_header_t;
+
+typedef struct tcp_pseudo_header {
+ ipv4_addr source_addr;
+ ipv4_addr dest_addr;
+ uint8_t zero;
+ uint8_t protocol;
+ uint16_t tcp_length;
+} __PACKED tcp_pseudo_header_t;
+
+typedef struct tcp_mss_option {
+ uint8_t kind; /* 0x2 */
+ uint8_t len; /* 0x4 */
+ uint16_t mss;
+} __PACKED tcp_mss_option_t;
+
+typedef enum tcp_state {
+ STATE_CLOSED,
+ STATE_LISTEN,
+ STATE_SYN_SENT,
+ STATE_SYN_RCVD,
+ STATE_ESTABLISHED,
+ STATE_CLOSE_WAIT,
+ STATE_LAST_ACK,
+ STATE_CLOSING,
+ STATE_FIN_WAIT_1,
+ STATE_FIN_WAIT_2,
+ STATE_TIME_WAIT
+} tcp_state_t;
+
+typedef enum tcp_flags {
+ PKT_FIN = 1,
+ PKT_SYN = 2,
+ PKT_RST = 4,
+ PKT_PSH = 8,
+ PKT_ACK = 16,
+ PKT_URG = 32
+} tcp_flags_t;
+
+typedef struct tcp_socket {
+ struct list_node node;
+
+ mutex_t lock;
+ volatile int ref;
+
+ tcp_state_t state;
+ ipv4_addr local_ip;
+ ipv4_addr remote_ip;
+ uint16_t local_port;
+ uint16_t remote_port;
+
+ uint32_t mss;
+
+ /* rx */
+ uint32_t rx_win_size;
+ uint32_t rx_win_low;
+ uint32_t rx_win_high;
+ uint8_t *rx_buffer_raw;
+ cbuf_t rx_buffer;
+ event_t rx_event;
+ int rx_full_mss_count; // number of packets we have received in a row with a full mss
+ net_timer_t ack_delay_timer;
+
+ /* tx */
+ uint32_t tx_win_low; // low side of the acked window
+ uint32_t tx_win_high; // tx_win_low + their advertised window size
+ uint32_t tx_highest_seq; // highest sequence we have txed them
+ uint8_t *tx_buffer; // our outgoing buffer
+ uint32_t tx_buffer_size; // size of tx_buffer
+ uint32_t tx_buffer_offset; // offset into the buffer to append new data to
+ event_t tx_event;
+ net_timer_t retransmit_timer;
+
+ /* listen accept */
+ semaphore_t accept_sem;
+ struct tcp_socket *accepted;
+
+ net_timer_t time_wait_timer;
+} tcp_socket_t;
+
+#define DEFAULT_MSS (1460)
+#define DEFAULT_RX_WINDOW_SIZE (8192)
+#define DEFAULT_TX_BUFFER_SIZE (8192)
+
+#define RETRANSMIT_TIMEOUT (50)
+#define DELAYED_ACK_TIMEOUT (50)
+#define TIME_WAIT_TIMEOUT (60000) // 1 minute
+
+#define FORCE_TCP_CHECKSUM (false)
+
+#define SEQUENCE_GTE(a, b) ((int32_t)((a) - (b)) >= 0)
+#define SEQUENCE_LTE(a, b) ((int32_t)((a) - (b)) <= 0)
+#define SEQUENCE_GT(a, b) ((int32_t)((a) - (b)) > 0)
+#define SEQUENCE_LT(a, b) ((int32_t)((a) - (b)) < 0)
+
+static mutex_t tcp_socket_list_lock = MUTEX_INITIAL_VALUE(tcp_socket_list_lock);
+static struct list_node tcp_socket_list = LIST_INITIAL_VALUE(tcp_socket_list);
+
+static bool tcp_debug = false;
+
+/* local routines */
+static tcp_socket_t *lookup_socket(ipv4_addr remote_ip, ipv4_addr local_ip, uint16_t remote_port, uint16_t local_port);
+static void add_socket_to_list(tcp_socket_t *s);
+static void remove_socket_from_list(tcp_socket_t *s);
+static tcp_socket_t *create_tcp_socket(bool alloc_buffers);
+static status_t tcp_send(ipv4_addr dest_ip, uint16_t dest_port, ipv4_addr src_ip, uint16_t src_port, const void *buf,
+ size_t len, tcp_flags_t flags, const void *options, size_t options_length, uint32_t ack, uint32_t sequence, uint16_t window_size);
+static status_t tcp_socket_send(tcp_socket_t *s, const void *data, size_t len, tcp_flags_t flags, const void *options, size_t options_length, uint32_t sequence);
+static void handle_data(tcp_socket_t *s, const void *data, size_t len, uint32_t sequence);
+static void send_ack(tcp_socket_t *s);
+static void handle_ack(tcp_socket_t *s, uint32_t sequence, uint32_t win_size);
+static void handle_retransmit_timeout(void *_s);
+static void handle_time_wait_timeout(void *_s);
+static void handle_delayed_ack_timeout(void *_s);
+static void tcp_remote_close(tcp_socket_t *s);
+static void tcp_wakeup_waiters(tcp_socket_t *s);
+static void inc_socket_ref(tcp_socket_t *s);
+static bool dec_socket_ref(tcp_socket_t *s);
+
+static uint16_t cksum_pheader(const tcp_pseudo_header_t *pheader, const void *buf, size_t len)
+{
+ uint16_t checksum = ones_sum16(0, pheader, sizeof(*pheader));
+ return ~ones_sum16(checksum, buf, len);
+}
+
+__NO_INLINE static void dump_tcp_header(const tcp_header_t *header)
+{
+ printf("TCP: src_port %u, dest_port %u, seq %u, ack %u, win %u, flags %c%c%c%c%c%c\n",
+ ntohs(header->source_port), ntohs(header->dest_port), ntohl(header->seq_num), ntohl(header->ack_num),
+ ntohs(header->win_size),
+ (ntohs(header->length_flags) & PKT_FIN) ? 'F' : ' ',
+ (ntohs(header->length_flags) & PKT_SYN) ? 'S' : ' ',
+ (ntohs(header->length_flags) & PKT_RST) ? 'R' : ' ',
+ (ntohs(header->length_flags) & PKT_PSH) ? 'P' : ' ',
+ (ntohs(header->length_flags) & PKT_ACK) ? 'A' : ' ',
+ (ntohs(header->length_flags) & PKT_URG) ? 'U' : ' ');
+}
+
+static const char *tcp_state_to_string(tcp_state_t state)
+{
+ switch (state) {
+ default:
+ case STATE_CLOSED: return "CLOSED";
+ case STATE_LISTEN: return "LISTEN";
+ case STATE_SYN_SENT: return "SYN_SENT";
+ case STATE_SYN_RCVD: return "SYN_RCVD";
+ case STATE_ESTABLISHED: return "ESTABLISHED";
+ case STATE_CLOSE_WAIT: return "CLOSE_WAIT";
+ case STATE_LAST_ACK: return "LAST_ACK";
+ case STATE_CLOSING: return "CLOSING";
+ case STATE_FIN_WAIT_1: return "FIN_WAIT_1";
+ case STATE_FIN_WAIT_2: return "FIN_WAIT_2";
+ case STATE_TIME_WAIT: return "TIME_WAIT";
+ }
+}
+
+static void dump_socket(tcp_socket_t *s)
+{
+ printf("socket %p: state %d (%s), local 0x%x:%hu, remote 0x%x:%hu, ref %d\n",
+ s, s->state, tcp_state_to_string(s->state),
+ s->local_ip, s->local_port, s->remote_ip, s->remote_port, s->ref);
+ if (s->state == STATE_ESTABLISHED || s->state == STATE_CLOSE_WAIT) {
+ printf("\trx: wsize %u wlo %u whi %u (%u)\n",
+ s->rx_win_size, s->rx_win_low, s->rx_win_high,
+ s->rx_win_high - s->rx_win_low);
+ printf("\ttx: wlo %u whi %u (%u) highest_seq %u (%u) bufsize %u bufoff %u\n",
+ s->tx_win_low, s->tx_win_high, s->tx_win_high - s->tx_win_low,
+ s->tx_highest_seq, s->tx_highest_seq - s->tx_win_low,
+ s->tx_buffer_size, s->tx_buffer_offset);
+ }
+}
+
+static tcp_socket_t *lookup_socket(ipv4_addr remote_ip, ipv4_addr local_ip, uint16_t remote_port, uint16_t local_port)
+{
+ LTRACEF("remote ip 0x%x local ip 0x%x remote port %u local port %u\n", remote_ip, local_ip, remote_port, local_port);
+
+ mutex_acquire(&tcp_socket_list_lock);
+
+ /* XXX replace with something faster, like a hash table */
+ tcp_socket_t *s = NULL;
+ list_for_every_entry(&tcp_socket_list, s, tcp_socket_t, node) {
+ if (s->state == STATE_CLOSED || s->state == STATE_LISTEN) {
+ continue;
+ } else {
+ /* full check */
+ if (s->remote_ip == remote_ip &&
+ s->local_ip == local_ip &&
+ s->remote_port == remote_port &&
+ s->local_port == local_port) {
+ goto out;
+ }
+ }
+ }
+
+ /* walk the list again, looking only for listen matches */
+ list_for_every_entry(&tcp_socket_list, s, tcp_socket_t, node) {
+ if (s->state == STATE_LISTEN) {
+ /* sockets in listen state only care about local port */
+ if (s->local_port == local_port) {
+ goto out;
+ }
+ }
+ }
+
+ /* fall through case returns null */
+ s = NULL;
+
+out:
+ /* bump the ref before returning it */
+ if (s)
+ inc_socket_ref(s);
+
+ mutex_release(&tcp_socket_list_lock);
+
+ return s;
+}
+
+static void add_socket_to_list(tcp_socket_t *s)
+{
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(s->ref > 0); // we should have implicitly bumped the ref when creating the socket
+
+ mutex_acquire(&tcp_socket_list_lock);
+
+ list_add_head(&tcp_socket_list, &s->node);
+
+ mutex_release(&tcp_socket_list_lock);
+}
+
+static void remove_socket_from_list(tcp_socket_t *s)
+{
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(s->ref > 0);
+
+ mutex_acquire(&tcp_socket_list_lock);
+
+ DEBUG_ASSERT(list_in_list(&s->node));
+ list_delete(&s->node);
+
+ mutex_release(&tcp_socket_list_lock);
+}
+
+static void inc_socket_ref(tcp_socket_t *s)
+{
+ DEBUG_ASSERT(s);
+
+ __UNUSED int oldval = atomic_add(&s->ref, 1);
+ LTRACEF("caller %p, thread %p, socket %p, ref now %d\n", __GET_CALLER(), get_current_thread(), s, oldval + 1);
+ DEBUG_ASSERT(oldval > 0);
+}
+
+static bool dec_socket_ref(tcp_socket_t *s)
+{
+ DEBUG_ASSERT(s);
+
+ int oldval = atomic_add(&s->ref, -1);
+ LTRACEF("caller %p, thread %p, socket %p, ref now %d\n", __GET_CALLER(), get_current_thread(), s, oldval - 1);
+
+ if (oldval == 1) {
+ LTRACEF("destroying socket\n");
+ event_destroy(&s->tx_event);
+ event_destroy(&s->rx_event);
+
+ free(s->rx_buffer_raw);
+ free(s->tx_buffer);
+
+ free(s);
+ }
+ return (oldval == 1);
+}
+
+static void tcp_timer_set(tcp_socket_t *s, net_timer_t *timer, net_timer_callback_t cb, lk_time_t delay)
+{
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(timer);
+
+ if (net_timer_set(timer, cb, s, delay))
+ inc_socket_ref(s);
+}
+
+static void tcp_timer_cancel(tcp_socket_t *s, net_timer_t *timer)
+{
+
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(timer);
+
+ if (net_timer_cancel(timer))
+ dec_socket_ref(s);
+}
+
+void tcp_input(pktbuf_t *p, uint32_t src_ip, uint32_t dst_ip)
+{
+ if (unlikely(tcp_debug))
+ TRACEF("p %p (len %u), src_ip 0x%x, dst_ip 0x%x\n", p, p->dlen, src_ip, dst_ip);
+
+ tcp_header_t *header = (tcp_header_t *)p->data;
+
+ /* reject if too small */
+ if (p->dlen < sizeof(tcp_header_t))
+ return;
+
+ if (unlikely(tcp_debug) || LOCAL_TRACE) {
+ dump_tcp_header(header);
+ }
+
+ /* compute the actual header length (+ options) */
+ size_t header_len = ((ntohs(header->length_flags) >> 12) & 0xf) * 4;
+ if (p->dlen < header_len) {
+ TRACEF("REJECT: packet too large for buffer\n");
+ return;
+ }
+
+ /* checksum */
+ if (FORCE_TCP_CHECKSUM || (p->flags & PKTBUF_FLAG_CKSUM_TCP_GOOD) == 0) {
+ tcp_pseudo_header_t pheader;
+
+ // set up the pseudo header for checksum purposes
+ pheader.source_addr = src_ip;
+ pheader.dest_addr = dst_ip;
+ pheader.zero = 0;
+ pheader.protocol = IP_PROTO_TCP;
+ pheader.tcp_length = htons(p->dlen);
+
+ uint16_t checksum = cksum_pheader(&pheader, p->data, p->dlen);
+ if(checksum != 0) {
+ TRACEF("REJECT: failed checksum, header says 0x%x, we got 0x%x\n", header->checksum, checksum);
+ return;
+ }
+ }
+
+ /* byte swap header in place */
+ header->source_port = ntohs(header->source_port);
+ header->dest_port = ntohs(header->dest_port);
+ header->seq_num = ntohl(header->seq_num);
+ header->ack_num = ntohl(header->ack_num);
+ header->length_flags = ntohs(header->length_flags);
+ header->win_size = ntohs(header->win_size);
+ header->urg_pointer = ntohs(header->urg_pointer);
+
+ /* get some data from the packet */
+ uint8_t packet_flags = header->length_flags & 0x3f;
+ size_t data_len = p->dlen - header_len;
+ uint32_t highest_sequence = header->seq_num + ((data_len > 0) ? (data_len - 1) : 0);
+
+ /* see if it matches a socket we have */
+ tcp_socket_t *s = lookup_socket(src_ip, dst_ip, header->source_port, header->dest_port);
+ if (!s) {
+ /* send a RST packet */
+ goto send_reset;
+ }
+
+ if (unlikely(tcp_debug))
+ TRACEF("got socket %p, state %d (%s), ref %d\n", s, s->state, tcp_state_to_string(s->state), s->ref);
+
+ /* remove the header */
+ pktbuf_consume(p, header_len);
+
+ mutex_acquire(&s->lock);
+
+ /* check to see if they're resetting us */
+ if (packet_flags & PKT_RST) {
+ if (s->state != STATE_CLOSED && s->state != STATE_LISTEN) {
+ tcp_remote_close(s);
+ }
+ goto done;
+ }
+
+ switch (s->state) {
+ case STATE_CLOSED:
+ /* socket closed, send RST */
+ goto send_reset;
+
+ /* passive connect states */
+ case STATE_LISTEN: {
+ /* we're in listen and they want to talk to us */
+ if (!(packet_flags & PKT_SYN)) {
+ /* not a SYN, send RST */
+ goto send_reset;
+ }
+
+ /* see if we have a slot to accept */
+ if (s->accepted != NULL)
+ goto done;
+
+ /* make a new accept socket */
+ tcp_socket_t *accept_socket = create_tcp_socket(true);
+ if (!accept_socket)
+ goto done;
+
+ /* set it up */
+ accept_socket->local_ip = minip_get_ipaddr();
+ accept_socket->local_port = s->local_port;
+ accept_socket->remote_ip = src_ip;
+ accept_socket->remote_port = header->source_port;
+ accept_socket->state = STATE_SYN_RCVD;
+
+ mutex_acquire(&accept_socket->lock);
+
+ add_socket_to_list(accept_socket);
+
+ /* remember their sequence */
+ accept_socket->rx_win_low = header->seq_num + 1;
+ accept_socket->rx_win_high = accept_socket->rx_win_low + accept_socket->rx_win_size - 1;
+
+ /* save this socket and wake anyone up that is waiting to accept */
+ s->accepted = accept_socket;
+ sem_post(&s->accept_sem, true);
+
+ /* set up a mss option for sending back */
+ tcp_mss_option_t mss_option;
+ mss_option.kind = 0x2;
+ mss_option.len = 0x4;
+ mss_option.mss = ntohs(s->mss); // XXX make sure we fit in their mss
+
+ /* send a response */
+ tcp_socket_send(accept_socket, NULL, 0, PKT_ACK|PKT_SYN, &mss_option, sizeof(mss_option),
+ accept_socket->tx_win_low);
+
+ /* SYN consumed a sequence */
+ accept_socket->tx_win_low++;
+
+ mutex_release(&accept_socket->lock);
+ break;
+ }
+ case STATE_SYN_RCVD:
+ if (packet_flags & PKT_SYN) {
+ /* they must have not seen our ack of their original syn, retransmit */
+ // XXX implement
+ goto send_reset;
+ }
+
+ /* if they ack our SYN, we can move on to ESTABLISHED */
+ if (packet_flags & PKT_ACK) {
+ if (header->ack_num != s->tx_win_low) {
+ goto send_reset;
+ }
+
+ s->tx_win_high = s->tx_win_low + header->win_size;
+ s->tx_highest_seq = s->tx_win_low;
+
+ s->state = STATE_ESTABLISHED;
+ } else {
+ goto send_reset;
+ }
+
+ break;
+
+ case STATE_ESTABLISHED:
+ if (packet_flags & PKT_ACK) {
+ /* they're acking us */
+ handle_ack(s, header->ack_num, header->win_size);
+ }
+
+ if (data_len > 0) {
+ LTRACEF("new data, len %zu\n", data_len);
+ handle_data(s, p->data, p->dlen, header->seq_num);
+ }
+
+ if ((packet_flags & PKT_FIN) && SEQUENCE_GTE(s->rx_win_low, highest_sequence)) {
+ /* they're closing with us, and there's no outstanding data */
+
+ /* FIN consumed a sequence */
+ s->rx_win_low++;
+
+ /* ack them and transition to new state */
+ send_ack(s);
+ s->state = STATE_CLOSE_WAIT;
+
+ /* wake up any read waiters */
+ event_signal(&s->rx_event, true);
+ }
+ break;
+
+ case STATE_CLOSE_WAIT:
+ if (packet_flags & PKT_ACK) {
+ /* they're acking us */
+ handle_ack(s, header->ack_num, header->win_size);
+ }
+ if (packet_flags & PKT_FIN) {
+ /* they must have missed our ack, ack them again */
+ send_ack(s);
+ }
+ break;
+ case STATE_LAST_ACK:
+ if (packet_flags & PKT_ACK) {
+ /* they're acking our FIN, probably */
+ tcp_remote_close(s);
+
+ /* tcp_close() was already called on us, remove us from the list and drop the ref */
+ remove_socket_from_list(s);
+ dec_socket_ref(s);
+ }
+ break;
+ case STATE_FIN_WAIT_1:
+ if (packet_flags & PKT_ACK) {
+ /* they're acking our FIN, probably */
+ s->state = STATE_FIN_WAIT_2;
+ /* drop into fin_wait_2 state logic, in case they were FINning us too */
+ goto fin_wait_2;
+ } else if (packet_flags & PKT_FIN) {
+ /* simultaneous close. they finned us without acking our fin */
+ s->rx_win_low++;
+ send_ack(s);
+ s->state = STATE_CLOSING;
+ }
+ break;
+ case STATE_FIN_WAIT_2:
+fin_wait_2:
+ if (packet_flags & PKT_FIN) {
+ /* they're FINning us, ack them */
+ s->rx_win_low++;
+ send_ack(s);
+ s->state = STATE_TIME_WAIT;
+
+ /* set timed wait timer */
+ tcp_timer_set(s, &s->time_wait_timer, &handle_time_wait_timeout, TIME_WAIT_TIMEOUT);
+ }
+ break;
+ case STATE_CLOSING:
+ if (packet_flags & PKT_ACK) {
+ /* they're acking our FIN, probably */
+ s->state = STATE_TIME_WAIT;
+
+ /* set timed wait timer */
+ tcp_timer_set(s, &s->time_wait_timer, &handle_time_wait_timeout, TIME_WAIT_TIMEOUT);
+ }
+ break;
+ case STATE_TIME_WAIT:
+ /* /dev/null of packets */
+ break;
+
+ case STATE_SYN_SENT:
+ PANIC_UNIMPLEMENTED;
+ }
+
+done:
+ mutex_release(&s->lock);
+ dec_socket_ref(s);
+ return;
+
+send_reset:
+ if (s) {
+ mutex_release(&s->lock);
+ dec_socket_ref(s);
+ }
+
+ LTRACEF("SEND RST\n");
+ if (!(packet_flags & PKT_RST)) {
+ tcp_send(src_ip, header->source_port, dst_ip, header->dest_port,
+ NULL, 0, PKT_RST, NULL, 0, 0, header->ack_num, 0);
+ }
+}
+
+static void handle_data(tcp_socket_t *s, const void *data, size_t len, uint32_t sequence)
+{
+ if (unlikely(tcp_debug))
+ TRACEF("data %p, len %zu, sequence %u\n", data, len, sequence);
+
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(is_mutex_held(&s->lock));
+ DEBUG_ASSERT(data);
+ DEBUG_ASSERT(len > 0);
+
+ /* see if it matches our current window */
+ uint32_t sequence_top = sequence + len - 1;
+ if (SEQUENCE_LTE(sequence, s->rx_win_low) && SEQUENCE_GTE(sequence_top, s->rx_win_low)) {
+ /* it intersects the bottom of our window, so it's in order */
+
+ /* copy the data we need to our cbuf */
+ size_t offset = sequence - s->rx_win_low;
+ size_t copy_len = MIN(s->rx_win_high - s->rx_win_low, len - offset);
+
+ DEBUG_ASSERT(offset < len);
+
+ LTRACEF("copying from offset %zu, len %zu\n", offset, copy_len);
+
+ s->rx_win_low += copy_len;
+
+ cbuf_write(&s->rx_buffer, (uint8_t *)data + offset, copy_len, false);
+ event_signal(&s->rx_event, true);
+
+ /* keep a counter if they've been sending a full mss */
+ if (copy_len >= s->mss) {
+ s->rx_full_mss_count++;
+ } else {
+ s->rx_full_mss_count = 0;
+ }
+
+ /* immediately ack if we're more than halfway into our buffer or they've sent 2 or more full packets */
+ if (s->rx_full_mss_count >= 2 ||
+ (int)(s->rx_win_low + s->rx_win_size - s->rx_win_high) > (int)s->rx_win_size / 2) {
+ send_ack(s);
+ s->rx_full_mss_count = 0;
+ } else {
+ tcp_timer_set(s, &s->ack_delay_timer, &handle_delayed_ack_timeout, DELAYED_ACK_TIMEOUT);
+ }
+ } else {
+ // either out of order or completely out of our window, drop
+ // duplicately ack the last thing we really got
+ send_ack(s);
+ }
+}
+
+static status_t tcp_socket_send(tcp_socket_t *s, const void *data, size_t len, tcp_flags_t flags,
+ const void *options, size_t options_length, uint32_t sequence)
+{
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(is_mutex_held(&s->lock));
+ DEBUG_ASSERT(len == 0 || data);
+ DEBUG_ASSERT(options_length == 0 || options);
+ DEBUG_ASSERT((options_length % 4) == 0);
+
+ // calculate the new right edge of the rx window
+ uint32_t rx_win_high = s->rx_win_low + s->rx_win_size - cbuf_space_used(&s->rx_buffer) - 1;
+
+ LTRACEF("rx_win_low %u rx_win_size %u read_buf_len %zu, new win high %u\n",
+ s->rx_win_low, s->rx_win_size, cbuf_space_used(&s->rx_buffer), rx_win_high);
+
+ uint16_t win_size;
+ if (SEQUENCE_GTE(rx_win_high, s->rx_win_high)) {
+ s->rx_win_high = rx_win_high;
+ win_size = rx_win_high - s->rx_win_low;
+ } else {
+ // the window size has shrunk, but we can't move the
+ // right edge of the window backwards
+ win_size = s->rx_win_high - s->rx_win_low;
+ }
+
+ // we are piggybacking a pending ACK, so clear the delayed ACK timer
+ if (flags & PKT_ACK) {
+ tcp_timer_cancel(s, &s->ack_delay_timer);
+ }
+
+ status_t err = tcp_send(s->remote_ip, s->remote_port, s->local_ip, s->local_port, data, len, flags,
+ options, options_length, (flags & PKT_ACK) ? s->rx_win_low : 0, sequence, win_size);
+
+ return err;
+}
+
+static void send_ack(tcp_socket_t *s)
+{
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(is_mutex_held(&s->lock));
+
+ if (s->state != STATE_ESTABLISHED && s->state != STATE_CLOSE_WAIT && s->state != STATE_FIN_WAIT_2)
+ return;
+
+ tcp_socket_send(s, NULL, 0, PKT_ACK, NULL, 0, s->tx_win_low);
+}
+
+static status_t tcp_send(ipv4_addr dest_ip, uint16_t dest_port, ipv4_addr src_ip, uint16_t src_port, const void *buf,
+ size_t len, tcp_flags_t flags, const void *options, size_t options_length, uint32_t ack, uint32_t sequence, uint16_t window_size)
+{
+ DEBUG_ASSERT(len == 0 || buf);
+ DEBUG_ASSERT(options_length == 0 || options);
+ DEBUG_ASSERT((options_length % 4) == 0);
+
+ pktbuf_t *p = pktbuf_alloc();
+ if (!p)
+ return ERR_NO_MEMORY;
+
+ tcp_header_t *header = pktbuf_prepend(p, sizeof(tcp_header_t) + options_length);
+ DEBUG_ASSERT(header);
+
+ /* fill in the header */
+ header->source_port = htons(src_port);
+ header->dest_port = htons(dest_port);
+ header->seq_num = htonl(sequence);
+ header->ack_num = htonl(ack);
+ header->length_flags = htons(((sizeof(tcp_header_t) + options_length) / 4) << 12 | flags);
+ header->win_size = htons(window_size);
+ header->checksum = 0;
+ header->urg_pointer = 0;
+ if (options)
+ memcpy(header + 1, options, options_length);
+
+ /* append the data */
+ if (len > 0)
+ pktbuf_append_data(p, buf, len);
+
+ /* compute the checksum */
+ /* XXX get the tx ckecksum capability from the nic */
+ if (FORCE_TCP_CHECKSUM || true) {
+ tcp_pseudo_header_t pheader;
+ pheader.source_addr = src_ip;
+ pheader.dest_addr = dest_ip;
+ pheader.zero = 0;
+ pheader.protocol = IP_PROTO_TCP;
+ pheader.tcp_length = htons(p->dlen);
+
+ header->checksum = cksum_pheader(&pheader, p->data, p->dlen);
+ }
+
+ if (LOCAL_TRACE) {
+ printf("sending ");
+ dump_tcp_header(header);
+ }
+
+ status_t err = minip_ipv4_send(p, dest_ip, IP_PROTO_TCP);
+
+ return err;
+}
+
+static void handle_ack(tcp_socket_t *s, uint32_t sequence, uint32_t win_size)
+{
+ LTRACEF("socket %p ack sequence %u, win_size %u\n", s, sequence, win_size);
+
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(is_mutex_held(&s->lock));
+
+ LTRACEF("s %p, tx_win_low %u tx_win_high %u tx_highest_seq %u bufsize %u offset %u\n",
+ s, s->tx_win_low, s->tx_win_high, s->tx_highest_seq, s->tx_buffer_size, s->tx_buffer_offset);
+ if (SEQUENCE_LTE(sequence, s->tx_win_low)) {
+ /* they're acking stuff we've already received an ack for */
+ return;
+ } else if (SEQUENCE_GT(sequence, s->tx_highest_seq)) {
+ /* they're acking stuff we haven't sent */
+ return;
+ } else {
+ /* their ack is somewhere in our window */
+ uint32_t acked_len;
+
+ acked_len = (sequence - s->tx_win_low);
+
+ LTRACEF("acked len %u\n", acked_len);
+
+ DEBUG_ASSERT(acked_len <= s->tx_buffer_size);
+ DEBUG_ASSERT(acked_len <= s->tx_buffer_offset);
+
+ memmove(s->tx_buffer, s->tx_buffer + acked_len, s->tx_buffer_offset - acked_len);
+
+ s->tx_buffer_offset -= acked_len;
+ s->tx_win_low += acked_len;
+ s->tx_win_high = s->tx_win_low + win_size;
+
+ /* cancel or reset our retransmit timer */
+ if (s->tx_win_low == s->tx_highest_seq) {
+ tcp_timer_cancel(s, &s->retransmit_timer);
+ } else {
+ tcp_timer_set(s, &s->retransmit_timer, &handle_retransmit_timeout, RETRANSMIT_TIMEOUT);
+ }
+
+ /* we have opened the transmit buffer */
+ event_signal(&s->tx_event, true);
+ }
+}
+
+static ssize_t tcp_write_pending_data(tcp_socket_t *s)
+{
+ LTRACEF("s %p, tx_win_low %u tx_win_high %u tx_highest_seq %u bufsize %u offset %u\n",
+ s, s->tx_win_low, s->tx_win_high, s->tx_highest_seq, s->tx_buffer_size, s->tx_buffer_offset);
+
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(is_mutex_held(&s->lock));
+ DEBUG_ASSERT(s->tx_buffer_size > 0);
+ DEBUG_ASSERT(s->tx_buffer_offset <= s->tx_buffer_size);
+
+ /* do we have any new data to send? */
+ uint32_t outstanding = (s->tx_highest_seq - s->tx_win_low);
+ uint32_t pending = s->tx_buffer_offset - outstanding;
+ LTRACEF("outstanding %u, pending %u\n", outstanding, pending);
+
+ /* send packets that cover the pending area of the window */
+ uint32_t offset = 0;
+ while (offset < pending) {
+ uint32_t tosend = MIN(s->mss, pending - offset);
+
+ tcp_socket_send(s, s->tx_buffer + outstanding + offset, tosend, PKT_ACK|PKT_PSH, NULL, 0, s->tx_highest_seq);
+ s->tx_highest_seq += tosend;
+ offset += tosend;
+ }
+
+ /* reset the retransmit timer if we sent anything */
+ if (offset > 0) {
+ tcp_timer_set(s, &s->retransmit_timer, &handle_retransmit_timeout, RETRANSMIT_TIMEOUT);
+ }
+
+ return offset;
+}
+
+static ssize_t tcp_retransmit(tcp_socket_t *s)
+{
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(is_mutex_held(&s->lock));
+
+ if (s->state != STATE_ESTABLISHED && s->state != STATE_CLOSE_WAIT)
+ return 0;
+
+ /* how much data have we sent but not gotten an ack for? */
+ uint32_t outstanding = (s->tx_highest_seq - s->tx_win_low);
+ if (outstanding == 0)
+ return 0;
+
+ uint32_t tosend = MIN(s->mss, outstanding);
+
+ LTRACEF("s %p, tosend %u seq %u\n", s, tosend, s->tx_win_low);
+ tcp_socket_send(s, s->tx_buffer, tosend, PKT_ACK|PKT_PSH, NULL, 0, s->tx_win_low);
+
+ return tosend;
+}
+
+static void handle_retransmit_timeout(void *_s)
+{
+ tcp_socket_t *s = _s;
+
+ LTRACEF("s %p\n", s);
+
+ DEBUG_ASSERT(s);
+
+ mutex_acquire(&s->lock);
+
+ if (tcp_retransmit(s) == 0)
+ goto done;
+
+ tcp_timer_set(s, &s->retransmit_timer, &handle_retransmit_timeout, RETRANSMIT_TIMEOUT);
+
+done:
+ mutex_release(&s->lock);
+ dec_socket_ref(s);
+}
+
+static void handle_delayed_ack_timeout(void *_s)
+{
+ tcp_socket_t *s = _s;
+
+ LTRACEF("s %p\n", s);
+
+ DEBUG_ASSERT(s);
+
+ mutex_acquire(&s->lock);
+ send_ack(s);
+ mutex_release(&s->lock);
+ dec_socket_ref(s);
+}
+
+static void handle_time_wait_timeout(void *_s)
+{
+ tcp_socket_t *s = _s;
+
+ LTRACEF("s %p\n", s);
+
+ DEBUG_ASSERT(s);
+
+ mutex_acquire(&s->lock);
+
+ DEBUG_ASSERT(s->state == STATE_TIME_WAIT);
+
+ /* remove us from the list and drop the last ref */
+ remove_socket_from_list(s);
+ dec_socket_ref(s);
+
+ mutex_release(&s->lock);
+ dec_socket_ref(s);
+}
+
+static void tcp_wakeup_waiters(tcp_socket_t *s)
+{
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(is_mutex_held(&s->lock));
+
+ // wake up any waiters
+ event_signal(&s->rx_event, true);
+ event_signal(&s->tx_event, true);
+}
+
+static void tcp_remote_close(tcp_socket_t *s)
+{
+ LTRACEF("s %p, ref %d\n", s, s->ref);
+
+ DEBUG_ASSERT(s);
+ DEBUG_ASSERT(is_mutex_held(&s->lock));
+ DEBUG_ASSERT(s->ref > 0);
+
+ if (s->state == STATE_CLOSED)
+ return;
+
+ s->state = STATE_CLOSED;
+
+ tcp_timer_cancel(s, &s->retransmit_timer);
+ tcp_timer_cancel(s, &s->ack_delay_timer);
+
+ tcp_wakeup_waiters(s);
+}
+
+static tcp_socket_t *create_tcp_socket(bool alloc_buffers)
+{
+ tcp_socket_t *s;
+
+ s = calloc(1, sizeof(tcp_socket_t));
+ if (!s)
+ return NULL;
+
+ mutex_init(&s->lock);
+ s->ref = 1; // start with the ref already bumped
+
+ s->state = STATE_CLOSED;
+ s->rx_win_size = DEFAULT_RX_WINDOW_SIZE;
+ event_init(&s->rx_event, false, 0);
+
+ s->mss = DEFAULT_MSS;
+
+ s->tx_win_low = rand();
+ s->tx_win_high = s->tx_win_low;
+ s->tx_highest_seq = s->tx_win_low;
+ event_init(&s->tx_event, true, 0);
+
+ if (alloc_buffers) {
+ // XXX check for error
+ s->rx_buffer_raw = malloc(s->rx_win_size);
+ cbuf_initialize_etc(&s->rx_buffer, s->rx_win_size, s->rx_buffer_raw);
+
+ s->tx_buffer_size = DEFAULT_TX_BUFFER_SIZE;
+ s->tx_buffer = malloc(s->tx_buffer_size);
+ }
+
+ sem_init(&s->accept_sem, 0);
+
+ return s;
+}
+
+/* user api */
+
+status_t tcp_open_listen(tcp_socket_t **handle, uint16_t port)
+{
+ tcp_socket_t *s;
+
+ if (!handle)
+ return ERR_INVALID_ARGS;
+
+ s = create_tcp_socket(false);
+ if (!s)
+ return ERR_NO_MEMORY;
+
+ // XXX see if there's another listen socket already on this port
+
+ s->local_port = port;
+
+ /* go to listen state */
+ s->state = STATE_LISTEN;
+
+ add_socket_to_list(s);
+
+ *handle = s;
+
+ return NO_ERROR;
+}
+
+status_t tcp_accept_timeout(tcp_socket_t *listen_socket, tcp_socket_t **accept_socket, lk_time_t timeout)
+{
+ if (!listen_socket || !accept_socket)
+ return ERR_INVALID_ARGS;
+
+ tcp_socket_t *s = listen_socket;
+ inc_socket_ref(s);
+
+ /* block to accept a socket for an amount of time */
+ if (sem_timedwait(&s->accept_sem, timeout) == ERR_TIMED_OUT) {
+ dec_socket_ref(s);
+ return ERR_TIMED_OUT;
+ }
+
+ mutex_acquire(&s->lock);
+
+ /* we got here, grab the accepted socket and return */
+ DEBUG_ASSERT(s->accepted);
+ *accept_socket = s->accepted;
+ s->accepted = NULL;
+
+ mutex_release(&s->lock);
+ dec_socket_ref(s);
+
+ return NO_ERROR;
+}
+
+ssize_t tcp_read(tcp_socket_t *socket, void *buf, size_t len)
+{
+ LTRACEF("socket %p, buf %p, len %zu\n", socket, buf, len);
+ if (!socket)
+ return ERR_INVALID_ARGS;
+ if (len == 0)
+ return 0;
+ if (!buf)
+ return ERR_INVALID_ARGS;
+
+ tcp_socket_t *s = socket;
+ inc_socket_ref(s);
+
+ ssize_t ret = 0;
+retry:
+ /* block on available data */
+ event_wait(&s->rx_event);
+
+ mutex_acquire(&s->lock);
+
+ /* try to read some data from the receive buffer, even if we're closed */
+ ret = cbuf_read(&s->rx_buffer, buf, len, false);
+ if (ret == 0) {
+ /* check to see if we've closed */
+ if (s->state != STATE_ESTABLISHED) {
+ ret = ERR_CHANNEL_CLOSED;
+ goto out;
+ }
+
+ /* we must have raced with another thread */
+ event_unsignal(&s->rx_event);
+ mutex_release(&s->lock);
+ goto retry;
+ }
+
+ /* if we've used up the last byte in the read buffer, unsignal the read event */
+ size_t remaining_bytes = cbuf_space_used(&s->rx_buffer);
+ if (s->state == STATE_ESTABLISHED && remaining_bytes == 0) {
+ event_unsignal(&s->rx_event);
+ }
+
+ /* we've read something, make sure the other end knows that our window is opening */
+ uint32_t new_rx_win_size = s->rx_win_size - remaining_bytes;
+
+ /* if we've opened it enough, send an ack */
+ if (new_rx_win_size >= s->mss && s->rx_win_high - s->rx_win_low < s->mss)
+ send_ack(s);
+
+out:
+ mutex_release(&s->lock);
+ dec_socket_ref(s);
+
+ return ret;
+}
+
+ssize_t tcp_write(tcp_socket_t *socket, const void *buf, size_t len)
+{
+ LTRACEF("socket %p, buf %p, len %zu\n", socket, buf, len);
+ if (!socket)
+ return ERR_INVALID_ARGS;
+ if (len == 0)
+ return 0;
+ if (!buf)
+ return ERR_INVALID_ARGS;
+
+ tcp_socket_t *s = socket;
+ inc_socket_ref(s);
+
+ size_t off = 0;
+ while (off < len) {
+ LTRACEF("off %zu, len %zu\n", off, len);
+
+ /* wait for the tx buffer to open up */
+ event_wait(&s->tx_event);
+ LTRACEF("after event_wait\n");
+
+ mutex_acquire(&s->lock);
+
+ /* check to see if we've closed */
+ if (s->state != STATE_ESTABLISHED && s->state != STATE_CLOSE_WAIT) {
+ mutex_release(&s->lock);
+ dec_socket_ref(s);
+ return ERR_CHANNEL_CLOSED;
+ }
+
+ DEBUG_ASSERT(s->tx_buffer_size > 0);
+ DEBUG_ASSERT(s->tx_buffer_offset <= s->tx_buffer_size);
+
+ /* figure out how much data to copy in */
+ size_t to_copy = MIN(s->tx_buffer_size - s->tx_buffer_offset, len - off);
+ if (to_copy == 0) {
+ mutex_release(&s->lock);
+ continue;
+ }
+
+ memcpy(s->tx_buffer + s->tx_buffer_offset, (uint8_t *)buf + off, to_copy);
+ s->tx_buffer_offset += to_copy;
+
+ /* if this has completely filled it, unsignal the event */
+ DEBUG_ASSERT(s->tx_buffer_offset <= s->tx_buffer_size);
+ if (s->tx_buffer_offset == s->tx_buffer_size) {
+ event_unsignal(&s->tx_event);
+ }
+
+ /* send as much data as we can */
+ tcp_write_pending_data(s);
+
+ off += to_copy;
+
+ mutex_release(&s->lock);
+ }
+
+ dec_socket_ref(s);
+ return len;
+}
+
+status_t tcp_close(tcp_socket_t *socket)
+{
+ if (!socket)
+ return ERR_INVALID_ARGS;
+
+ tcp_socket_t *s = socket;
+
+ inc_socket_ref(s);
+ mutex_acquire(&s->lock);
+
+ LTRACEF("socket %p, state %d (%s), ref %d\n", s, s->state, tcp_state_to_string(s->state), s->ref);
+
+ status_t err;
+ switch (s->state) {
+ case STATE_CLOSED:
+ case STATE_LISTEN:
+ /* we can directly remove this socket */
+ remove_socket_from_list(s);
+
+ /* drop any timers that may be pending on this */
+ tcp_timer_cancel(s, &s->ack_delay_timer);
+ tcp_timer_cancel(s, &s->retransmit_timer);
+
+ s->state = STATE_CLOSED;
+
+ /* drop the extra ref that was held when the socket was created */
+ dec_socket_ref(s);
+ break;
+ case STATE_SYN_RCVD:
+ case STATE_ESTABLISHED:
+ s->state = STATE_FIN_WAIT_1;
+ tcp_socket_send(s, NULL, 0, PKT_ACK|PKT_FIN, NULL, 0, s->tx_win_low);
+ s->tx_win_low++;
+
+ /* stick around and wait for them to FIN us */
+ break;
+ case STATE_CLOSE_WAIT:
+ s->state = STATE_LAST_ACK;
+ tcp_socket_send(s, NULL, 0, PKT_ACK|PKT_FIN, NULL, 0, s->tx_win_low);
+ s->tx_win_low++;
+
+ // XXX set up fin retransmit timer here
+ break;
+ case STATE_FIN_WAIT_1:
+ case STATE_FIN_WAIT_2:
+ case STATE_CLOSING:
+ case STATE_TIME_WAIT:
+ case STATE_LAST_ACK:
+ /* these states are all post tcp_close(), so it's invalid to call it here */
+ err = ERR_CHANNEL_CLOSED;
+ goto out;
+ default:
+ PANIC_UNIMPLEMENTED;
+ }
+
+ /* make sure anyone blocked on this wakes up */
+ tcp_wakeup_waiters(s);
+
+ mutex_release(&s->lock);
+
+ err = NO_ERROR;
+
+out:
+ /* if this was the last ref, it should destroy the socket */
+ dec_socket_ref(s);
+
+ return err;
+}
+
+/* debug stuff */
+static int cmd_tcp(int argc, const cmd_args *argv)
+{
+ status_t err;
+
+ if (argc < 2) {
+notenoughargs:
+ printf("ERROR not enough arguments\n");
+usage:
+ printf("usage: %s sockets\n", argv[0].str);
+ printf("usage: %s listenclose <port>\n", argv[0].str);
+ printf("usage: %s listen <port>\n", argv[0].str);
+ printf("usage: %s debug\n", argv[0].str);
+ return ERR_INVALID_ARGS;
+ }
+
+ if (!strcmp(argv[1].str, "sockets")) {
+
+ mutex_acquire(&tcp_socket_list_lock);
+ tcp_socket_t *s = NULL;
+ list_for_every_entry(&tcp_socket_list, s, tcp_socket_t, node) {
+ dump_socket(s);
+ }
+ mutex_release(&tcp_socket_list_lock);
+ } else if (!strcmp(argv[1].str, "listenclose")) {
+ /* listen for a connection, accept it, then immediately close it */
+ if (argc < 3) goto notenoughargs;
+
+ tcp_socket_t *handle;
+
+ err = tcp_open_listen(&handle, argv[2].u);
+ printf("tcp_open_listen returns %d, handle %p\n", err, handle);
+
+ tcp_socket_t *accepted;
+ err = tcp_accept(handle, &accepted);
+ printf("tcp_accept returns returns %d, handle %p\n", err, accepted);
+
+ err = tcp_close(accepted);
+ printf("tcp_close returns %d\n", err);
+
+ err = tcp_close(handle);
+ printf("tcp_close returns %d\n", err);
+ } else if (!strcmp(argv[1].str, "listen")) {
+ if (argc < 3) goto notenoughargs;
+
+ tcp_socket_t *handle;
+
+ err = tcp_open_listen(&handle, argv[2].u);
+ printf("tcp_open_listen returns %d, handle %p\n", err, handle);
+
+ tcp_socket_t *accepted;
+ err = tcp_accept(handle, &accepted);
+ printf("tcp_accept returns returns %d, handle %p\n", err, accepted);
+
+ for (;;) {
+ uint8_t buf[512];
+
+ ssize_t err = tcp_read(accepted, buf, sizeof(buf));
+ printf("tcp_read returns %ld\n", err);
+ if (err < 0)
+ break;
+ if (err > 0) {
+ hexdump8(buf, err);
+ }
+
+ err = tcp_write(accepted, buf, err);
+ printf("tcp_write returns %ld\n", err);
+ if (err < 0)
+ break;
+ }
+
+ err = tcp_close(accepted);
+ printf("tcp_close returns %d\n", err);
+
+ err = tcp_close(handle);
+ printf("tcp_close returns %d\n", err);
+ } else if (!strcmp(argv[1].str, "debug")) {
+ tcp_debug = !tcp_debug;
+ printf("tcp debug now %u\n", tcp_debug);
+ } else {
+ printf("ERROR unknown command\n");
+ goto usage;
+ }
+
+ return NO_ERROR;
+}
+
+STATIC_COMMAND_START
+STATIC_COMMAND("tcp", "tcp commands", &cmd_tcp)
+STATIC_COMMAND_END(tcp);
+
+
+// vim: set ts=4 sw=4 expandtab: