From c68c67ce7d88dcfc9db29e572d60e1c43ddb2519 Mon Sep 17 00:00:00 2001
From: "Alex Xu (Hello71)" <alex_y_xu@yahoo.ca>
Date: Wed, 6 Jul 2016 13:21:32 -0400
Subject: Miscellaneous improvements.

---
 src/checksum.c |  25 +++++----
 src/checksum.h |  40 --------------
 src/client.c   | 169 +++++++++++++++++++++++++++++++++------------------------
 3 files changed, 114 insertions(+), 120 deletions(-)

(limited to 'src')

diff --git a/src/checksum.c b/src/checksum.c
index 300237d..5c6415a 100644
--- a/src/checksum.c
+++ b/src/checksum.c
@@ -31,21 +31,24 @@
 
 /* Based on code from the Linux kernel. */
 
+#include <assert.h>
+#include <endian.h>
 #include <stdint.h>
+#include "checksum.h"
 
 /* Revised by Kenneth Albanowski for m68knommu. Basic problem: unaligned access
  kills, so most of the assembly has to go. */
 
-static inline unsigned short from32to16(unsigned int x)
+static inline uint16_t from32to16(uint32_t x)
 {
 	/* add up 16-bit and 16-bit for 16+c bit */
 	x = (x & 0xffff) + (x >> 16);
 	/* add up carry.. */
 	x = (x & 0xffff) + (x >> 16);
-	return x;
+	return (uint16_t)x;
 }
 
-static unsigned int do_csum(const unsigned char *buff, int len)
+static uint16_t do_csum(const unsigned char *buff, int len)
 {
 	int odd;
 	unsigned int result = 0;
@@ -55,24 +58,26 @@ static unsigned int do_csum(const unsigned char *buff, int len)
 	odd = 1 & (unsigned long) buff;
 	if (odd) {
 #if __BYTE_ORDER == __LITTLE_ENDIAN
-		result += (*buff << 8);
+		result += ((unsigned int)(*buff) << 8);
 #else
 		result = *buff;
 #endif
 		len--;
 		buff++;
 	}
+        assert(!((unsigned long)buff & 1));
 	if (len >= 2) {
-		if (2 & (unsigned long) buff) {
-			result += *(unsigned short *) buff;
+		if (2 & (const unsigned long) buff) {
+			result += *(const unsigned short *) buff;
 			len -= 2;
 			buff += 2;
 		}
+                assert(!((unsigned long)buff & 2));
 		if (len >= 4) {
-			const unsigned char *end = buff + ((unsigned)len & ~3);
+			const unsigned char *end = buff + ((unsigned)len & ~3u);
 			unsigned int carry = 0;
 			do {
-				unsigned int w = *(unsigned int *) buff;
+				unsigned int w = *(const unsigned int *) buff;
 				buff += 4;
 				result += carry;
 				result += w;
@@ -82,7 +87,7 @@ static unsigned int do_csum(const unsigned char *buff, int len)
 			result = (result & 0xffff) + (result >> 16);
 		}
 		if (len & 2) {
-			result += *(unsigned short *) buff;
+			result += *(const unsigned short *) buff;
 			buff += 2;
 		}
 	}
@@ -96,7 +101,7 @@ static unsigned int do_csum(const unsigned char *buff, int len)
 	if (odd)
 		result = ((result >> 8) & 0xff) | ((result & 0xff) << 8);
 out:
-	return result;
+	return (uint16_t)result;
 }
 
 /*
diff --git a/src/checksum.h b/src/checksum.h
index f788ff1..7d4a003 100644
--- a/src/checksum.h
+++ b/src/checksum.h
@@ -2,44 +2,4 @@
 #include <stdint.h>
 #include <stdlib.h>
 
-// based on code from RFCs 1071 and 1624
-
-/*
-static inline uint16_t csum_update(const void *ptr, uint16_t new_value, uint16_t wsum) {
-    uint32_t sum = *(uint16_t *)ptr + (~ntohs(*(uint16_t *)&new_value) & 0xffff) + ntohs(wsum);
-    sum = (sum & 0xffff) + (sum >> 16);
-    return htons(sum + (sum >> 16));
-}
-
-static inline uint16_t fold_sum(uint32_t sum) {
-    while (sum >> 16)
-        sum = (sum & 0xffff) + (sum >> 16);
-    return sum;
-}
-
-static inline uint16_t do_csum(const void *ptr, size_t len) {
-    uint32_t sum = 0;
-
-    while (len > 1) {
-        sum += *(uint16_t *)ptr++;
-        len -= 2;
-    }
-
-    if (len > 0)
-        sum += *(uint8_t *)ptr;
-
-    return ~fold_sum(sum);
-}
-
-static inline uint16_t csum_partial(uint16_t sum, const void *ptr, size_t len, ...) {
-    va_list ap;
-    va_start(ap, len);
-    do {
-        sum = ~fold_sum(~sum + ~do_csum(ptr, len));
-    } while ((ptr = va_arg(ap, const void *)) && (len = va_arg(ap, size_t)));
-    va_end(ap);
-    return sum;
-}
-*/
-
 uint16_t csum_partial(const void *buff, int len, uint16_t wsum);
diff --git a/src/client.c b/src/client.c
index 23d5eac..380bf42 100644
--- a/src/client.c
+++ b/src/client.c
@@ -20,6 +20,8 @@
 
 #define PORTS_IN_INT (sizeof(int) * CHAR_BIT)
 
+#define IN_ADDR_PORT(addr) (((struct sockaddr_in *)addr)->sin_port)
+
 struct c_data {
     const char *r_host;
     const char *r_port;
@@ -29,7 +31,6 @@ struct c_data {
     int s_sock;
     int i_sock;
     socklen_t s_addrlen;
-    uint16_t csum_p;
 };
 
 struct o_c_rsock {
@@ -61,8 +62,9 @@ struct o_c_sock {
 
 static struct c_data *global_c_data;
 
-static const uint8_t tcp_syn_retry_timeouts[] = { 3, 6, 12, 24, 0 };
+static const int8_t tcp_syn_retry_timeouts[] = { 0, 3, 6, 12, 24, -1 };
 
+/* check if a port offset is set in a int */
 static inline int check_resv_poff(unsigned int *used_ports, uint16_t poff) {
     if (used_ports[poff / PORTS_IN_INT] & (1 << (poff % PORTS_IN_INT)))
         return 0;
@@ -70,25 +72,23 @@ static inline int check_resv_poff(unsigned int *used_ports, uint16_t poff) {
     return poff;
 }
 
-/* reserve a local TCP port (local addr, remote addr, remote port are usually
- * fixed in the tuple) */
+/* reserve a local TCP port */
 static inline uint16_t reserve_port(unsigned int *used_ports) {
     long r;
 
-    // randomly try 16 places
+    // randomly try some places, hope this will give us reasonably uniform distribution
     for (int i = 1; i <= 16; i++) {
         r = random();
 
-        if (check_resv_poff(used_ports, r % 32768))
-            return 32768 + (r % 32768);
-
-        if (check_resv_poff(used_ports, (r >> 16) % 32768))
-            return 32768 + ((r >> 16) % 32768);
+        do {
+            if (check_resv_poff(used_ports, r % 32768))
+                return 32768 + (r % 32768);
+        } while (r >>= 16);
     }
 
     // give up and go sequentially
 
-    uint16_t ioff, spoff = (r >> 16) + 1;
+    uint16_t ioff, spoff = random();
     size_t moff, smoff = spoff / PORTS_IN_INT;
 
     /* two step process:
@@ -122,15 +122,21 @@ static void free_port(unsigned int *used_ports, uint16_t port_num) {
     used_ports[port_num / PORTS_IN_INT] ^= 1 << (port_num % PORTS_IN_INT);
 }
 
+/* prepare server address in TCP header format */
+static void c_prep_s_addr(struct o_c_sock *sock, struct tcphdr *hdr) {
+    hdr->th_sport = sock->l_port;
+    hdr->th_dport = IN_ADDR_PORT(sock->rsock->r_addr);
+    hdr->th_seq = htonl(sock->seq_num);
+    hdr->th_off = 5;
+}
+
+/* clean up a socket, don't bother freeing anything if the program is stopping */
 static void c_sock_cleanup(EV_P_ struct o_c_sock *sock, int stopping) {
     if (sock->status != TCP_SYN_SENT) {
         struct tcphdr buf = {
-            .th_sport = sock->l_port,
-            .th_dport = ((struct sockaddr_in *)sock->rsock->r_addr)->sin_port,
-            .th_seq = htonl(sock->seq_num),
-            .th_off = 5,
             .th_flags = sock->status == TCP_ESTABLISHED ? TH_FIN : TH_RST
         };
+        c_prep_s_addr(sock, &buf);
 
         ssize_t sz = send(sock->rsock->fd, &buf, sizeof(buf), 0);
         if (sz < 0) {
@@ -170,24 +176,49 @@ static void c_tm_cb(EV_P_ ev_timer *w, int revents __attribute__((unused))) {
     c_sock_cleanup(EV_A_ w->data, 0);
 }
 
+static int c_send_syn(struct o_c_sock *sock) {
+    struct tcphdr buf = {
+        .th_flags = TH_SYN
+    };
+    c_prep_s_addr(sock, &buf);
+
+    uint16_t tsz = htons(sizeof(buf));
+    buf.th_sum = ~csum_partial(&buf.th_seq, 16, csum_partial(&tsz, sizeof(tsz), sock->csum_p));
+
+    DBG("sending SYN to remote");
+    ssize_t sz = send(sock->rsock->fd, &buf, sizeof(buf), 0);
+    if (sz < 0) {
+        perror("send");
+        return 0;
+    } else if ((size_t)sz != sizeof(buf)) {
+        fprintf(stderr, "send %s our packet: tried %lu, sent %zd\n", (size_t)sz > sizeof(buf) ? "expanded" : "truncated", sizeof(buf), sz);
+    }
+
+    return 1;
+}
+
 static int c_adv_syn_tm(EV_P_ struct o_c_sock *sock) {
-    uint8_t next_retr = tcp_syn_retry_timeouts[sock->syn_retries++];
+    int8_t next_retr = tcp_syn_retry_timeouts[sock->syn_retries++];
+
+    if (next_retr < 0 || !c_send_syn(sock))
+        return 0;
+
     if (next_retr) {
         ev_timer_set(&sock->tm_w, next_retr, 0.);
         ev_timer_start(EV_A_ &sock->tm_w);
     }
-    return !!next_retr;
+
+    return 1;
 }
 
 static void c_syn_tm_cb(EV_P_ ev_timer *w, int revents __attribute__((unused))) {
-    if (c_adv_syn_tm(EV_A_ w->data)) {
-        // resend SYN
-    } else {
+    if (!c_adv_syn_tm(EV_A_ w->data)) {
         DBG("connection timed out");
         c_sock_cleanup(EV_A_ w->data, 0);
     }
 }
 
+/* client raw socket callback */
 static void cc_cb(struct ev_loop *loop, ev_io *w, int revents __attribute__((unused))) {
     DBG("-- entering cc_cb --");
 
@@ -222,14 +253,14 @@ static void cc_cb(struct ev_loop *loop, ev_io *w, int revents __attribute__((unu
             sock->status = TCP_ESTABLISHED;
 
             struct tcphdr shdr = {
-                .th_sport = sock->l_port,
-                .th_dport = ((struct sockaddr_in *)sock->rsock->r_addr)->sin_port,
-                .th_seq = htonl(sock->seq_num),
                 .th_ack = rhdr->th_seq,
                 .th_win = 65535,
-                .th_flags = TH_ACK,
-                .th_off = 5
+                .th_flags = TH_ACK
             };
+            c_prep_s_addr(sock, &shdr);
+
+            uint16_t tsz = htons(sizeof(shdr) + sock->pending_data_size);
+            shdr.th_sum = ~csum_partial(sock->pending_data, sock->pending_data_size, csum_partial(&shdr.th_seq, 16, csum_partial(&tsz, sizeof(tsz), sock->csum_p)));
 
             sock->seq_num += sock->pending_data_size;
 
@@ -267,26 +298,35 @@ static void cc_cb(struct ev_loop *loop, ev_io *w, int revents __attribute__((unu
             ev_timer_start(EV_A_ &sock->tm_w);
         }
 
-        should_ssz = rsz - rhdr->th_off * 32 / CHAR_BIT;
-        if (should_ssz > 0) {
-            DBG("sending %zd bytes to client", should_ssz);
-            ssz = sendto(rsock->c_data->s_sock, rbuf + rhdr->th_off * 32 / CHAR_BIT, should_ssz, 0, sock->c_address, rsock->c_data->s_addrlen);
+        if (rhdr->th_flags & ~(TH_PUSH | TH_ACK)) {
+            DBG("packet has strange flags, dropping");
+            return;
+        }
 
-            if (ssz < 0) {
-                perror("sendto");
-                ev_break(EV_A_ EVBREAK_ONE);
-                return;
-            } else if ((size_t)ssz != should_ssz) {
-                fprintf(stderr, "sendto %s our packet: tried %lu, sent %zd\n", (size_t)ssz > should_ssz ? "expanded" : "truncated", should_ssz, ssz);
+        if (sock->status == TCP_ESTABLISHED) {
+            should_ssz = rsz - rhdr->th_off * 32 / CHAR_BIT;
+            if (should_ssz > 0) {
+                DBG("sending %zd bytes to client", should_ssz);
+                ssz = sendto(rsock->c_data->s_sock, rbuf + rhdr->th_off * 32 / CHAR_BIT, should_ssz, 0, sock->c_address, rsock->c_data->s_addrlen);
+
+                if (ssz < 0) {
+                    perror("sendto");
+                    ev_break(EV_A_ EVBREAK_ONE);
+                    return;
+                } else if ((size_t)ssz != should_ssz) {
+                    fprintf(stderr, "sendto %s our packet: tried %lu, sent %zd\n", (size_t)ssz > should_ssz ? "expanded" : "truncated", should_ssz, ssz);
+                }
             }
         }
     }
+
     if (errno != EAGAIN) {
         perror("recvfrom");
         ev_break(EV_A_ EVBREAK_ONE);
     }
 }
 
+/* initialize new raw socket */
 static inline struct o_c_rsock * c_rsock_init(struct addrinfo *res) {
     struct o_c_rsock *rsock;
     rsock = malloc(sizeof(*rsock));
@@ -324,28 +364,33 @@ static inline struct o_c_rsock * c_rsock_init(struct addrinfo *res) {
 
     char proto[] = { 0, IPPROTO_TCP };
 
-    if (((struct sockaddr *)rsock->r_addr)->sa_family != our_addr.ss_family)
+    if (rsock->r_addr->sa_family != our_addr.ss_family)
         abort();
 
+    size_t addr_offset, addr_size;
+
     switch (our_addr.ss_family) {
     case AF_INET:
-        rsock->csum_p = csum_partial(&((struct sockaddr_in *)&our_addr)->sin_addr, sizeof(in_addr_t),
-                csum_partial(&((struct sockaddr_in *)rsock->r_addr)->sin_addr, sizeof(in_addr_t), 0));
+        addr_offset = offsetof(struct sockaddr_in, sin_addr);
+        addr_size = sizeof(in_addr_t);
         break;
     case AF_INET6:
-        rsock->csum_p = csum_partial(&((struct sockaddr_in6 *)&our_addr)->sin6_addr, sizeof(struct in6_addr),
-                csum_partial(&((struct sockaddr_in6 *)rsock->r_addr)->sin6_addr, sizeof(struct in6_addr), 0));
+        addr_offset = offsetof(struct sockaddr_in6, sin6_addr);
+        addr_size = sizeof(struct in6_addr);
         break;
     default:
         abort();
     }
 
-    rsock->csum_p = csum_partial(&((struct sockaddr_in *)rsock->r_addr)->sin_port, sizeof(in_port_t),
-            csum_partial(proto, sizeof(proto), rsock->csum_p));
+    rsock->csum_p = csum_partial(&IN_ADDR_PORT(rsock->r_addr), sizeof(in_port_t),
+            csum_partial(proto, sizeof(proto),
+            csum_partial((char *)&our_addr + addr_offset, addr_size,
+            csum_partial((char *)rsock->r_addr + addr_offset, addr_size, 0))));
 
     return rsock;
 }
 
+/* client UDP socket callback */
 static void cs_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) {
     DBG("-- entering cs_cb --");
     struct c_data *c_data = w->data;
@@ -367,7 +412,7 @@ static void cs_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) {
             sock = calloc(1, sizeof(*sock));
 
             struct addrinfo *res;
-            DBG("looking up %s:%s", c_data->r_host, c_data->r_port);
+            DBG("looking up [%s]:%s", c_data->r_host, c_data->r_port);
             // TODO: make this asynchronous
             int r = getaddrinfo(c_data->r_host, c_data->r_port, NULL, &res);
             if (r) {
@@ -414,34 +459,13 @@ static void cs_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) {
 
             sock->seq_num = random();
 
-            struct tcphdr buf = {
-                .th_sport = sock->l_port,
-                .th_dport = ((struct sockaddr_in *)sock->rsock->r_addr)->sin_port,
-                .th_seq = htonl(sock->seq_num),
-                .th_flags = TH_SYN,
-                .th_off = 5,
-            };
-            uint16_t tsz = htons(sizeof(buf));
-            buf.th_sum = ~csum_partial(&buf.th_seq, 16, csum_partial(&tsz, sizeof(tsz), sock->csum_p));
-
             sock->pending_data = malloc(sz);
             memcpy(sock->pending_data, rbuf, sz);
             sock->pending_data_size = sz;
 
-            DBG("sending SYN to remote");
-            sz = send(sock->rsock->fd, &buf, sizeof(buf), 0);
-            if (sz < 0) {
-                perror("send");
-                ev_break(EV_A_ EVBREAK_ONE);
-                return;
-            } else if ((size_t)sz != sizeof(buf)) {
-                fprintf(stderr, "send %s our packet: tried %lu, sent %zd\n", (size_t)sz > sizeof(buf) ? "expanded" : "truncated", sizeof(buf), sz);
-            }
-
-            // resend SYN
-
-            ev_timer_init(&sock->tm_w, c_syn_tm_cb, 0., tcp_syn_retry_timeouts[0]);
+            ev_init(&sock->tm_w, c_syn_tm_cb);
             sock->tm_w.data = sock;
+
             sock->syn_retries = 0;
             c_adv_syn_tm(EV_A_ sock);
 
@@ -451,13 +475,13 @@ static void cs_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) {
         }
 
         struct tcphdr tcp_hdr = {
-            .th_sport = sock->l_port,
-            .th_dport = ((struct sockaddr_in *)sock->rsock->r_addr)->sin_port,
-            .th_seq = htonl(sock->seq_num),
-            .th_off = 5,
             .th_win = 65535,
             .th_flags = TH_PUSH
         };
+        c_prep_s_addr(sock, &tcp_hdr);
+
+        uint16_t tsz = htons(sizeof(tcp_hdr) + sz);
+        tcp_hdr.th_sum = ~csum_partial(rbuf, sz, csum_partial(&tcp_hdr.th_seq, 16, csum_partial(&tsz, sizeof(tsz), sock->csum_p)));
 
         sock->seq_num += sz;
 
@@ -477,6 +501,10 @@ static void cs_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) {
         DBG("sending %zd raw bytes containing %zd bytes payload to remote", should_send_size, sz);
         sz = sendmsg(sock->rsock->fd, &msghdr, 0);
         if (sz < 0) {
+            if (errno == ENOBUFS) {
+                fprintf(stderr, "sendmsg: out of buffer space\n");
+                return;
+            }
             perror("sendmsg");
             ev_break(EV_A_ EVBREAK_ONE);
             return;
@@ -491,6 +519,7 @@ static void cs_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) {
     }
 }
 
+/* atexit cleanup */
 static void c_cleanup() {
     if (!global_c_data)
         return;
-- 
cgit v1.2.3-70-g09d2