diff options
-rw-r--r-- | Makefile | 6 | ||||
-rw-r--r-- | src/checksum.c | 16 | ||||
-rw-r--r-- | src/checksum.h | 12 | ||||
-rw-r--r-- | src/client.c | 46 | ||||
-rw-r--r-- | src/common.h | 2 | ||||
-rw-r--r-- | src/server.c | 45 | ||||
-rwxr-xr-x | test.sh | 27 |
7 files changed, 115 insertions, 39 deletions
@@ -2,13 +2,13 @@ CFLAGS += -Wall -Wextra -Wwrite-strings -std=c99 -D_BSD_SOURCE -D_DEFAULT_SOURCE LDLIBS := -lev -NET_OBJS := src/client.o src/server.o -OBJS := src/udpastcp.o src/checksum.o $(NET_OBJS) +NET_OBJS := src/checksum.o src/client.o src/server.o +OBJS := src/udpastcp.o $(NET_OBJS) udpastcp: $(OBJS) $(LINK.c) $^ $(LOADLIBES) $(LDLIBS) -o $@ -# networking code needs aliasing to be efficient +# networking code needs aliasing to work at all $(NET_OBJS): CFLAGS+=-fno-strict-aliasing -Wno-sign-compare clean: diff --git a/src/checksum.c b/src/checksum.c index 5c6415a..53d4e5a 100644 --- a/src/checksum.c +++ b/src/checksum.c @@ -33,6 +33,7 @@ #include <assert.h> #include <endian.h> +#include <netinet/in.h> #include <stdint.h> #include "checksum.h" @@ -127,3 +128,18 @@ uint16_t csum_partial(const void *buff, int len, uint16_t wsum) result += 1; return result; } + +uint16_t csum_sockaddr_partial(const struct sockaddr *addr, int incl_port, uint16_t wsum) +{ + if (incl_port) + wsum = csum_partial(&((struct sockaddr_in *)addr)->sin_port, sizeof(in_port_t), wsum); + + switch (addr->sa_family) { + case AF_INET: + return csum_partial(&((struct sockaddr_in *)addr)->sin_addr, sizeof(struct in_addr), wsum); + case AF_INET6: + return csum_partial(&((struct sockaddr_in6 *)addr)->sin6_addr, sizeof(struct in6_addr), wsum); + default: + abort(); + } +} diff --git a/src/checksum.h b/src/checksum.h index 7d4a003..700ee1a 100644 --- a/src/checksum.h +++ b/src/checksum.h @@ -1,5 +1,17 @@ +#include <netinet/in.h> #include <stdarg.h> #include <stdint.h> #include <stdlib.h> +/* calculates the checksum of len bytes at buff when combined with wsum. + * return value is already in network order, but must be inverted before + * sending. + * example: hdr.th_sum = csum_partial(data, len, csum_partial(hdr, hdrlen, 0)); + */ uint16_t csum_partial(const void *buff, int len, uint16_t wsum); + +/* calculates the checksum of a sockaddr_in or sockaddr_in6. + * if incl_port is set then the sin_port will be included. + * otherwise identical to csum_partial. + */ +uint16_t csum_sockaddr_partial(const struct sockaddr *addr, int incl_port, uint16_t wsum); diff --git a/src/client.c b/src/client.c index 380bf42..c78f3b1 100644 --- a/src/client.c +++ b/src/client.c @@ -5,6 +5,7 @@ #include <limits.h> #include <netdb.h> #include <netinet/in.h> +#include <netinet/ip.h> #include <netinet/tcp.h> #include <stdint.h> #include <stdio.h> @@ -20,8 +21,6 @@ #define PORTS_IN_INT (sizeof(int) * CHAR_BIT) -#define IN_ADDR_PORT(addr) (((struct sockaddr_in *)addr)->sin_port) - struct c_data { const char *r_host; const char *r_port; @@ -229,14 +228,31 @@ static void cc_cb(struct ev_loop *loop, ev_io *w, int revents __attribute__((unu while ((rsz = recvfrom(w->fd, rbuf, sizeof(rbuf), 0, (struct sockaddr *)&rsock->c_data->pkt_addr, &pkt_addrlen)) != -1) { DBG("received %zd raw bytes on client", rsz); + DBG("%u %zu", pkt_addrlen, sizeof(struct sockaddr_in6)); if (pkt_addrlen > sizeof(struct sockaddr_in6)) abort(); + char *rptr = rbuf; + + if (rsock->r_addr->sa_family == AF_INET) { + if ((size_t)rsz < sizeof(struct iphdr)) { + DBG("packet is smaller than IP header, ignoring"); + return; + } + + if (((struct iphdr *)rptr)->protocol != IPPROTO_TCP) + abort(); + + uint32_t ihl = ((struct iphdr *)rptr)->ihl * 4; + rptr = rptr + ihl; + rsz -= ihl; + } + if ((size_t)rsz < sizeof(struct tcphdr)) return; - struct tcphdr *rhdr = (struct tcphdr *)rbuf; + struct tcphdr *rhdr = (struct tcphdr *)rptr; struct o_c_sock *sock; @@ -307,7 +323,7 @@ static void cc_cb(struct ev_loop *loop, ev_io *w, int revents __attribute__((unu should_ssz = rsz - rhdr->th_off * 32 / CHAR_BIT; if (should_ssz > 0) { DBG("sending %zd bytes to client", should_ssz); - ssz = sendto(rsock->c_data->s_sock, rbuf + rhdr->th_off * 32 / CHAR_BIT, should_ssz, 0, sock->c_address, rsock->c_data->s_addrlen); + ssz = sendto(rsock->c_data->s_sock, rptr + rhdr->th_off * 32 / CHAR_BIT, should_ssz, 0, sock->c_address, rsock->c_data->s_addrlen); if (ssz < 0) { perror("sendto"); @@ -367,25 +383,9 @@ static inline struct o_c_rsock * c_rsock_init(struct addrinfo *res) { if (rsock->r_addr->sa_family != our_addr.ss_family) abort(); - size_t addr_offset, addr_size; - - switch (our_addr.ss_family) { - case AF_INET: - addr_offset = offsetof(struct sockaddr_in, sin_addr); - addr_size = sizeof(in_addr_t); - break; - case AF_INET6: - addr_offset = offsetof(struct sockaddr_in6, sin6_addr); - addr_size = sizeof(struct in6_addr); - break; - default: - abort(); - } - - rsock->csum_p = csum_partial(&IN_ADDR_PORT(rsock->r_addr), sizeof(in_port_t), - csum_partial(proto, sizeof(proto), - csum_partial((char *)&our_addr + addr_offset, addr_size, - csum_partial((char *)rsock->r_addr + addr_offset, addr_size, 0)))); + rsock->csum_p = csum_partial(proto, sizeof(proto), + csum_sockaddr_partial((struct sockaddr *)&our_addr, 0, + csum_sockaddr_partial(rsock->r_addr, 1, 0))); return rsock; } diff --git a/src/common.h b/src/common.h index f78b2ba..5800581 100644 --- a/src/common.h +++ b/src/common.h @@ -3,3 +3,5 @@ #else #define DBG(...) #endif + +#define IN_ADDR_PORT(addr) (((struct sockaddr_in *)addr)->sin_port) diff --git a/src/server.c b/src/server.c index 462212f..098c313 100644 --- a/src/server.c +++ b/src/server.c @@ -4,6 +4,7 @@ #include <fcntl.h> #include <netdb.h> #include <netinet/in.h> +#include <netinet/ip.h> #include <netinet/tcp.h> #include <stdint.h> #include <stdio.h> @@ -11,7 +12,9 @@ #include <string.h> #include <sys/socket.h> #include <unistd.h> + #include "common.h" +#include "checksum.h" #include "server.h" #include "uthash.h" @@ -23,6 +26,7 @@ struct o_s_sock { UT_hash_handle hh; int c_sock; uint16_t seq_num; + uint16_t csum_p; uint8_t status; }; @@ -34,6 +38,7 @@ struct s_data { struct o_s_sock *o_socks_by_caddr; int s_sock; socklen_t s_addrlen; + uint16_t csum_p; }; static inline void s_prep_c_addr(struct o_s_sock *sock, struct tcphdr *hdr) { @@ -153,6 +158,22 @@ static void ss_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) { if (c_addrlen != s_data->s_addrlen) abort(); + char *rptr = rbuf; + + if (s_data->s_addr->sa_family == AF_INET) { + if ((size_t)sz < sizeof(struct iphdr)) { + DBG("packet is smaller than IP header, ignoring"); + return; + } + + if (((struct iphdr *)rbuf)->protocol != IPPROTO_TCP) + abort(); + + uint32_t ihl = ((struct iphdr *)rbuf)->ihl * 4; + rptr = rbuf + ihl; + sz -= ihl; + } + #ifdef DEBUG char hbuf[NI_MAXHOST]; r = getnameinfo((struct sockaddr *)&s_data->pkt_addr, c_addrlen, hbuf, sizeof(hbuf), NULL, 0, NI_NUMERICHOST); @@ -161,7 +182,7 @@ static void ss_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) { ev_break(EV_A_ EVBREAK_ONE); return; } - DBG("received %zd bytes from %s", sz, hbuf); + DBG("received %zd payload bytes from %s", sz, hbuf); #endif if ((size_t)sz < sizeof(struct tcphdr)) { @@ -169,7 +190,7 @@ static void ss_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) { return; } - struct tcphdr *tcphdr = (struct tcphdr *)rbuf; + struct tcphdr *tcphdr = (struct tcphdr *)rptr; DBG("packet received on port %hu", ntohs(tcphdr->th_dport)); @@ -201,6 +222,8 @@ static void ss_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) { sock->c_sock = -1; sock->status = TCP_SYN_RECV; + sock->csum_p = csum_sockaddr_partial((struct sockaddr *)&s_data->pkt_addr, 1, s_data->csum_p); + struct tcphdr buf = { .th_sport = tcphdr->th_dport, .th_dport = tcphdr->th_sport, @@ -210,6 +233,9 @@ static void ss_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) { .th_off = 5 }; + uint16_t tsz = htons(sizeof(buf)); + buf.th_sum = ~csum_partial(&buf.th_seq, 16, csum_partial(&tsz, sizeof(tsz), sock->csum_p)); + HASH_ADD(hh, s_data->o_socks_by_caddr, c_addr, c_addrlen, sock); ((struct sockaddr_in *)&s_data->pkt_addr)->sin_port = htons(0); @@ -239,11 +265,13 @@ static void ss_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) { return; } + /* if (th_flags == TH_RST) { DBG("RST received, cleaning up socket"); sock->status = TCP_CLOSE; s_sock_cleanup(EV_A_ sock); } + */ if (th_flags & ~(TH_PUSH | TH_ACK)) { DBG("TCP flags not PSH and/or ACK, dropping packet"); @@ -297,7 +325,7 @@ static void ss_cb(EV_P_ ev_io *w, int revents __attribute__((unused))) { assert(sock->status == TCP_ESTABLISHED); DBG("sending %zu bytes to client", (size_t)(sz - tcphdr->th_off * 4)); - sz = send(sock->c_sock, rbuf + tcphdr->th_off * 4, sz - tcphdr->th_off * 4, 0); + sz = send(sock->c_sock, rptr + tcphdr->th_off * 4, sz - tcphdr->th_off * 4, 0); if (sz < 0) { perror("send"); ev_break(EV_A_ EVBREAK_ONE); @@ -323,11 +351,15 @@ int start_server(const char *s_host, const char *s_port, const char *r_host, con return 1; } + char proto[] = { 0, IPPROTO_TCP }; + struct s_data s_data = { .s_addr = res->ai_addr, .s_addrlen = res->ai_addrlen, .r_host = r_host, - .r_port = r_port + .r_port = r_port, + .csum_p = csum_sockaddr_partial(res->ai_addr, 1, + csum_partial(&proto, sizeof(proto), 0)) }; s_data.s_sock = socket(s_data.s_addr->sa_family, SOCK_RAW, IPPROTO_TCP); @@ -337,6 +369,11 @@ int start_server(const char *s_host, const char *s_port, const char *r_host, con return 1; } + if (bind(s_data.s_sock, res->ai_addr, res->ai_addrlen) == -1) { + perror("bind"); + return 2; + } + if (fcntl(s_data.s_sock, F_SETFL, O_NONBLOCK) == -1) { perror("fcntl"); freeaddrinfo(res); @@ -1,17 +1,17 @@ #!/bin/sh # this script tests basic udpastcp functionality. -test1() { +test_bidi() { ( pids= trap 'kill $pids' EXIT - ./udpastcp client localhost 36563 localhost 64109 & + ./udpastcp client "$1" 36563 "$1" 64109 & pids="$!" - ./udpastcp server localhost 64109 localhost 41465 & + ./udpastcp server "$1" 64109 "$1" 41465 & pids="$pids $!" - ( ( sleep 0.4; echo BBBBBBBB; ) | socat udp6-listen:41465 - ) & + ( ( sleep 0.4; echo BBBBBBBB; ) | socat "udp-listen:41465,pf=${2}" - ) & pids="$pids $!" - ( ( sleep 0.2; echo AAAAAAAA; ) | socat - 'udp-connect:[::1]:36563' ) & + ( ( sleep 0.2; echo AAAAAAAA; ) | socat - "udp-connect:localhost:36563,pf=${2}" ) & pids="$pids $!" sleep 0.5 ) @@ -20,9 +20,18 @@ test1() { nl=' ' -if [ "$( test1 )" = "AAAAAAAA${nl}BBBBBBBB" ]; then - echo "Test succeeded." +if [ "$( test_bidi 127.0.0.1 ip4 )" = "AAAAAAAA${nl}BBBBBBBB" ]; then + echo "IPv4 test succeeded." else - echo "Test failed." - exit 1 + echo "IPv4 test failed." + r=1 fi + +if [ "$( test_bidi ::1 ip6 )" = "AAAAAAAA${nl}BBBBBBBB" ]; then + echo "IPv6 test succeeded." +else + echo "IPv6 test failed." + r=1 +fi + +exit $r |