Skip to content

Instantly share code, notes, and snippets.

@rlcamp
Last active May 19, 2022 16:18
Show Gist options
  • Save rlcamp/4dbc72df87dd6a0302aa1cd35020287f to your computer and use it in GitHub Desktop.
Save rlcamp/4dbc72df87dd6a0302aa1cd35020287f to your computer and use it in GitHub Desktop.
allow Wireguard to replace ppp
/* isc license probably */
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <string.h>
#include <time.h>
#include <unistd.h>
#include <fcntl.h>
#include <poll.h>
#include <termios.h>
#include <arpa/inet.h>
#define MAX_UDP_SIZE 65507
#define MAX_ESCAPED_SIZE (MAX_UDP_SIZE * 2 + 1)
#define NOPE(...) do { fprintf(stderr, "error: " __VA_ARGS__); exit(EXIT_FAILURE); } while(0)
/* frame escaping scheme is ppp-like: frames end with 0x7E, and any instances of 0x7E or 0x7D are
preceded with an extra 0x7D byte and then xor'd with 0x20 */
#define END_BYTE 0x7E
#define ESC_BYTE 0x7D
#define ESC_MASK 0x20
static unsigned long long current_time_in_milliseconds(void) {
struct timespec timespec;
clock_gettime(CLOCK_REALTIME, &timespec);
return timespec.tv_sec * 1000ULL + timespec.tv_nsec / 1000000;
}
static char * packet_description(char out[static 72], const unsigned char * data, const size_t size) {
uint32_t key0 = 0, key1 = 0;
if (size >= 8) memcpy(&key0, data + 4, sizeof(uint32_t));
if (size >= 12) memcpy(&key1, data + 8, sizeof(uint32_t));
if (0x2 == data[0]) sprintf(out, "type %u, sender key index 0x%8.8X, receiver key index 0x%8.8X", data[0], key0, key1);
else sprintf(out, "type %u, key index 0x%8.8X", data[0], key0);
return out;
}
int main(const int argc, char ** const argv) {
const char * const slash_in_argvzero = strrchr(argv[0], '/');
const char * const progname = slash_in_argvzero ? slash_in_argvzero + 1 : argv[0];
if (argc < 2) {
fprintf(stderr, "%s: Forward packets between udp and serial, with ppp-like framing, for the purpose of tunneling Wireguard or other bidirectional udp traffic over a raw serial line\n\n", progname);
fprintf(stderr, "Usage: %s [serial device] [baud rate] [udp port to listen on] [udp port to forward to]\n\n", argv[0]);
fprintf(stderr, "Example: given a Wireguard interface with a ListenPort of 51820, configured to talk to a peer with an assumed Endpoint of 127.0.0.1:51821, you would run:\n");
fprintf(stderr, " %s /dev/ttyS0 115200 51821 51820\n", argv[0]);
fprintf(stderr, "Note that for Wireguard handshaking to be completed reliably over slow links, \"ip route\" or equivalent must also be used to restrict the TCP window size to around 55 seconds or less, at whatever baud rate is used.\n");
exit(EXIT_FAILURE);
}
/* parse cmdline arguments */
const char * const path_serial_device = argc > 1 ? argv[1] : "/dev/ttyS0";
const unsigned int baud = argc > 2 ? (unsigned int)strtoul(argv[2], NULL, 10) : 115200;
const unsigned short udp_port = argc > 3 ? strtoul(argv[3], NULL, 10) : 51821;
const unsigned short udp_output_port = argc > 4 ? strtoul(argv[4], NULL, 10) : 51820;
const int fd_serial = open(path_serial_device, O_RDWR | O_NOCTTY | O_NONBLOCK);
if (-1 == fd_serial) NOPE("%s: cannot open %s: %s\n", progname, path_serial_device, strerror(errno));
struct termios ts;
if (-1 == tcgetattr(fd_serial, &ts)) NOPE("%s: cannot tcgetattr: %s\n", progname, strerror(errno));
cfmakeraw(&ts);
/* turn off input and output processing */
ts.c_iflag = 0;
ts.c_oflag = 0;
#if 0
/* FIXME: implementations are allowed to hang if there are more than MAX_CANON bytes between
consecutive 0x7e or 0x0a bytes in the input. if there is an external guarantee (such as via
smallish MTU) that this will never happen, then it's safe, but it would probably be better to
use VTIME or some other mechanism */
/* enable canonical mode so that read() returns when it sees a control character */
ts.c_lflag = ICANON;
/* and disable all control characters except for VEOL, which is set to END_BYTE */
for (size_t icc = 0; icc < NCCS; icc++)
ts.c_cc[icc] = VEOL == icc ? END_BYTE : _POSIX_VDISABLE;
#else
/* return after 0.1 seconds if at least one byte has been received, regardless of whether full
reads have been satisfied */
ts.c_cc[VMIN] = 1;
ts.c_cc[VTIME] = 1;
#endif
/* if your desired baud rate is not represented here, fix it */
if (-1 == cfsetspeed(&ts, 2400 >= baud ? B2400 : 4800 >= baud ? B4800 : 9600 >= baud ? B9600 : 19200 >= baud ? B19200 : 230400 <= baud ? B230400 : B115200))
NOPE("%s: cannot cfsetspeed: %s\n", progname, strerror(errno));
if (-1 == tcsetattr(fd_serial, TCSANOW, &ts)) NOPE("%s: cannot tcsetattr: %s\n", progname, strerror(errno));
/* attempt to clear stale data */
if (-1 == tcflush(fd_serial, TCIOFLUSH)) NOPE("%s: cannot tcflush: %s\n", progname, strerror(errno));
/* open a socket for receiving udp packets */
const int fd_udp = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
if (-1 == fd_udp) NOPE("%s: cannot socket(): %s\n", progname, strerror(errno));
if (-1 == setsockopt(fd_udp, SOL_SOCKET, SO_SNDBUF, &(int){ 65536 }, sizeof(int))) NOPE("%s: cannot setsockopt(): %s\n", progname, strerror(errno));
if (-1 == setsockopt(fd_udp, SOL_SOCKET, SO_RCVBUF, &(int){ 65536 }, sizeof(int))) NOPE("%s: cannot setsockopt(): %s\n", progname, strerror(errno));
/* if a local port was given, bind udp socket to the port number given on the cmdline, accept
packets from anywhere */
if (udp_port && bind(fd_udp, (struct sockaddr *)&(struct sockaddr_in) {
.sin_family = AF_INET,
.sin_port = htons(udp_port),
.sin_addr.s_addr = htonl(INADDR_ANY)
}, sizeof(struct sockaddr_in)))
NOPE("%s: cannot bind(%d): %s\n", progname, udp_port, strerror(errno));
/* this will be overwritten each time we get a new packet from the actual peer */
struct sockaddr_in peer = {
.sin_family = AF_INET,
.sin_port = htons(udp_output_port),
.sin_addr.s_addr = htonl(INADDR_LOOPBACK)
};
/* holds the decoded packet being constructed as bytes arrive via serial, which will be sent via
udp when complete. this is scoped here because we don't assume anything in the main loop about
how many bytes we get per read() from the serial line, whether the boundaries fall between an
escape byte and the following escaped byte, &c. */
unsigned char * buffer_to_udp = malloc(MAX_UDP_SIZE);
if (!buffer_to_udp) abort();
size_t buffer_to_udp_size = 0;
char last_byte_from_serial_was_escape = 0;
struct queued {
struct queued * next;
unsigned long long time;
size_t plain_size;
unsigned char plain[MAX_UDP_SIZE];
} * queued_head = NULL, * queued_tail = NULL, * queued_freelist = NULL, * in_progress = NULL;
/* holds the encoded packet being sent via serial. must be twice the size of the largest udp
packet, plus one byte */
unsigned char * buffer_to_serial = malloc(MAX_ESCAPED_SIZE), * buffer_to_serial_cursor = NULL, * buffer_to_serial_stop = NULL;
if (!buffer_to_serial) abort();
/* on-wire size is potentially twice the max udp packet size, plus the end-of-frame byte */
unsigned char * buffer_from_serial = malloc(MAX_ESCAPED_SIZE);
if (!buffer_from_serial) abort();
/* 0.2 seconds at whagever the baud rate is, so 48 bytes at 2400 baud. this number is a tradeoff
between max cpu load and penalty for abandoning a packet */
const size_t max_write_size = baud / 50;
unsigned long long prior = 0;
while (1) {
if (prior) {
const unsigned long long elapsed = current_time_in_milliseconds() - prior;
if (elapsed > 1)
fprintf(stderr, "warning: %s: %llu.%03llu elapsed outside of poll()\n", progname, elapsed / 1000, elapsed % 1000);
}
/* block until we get either new bytes from serial or a new packet from udp, or we are
waiting to send bytes and are allowed to now do so */
struct pollfd pollfds[] = {
{ .fd = fd_serial, .events = POLLIN },
{ .fd = fd_udp, .events = POLLIN },
{ .fd = fd_serial, .events = POLLOUT }
};
if (-1 == poll(pollfds, (buffer_to_serial_cursor != buffer_to_serial_stop || queued_head) ? 3 : 2, -1))
NOPE("%s: cannot poll: %s\n", progname, strerror(errno));
/* get current time in milliseconds for diagnostics */
const unsigned long long now = current_time_in_milliseconds();
prior = now;
/* if bytes can be sent to the serial port... */
if ((buffer_to_serial_cursor != buffer_to_serial_stop || queued_head) && POLLOUT == pollfds[2].revents) {
/* if the head of the queue is an 0x01 or 0x02 byte, and we're currently sending a non-
handshake packet, then abort and requeue the current packet */
if (buffer_to_serial_cursor != buffer_to_serial_stop && queued_head &&
(queued_head->plain[0] == 0x01 || queued_head->plain[0] == 0x02) &&
in_progress->plain[0] == 0x04) {
fprintf(stderr, "%s: truncating current packet after %zu bytes, will restart after sending higher-priority type %u packet\n", progname, buffer_to_serial_cursor - buffer_to_serial, queued_head->plain[0]);
*buffer_to_serial_cursor = END_BYTE;
buffer_to_serial_stop = buffer_to_serial_cursor + 1;
/* requeue this packet so that it may be restarted later */
in_progress->next = queued_head->next;
queued_head->next = in_progress;
in_progress = NULL;
}
/* if a packet is not yet being sent... */
if (buffer_to_serial_cursor == buffer_to_serial_stop) {
/* dequeue a packet and escape it */
struct queued * this = queued_head;
queued_head = this->next;
if (!queued_head) queued_tail = NULL;
buffer_to_serial_stop = buffer_to_serial;
buffer_to_serial_cursor = buffer_to_serial;
/* escape bytes as necessary */
for (size_t iplain = 0; iplain < this->plain_size; ) {
const unsigned char plain_byte = this->plain[iplain++];
if (END_BYTE == plain_byte || ESC_BYTE == plain_byte) {
*(buffer_to_serial_stop++) = ESC_BYTE;
*(buffer_to_serial_stop++) = plain_byte ^ ESC_MASK;
}
else
*(buffer_to_serial_stop++) = plain_byte;
}
*(buffer_to_serial_stop++) = END_BYTE;
char buf[72];
fprintf(stderr, "%s: %llu.%03llu: send: %s, %zd bytes (delayed %llu.%03llu s)\n", progname, now / 1000, now % 1000, packet_description(buf, this->plain, this->plain_size), this->plain_size, (now - this->time) / 1000, (now - this->time) % 1000);
in_progress = this;
this->next = NULL;
}
/* send some enqueued bytes */
const size_t bytes_remaining = buffer_to_serial_stop - buffer_to_serial_cursor;
const ssize_t ret = write(fd_serial, buffer_to_serial_cursor, bytes_remaining > max_write_size ? max_write_size : bytes_remaining);
if (ret > 0) buffer_to_serial_cursor += ret;
else if (-1 == ret && EAGAIN != errno && EWOULDBLOCK != errno) NOPE("%s: cannot write(): %s\n", progname, strerror(errno));
/* if we just finished sending a packet, put it in the freelist */
if (buffer_to_serial_cursor == buffer_to_serial_stop && in_progress) {
in_progress->next = queued_freelist;
queued_freelist = in_progress;
in_progress = NULL;
}
}
/* if at least one new byte can be read from the serial line... */
if (POLLIN == pollfds[0].revents) {
/* get however many bytes the kernel feels like giving us right now */
const ssize_t count = read(fd_serial, buffer_from_serial, MAX_ESCAPED_SIZE);
if (-1 == count) NOPE("%s: cannot read(): %s\n", progname, strerror(errno));
else if (!count) break;
/* loop over however many bytes we just got from the serial port, assuming nothing about
how they are batched */
for (size_t iencoded = 0; iencoded < (size_t)count; iencoded++) {
unsigned char byte = buffer_from_serial[iencoded];
if (!last_byte_from_serial_was_escape && ESC_BYTE == byte) {
last_byte_from_serial_was_escape = 1;
continue;
}
/* if we're seeing the end of the frame... */
if (!last_byte_from_serial_was_escape && END_BYTE == byte) {
char buf[72];
fprintf(stderr, "%s: %llu.%03llu: recv: %s, %zu bytes\n", progname, now / 1000, now % 1000, packet_description(buf, buffer_to_udp, buffer_to_udp_size), buffer_to_udp_size);
/* send it via udp */
if (!peer.sin_port)
fprintf(stderr, "%s: got %zd bytes from serial, but no local udp destination known yet\n", progname, count);
else if (-1 == sendto(fd_udp, buffer_to_udp, buffer_to_udp_size, 0, (void *)&peer, sizeof(peer)))
fprintf(stderr, "warning: %s: failed to send to %u: %s\n", progname, ntohs(peer.sin_port), strerror(errno));
buffer_to_udp_size = 0;
continue;
}
if (last_byte_from_serial_was_escape) {
byte ^= ESC_MASK;
last_byte_from_serial_was_escape = 0;
}
if (MAX_UDP_SIZE == buffer_to_udp_size) {
fprintf(stderr, "warning: %s: no frame-end marker after %zu bytes, discarding contents of buffer\n", progname, buffer_to_udp_size);
buffer_to_udp_size = 0;
}
buffer_to_udp[buffer_to_udp_size++] = byte;
}
}
/* if done sending the last batch of bytes to serial, and at least one complete new udp
packet can be recv'd... */
if (POLLIN == pollfds[1].revents) {
struct queued * this = queued_freelist;
if (this) queued_freelist = this->next;
else this = malloc(sizeof(struct queued));
if (!this) abort();
this->next = NULL;
const ssize_t count = recvfrom(fd_udp, this->plain, sizeof(this->plain), 0, (void *)&peer, &(socklen_t) { sizeof(peer) });
if (-1 == count) NOPE("%s: cannot recvfrom(): %s\n", progname, strerror(errno));
this->plain_size = count;
this->time = now;
/* if the outgoing packet is a handshake request or reply... */
if (0x1 == this->plain[0] || 0x2 == this->plain[0]) {
/* then place it at the head of the queue */
char buf[72];
fprintf(stderr, "%s: %llu.%03llu: enqueued %s\n", progname, now / 1000, now % 1000, packet_description(buf, this->plain, this->plain_size));
this->next = queued_head;
queued_head = this;
if (!queued_tail) queued_tail = this;
} else {
/* otherwise place it at the tail of the queue */
if (queued_tail) queued_tail->next = this;
else queued_head = this;
queued_tail = this;
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment