Skip to content

Instantly share code, notes, and snippets.

@beordle
Forked from ayende/libuv_tls.c
Created February 18, 2022 12:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save beordle/b582470084a2adfb0027222191d99efd to your computer and use it in GitHub Desktop.
Save beordle/b582470084a2adfb0027222191d99efd to your computer and use it in GitHub Desktop.
#include "stdafx.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <uv.h>
#define DEFAULT_PORT 7000
#define DEFAULT_BACKLOG 128
typedef struct tls_uv_connection_state tls_uv_connection_state_t;
typedef struct {
tls_uv_connection_state_t* (*create_connection)(uv_tcp_t* connection);
int (*connection_established)(tls_uv_connection_state_t* connection);
void (*connection_closed)(tls_uv_connection_state_t* connection, int status);
int (*read)(tls_uv_connection_state_t* connection, void* buf, ssize_t nread);
} connection_handler_t;
typedef struct {
SSL_CTX *ctx;
uv_loop_t* loop;
connection_handler_t protocol;
tls_uv_connection_state_t* pending_writes;
} tls_uv_server_state_t;
typedef struct tls_uv_connection_state {
tls_uv_server_state_t* server;
uv_tcp_t* handle;
SSL *ssl;
BIO *read, *write;
struct {
tls_uv_connection_state_t** prev_holder;
tls_uv_connection_state_t* next;
int in_queue;
size_t pending_writes_count;
uv_buf_t* pending_writes_buffer;
} pending;
} tls_uv_connection_state_t;
void alloc_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf) {
buf->base = (char*)malloc(suggested_size);
buf->len = suggested_size;
}
void report_connection_failure(int status) {
fprintf(stderr, "New connection error %s\n", uv_strerror(status));
}
void remove_connection_from_queue(tls_uv_connection_state_t* cur) {
if (cur->pending.pending_writes_buffer != NULL) {
free(cur->pending.pending_writes_buffer);
}
if (cur->pending.prev_holder != NULL) {
*cur->pending.prev_holder = cur->pending.next;
}
memset(&cur->pending, 0, sizeof(cur->pending));
}
void abort_connection_on_error(tls_uv_connection_state_t* state) {
uv_close((uv_handle_t*)state->handle, NULL);
SSL_free(state->ssl);
// implicitly freed by SSL_free
//BIO_free(state->read);
//BIO_free(state->write);
remove_connection_from_queue(state);
free(state);
}
void maybe_flush_ssl(tls_uv_connection_state_t* state) {
if (state->pending.in_queue)
return;
if (BIO_pending(state->write) == 0 && state->pending.pending_writes_count > 0)
return;
state->pending.next = state->server->pending_writes;
if (state->pending.next != NULL) {
state->pending.next->pending.prev_holder = &state->pending.next;
}
state->pending.prev_holder = &state->server->pending_writes;
state->pending.in_queue = 1;
state->server->pending_writes = state;
}
void handle_read(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf) {
tls_uv_connection_state_t* state = client->data;
BIO_write(state->read, buf->base, nread);
while (1)
{
int rc = SSL_read(state->ssl, buf->base, buf->len);
if (rc <= 0) {
rc = SSL_get_error(state->ssl, rc);
if (rc != SSL_ERROR_WANT_READ) {
state->server->protocol.connection_closed(state, rc);
abort_connection_on_error(state);
break;
}
maybe_flush_ssl(state);
// need to read more, we'll let libuv handle this
break;
}
if (state->server->protocol.read(state, buf->base, rc) == 0) {
// protocol asked to close the socket
abort_connection_on_error(state);
break;
}
}
free(buf->base);
}
void complete_write(uv_write_t* r, int status) {
tls_uv_connection_state_t* state = r->data;
free(r->write_buffer.base);
free(r);
if (status < 0) {
state->server->protocol.connection_closed(state, status);
abort_connection_on_error(state);
}
}
void flush_ssl_buffer(tls_uv_connection_state_t* cur) {
int rc = BIO_pending(cur->write);
if (rc > 0) {
uv_buf_t buf = uv_buf_init(malloc(rc), rc);
BIO_read(cur->write, buf.base, rc);
uv_write_t* r = calloc(1, sizeof(uv_write_t));
r->data = cur;
uv_write(r, (uv_stream_t*)cur->handle, &buf, 1, complete_write);
}
}
void try_flush_ssl_state(uv_handle_t * handle) {
tls_uv_server_state_t* server_state = handle->data;
tls_uv_connection_state_t** head = &server_state->pending_writes;
while (*head != NULL) {
tls_uv_connection_state_t* cur = *head;
flush_ssl_buffer(cur);
if (cur->pending.pending_writes_count == 0) {
remove_connection_from_queue(cur);
continue;
}
// here we have pending writes to deal with, so we'll try stuffing them
// into the SSL buffer
int used = 0;
for (size_t i = 0; i < cur->pending.pending_writes_count; i++)
{
int rc = SSL_write(cur->ssl,
cur->pending.pending_writes_buffer[i].base,
cur->pending.pending_writes_buffer[i].len);
if (rc > 0) {
used++;
continue;
}
rc = SSL_get_error(cur->ssl, rc);
if (rc == SSL_ERROR_WANT_WRITE) {
flush_ssl_buffer(cur);
i--;// retry
continue;
}
if (rc != SSL_ERROR_WANT_READ) {
server_state->protocol.connection_closed(cur, rc);
abort_connection_on_error(cur);
cur->pending.in_queue = 0;
break;
}
// we are waiting for reads from the network
// we can't remove this instance, so we play
// with the pointer and start the scan/remove
// from this position
head = &cur->pending.next;
break;
}
flush_ssl_buffer(cur);
if (used == cur->pending.pending_writes_count) {
remove_connection_from_queue(cur);
}
else {
cur->pending.pending_writes_count -= used;
memmove(cur->pending.pending_writes_buffer,
cur->pending.pending_writes_buffer + sizeof(uv_buf_t)*used,
sizeof(uv_buf_t) * cur->pending.pending_writes_count);
}
}
}
void prepare_if_need_to_flush_ssl_state(uv_prepare_t * handle) {
try_flush_ssl_state((uv_handle_t*)handle);
}
void check_if_need_to_flush_ssl_state(uv_check_t * handle) {
try_flush_ssl_state((uv_handle_t*)handle);
}
int connection_write(tls_uv_connection_state_t* state, void* buf, int size) {
int rc = SSL_write(state->ssl, buf, size);
if (rc > 0)
{
maybe_flush_ssl(state);
return 1;
}
rc = SSL_get_error(state->ssl, rc);
if (rc == SSL_ERROR_WANT_WRITE) {
flush_ssl_buffer(state);
rc = SSL_write(state->ssl, buf, size);
if (rc > 0)
return 1;
}
if (rc != SSL_ERROR_WANT_READ) {
state->server->protocol.connection_closed(state, rc);
abort_connection_on_error(state);
return 0;
}
// we need to re negotiate with the client, so we can't accept the write yet
// we'll copy it to the side for now and retry after the next read
uv_buf_t copy = uv_buf_init(malloc(size), size);
memcpy(copy.base, buf, size);
state->pending.pending_writes_count++;
state->pending.pending_writes_buffer = realloc(state->pending.pending_writes_buffer,
sizeof(uv_buf_t) * state->pending.pending_writes_count);
state->pending.pending_writes_buffer[state->pending.pending_writes_count - 1] = copy;
maybe_flush_ssl(state);
return 1;
}
void on_new_connection(uv_stream_t *server, int status) {
if (status < 0) {
report_connection_failure(status);
return;
}
tls_uv_server_state_t* server_state = server->data;
uv_tcp_t *client = (uv_tcp_t*)malloc(sizeof(uv_tcp_t));
uv_tcp_init(server_state->loop, client);
status = uv_accept(server, (uv_stream_t*)client);
if (status != 0) {
uv_close((uv_handle_t*)client, NULL);
report_connection_failure(status);
return;
}
tls_uv_connection_state_t* state = server_state->protocol.create_connection(client);
state->ssl = SSL_new(server_state->ctx);
SSL_set_accept_state(state->ssl);
state->server = server_state;
state->handle = client;
state->read = BIO_new(BIO_s_mem());
state->write = BIO_new(BIO_s_mem());
BIO_set_nbio(state->read, 1);
BIO_set_nbio(state->write, 1);
SSL_set_bio(state->ssl, state->read, state->write);
client->data = state;
if (server_state->protocol.connection_established(state) == 0) {
abort_connection_on_error(state);
return;
}
uv_read_start((uv_stream_t*)client, alloc_buffer, handle_read);
}
int network_one_time_init() {
WSADATA wsaData;
return WSAStartup(MAKEWORD(2, 2), &wsaData);
}
tls_uv_connection_state_t* on_create_connection(uv_tcp_t* connection) {
return calloc(1, sizeof(tls_uv_connection_state_t));
}
int on_connection_established(tls_uv_connection_state_t* connection) {
return connection_write(connection, "OK\r\n", 4);
}
void on_connection_closed(tls_uv_connection_state_t* connection, int status) {
report_connection_failure(status);
}
int on_read(tls_uv_connection_state_t* connection, void* buf, ssize_t nread) {
return connection_write(connection, buf, nread);
}
int main() {
char * cert = "C:\\Work\\NetworkProtocol\\NetworkProtocol\\example-com.cert.pem";
char * key = "C:\\Work\\NetworkProtocol\\NetworkProtocol\\example-com.key.pem";
network_one_time_init();
SSL_load_error_strings();
ERR_load_BIO_strings();
ERR_load_crypto_strings();
OpenSSL_add_ssl_algorithms();
const SSL_METHOD* method = TLSv1_2_server_method();
SSL_CTX* ctx = SSL_CTX_new(method);
SSL_CTX_use_certificate_file(ctx, cert, SSL_FILETYPE_PEM);
SSL_CTX_use_PrivateKey_file(ctx, key, SSL_FILETYPE_PEM);
uv_loop_t* loop = uv_default_loop();
tls_uv_server_state_t server_state = {
.ctx = ctx,
.loop = loop,
.protocol = {
.create_connection = on_create_connection,
.connection_closed = on_connection_closed,
.read = on_read,
.connection_established = on_connection_established
}
};
uv_tcp_t server;
uv_tcp_init(loop, &server);
server.data = &server_state;
struct sockaddr_in addr;
uv_ip4_addr("0.0.0.0", DEFAULT_PORT, &addr);
uv_tcp_bind(&server, (const struct sockaddr*)&addr, 0);
int r = uv_listen((uv_stream_t*)&server, DEFAULT_BACKLOG, on_new_connection);
if (r) {
fprintf(stderr, "Listen error %s\n", uv_strerror(r));
return 1;
}
uv_prepare_t before_io;
before_io.data = &server_state;
uv_prepare_init(loop, &before_io);
uv_prepare_start(&before_io, prepare_if_need_to_flush_ssl_state);
uv_check_t after_io;
after_io.data = &server_state;
uv_check_init(loop, &after_io);
uv_check_start(&after_io, check_if_need_to_flush_ssl_state);
return uv_run(loop, UV_RUN_DEFAULT);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment