Skip to content

Instantly share code, notes, and snippets.

@snej

snej/mbedtls.cc Secret

Created August 27, 2019 23:58
Show Gist options
  • Save snej/a9cde8eb3f8ee0d228898f8715bb4f26 to your computer and use it in GitHub Desktop.
Save snej/a9cde8eb3f8ee0d228898f8715bb4f26 to your computer and use it in GitHub Desktop.
WIP mbedTLS support for uSocket
/*
* Authored by Jens Alfke, 2019; parts adapted from code by Alex Hultman.
* Intellectual property of third-party.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if !defined(LIBUS_NO_SSL) && defined(LIBUS_USE_MBEDTLS)
#include "libusockets.h"
extern "C" {
#include "internal/internal.h"
}
/* This module contains the entire mbedTLS implementation
* of the SSL socket and socket context interfaces. */
#include <mbedtls/entropy.h>
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/certs.h>
#include <mbedtls/error.h>
#include <mbedtls/x509.h>
#include <mbedtls/ssl.h>
#include <mbedtls/net_sockets.h>
#include <mbedtls/error.h>
#include <mbedtls/debug.h>
#include <assert.h>
#include <stdarg.h>
#include <stdlib.h>
#include <atomic>
// typedefs for sanity's sake
typedef struct us_internal_ssl_socket_context_t us_internal_ssl_socket_context_t;
typedef struct us_internal_ssl_socket_t us_internal_ssl_socket_t;
typedef struct us_socket_context_t us_socket_context_t;
typedef struct us_socket_context_options_t us_socket_context_options_t;
typedef struct us_listen_socket_t us_listen_socket_t;
typedef struct us_loop_t us_loop_t;
typedef struct us_socket_t us_socket_t;
#pragma mark - LOGGING:
// mbedTLS logging level:
// - 0 No debug
// - 1 Error
// - 2 State change
// - 3 Informational
// - 4 Verbose
#ifndef LIBUS_MBEDTLS_LOG_LEVEL
#define LIBUS_MBEDTLS_LOG_LEVEL 1
#endif
#define kLogPrefix "uSockets+mbedTLS:: "
#if LIBUS_MBEDTLS_LOG_LEVEL >= 2
static void log(const char *fmt, ...) {
va_list args;
va_start(args, fmt);
fputs(kLogPrefix, stderr);
vfprintf(stderr, fmt, args);
fputs("\n", stderr);
va_end(args);
}
#else
static inline void log(const char *fmt, ...) { }
#endif
#if LIBUS_MBEDTLS_LOG_LEVEL >= 1
static int checkErr(int err, const char *fn) {
if (err < 0 && LIBUS_MBEDTLS_LOG_LEVEL > 0) {
char desc[100];
mbedtls_strerror(err, desc, sizeof(desc));
fprintf(stderr, kLogPrefix "%s returned %s (-0x%04X)\n", fn, desc, -err);
}
return err;
}
#else
static inline int checkErr(int err, const char *fn) { return err; }
#endif
#pragma mark - LOOP SSL DATA:
struct LoopSSLData {
void* operator new(size_t size) {return malloc(size);}
void operator delete(void *ptr) {free(ptr);}
LoopSSLData() {
// Initialize random number generator:
mbedtls_entropy_init( &_entropy );
mbedtls_ctr_drbg_init( &_ctr_drbg );
int err = mbedtls_ctr_drbg_seed(&_ctr_drbg, mbedtls_entropy_func, &_entropy,
(const uint8_t *)kEntropyPersonalization,
strlen(kEntropyPersonalization));
assert(err == 0);
}
~LoopSSLData() {
mbedtls_ctr_drbg_free(&_ctr_drbg);
mbedtls_entropy_free(&_entropy);
}
mbedtls_ctr_drbg_context* random_context() {return &_ctr_drbg;}
private:
static constexpr const char* kEntropyPersonalization = "uSockets";
mbedtls_entropy_context _entropy;
mbedtls_ctr_drbg_context _ctr_drbg;
};
#pragma mark - SSL SOCKET CONTEXT:
struct us_internal_ssl_socket_context_t : public us_socket_context_t {
void* operator new(size_t size, us_loop_t *loop, size_t ext_size) {
return us_create_socket_context(0, loop, size + ext_size, {0});
}
void operator delete(void *context) {
us_socket_context_free(0, (us_socket_context_t*)context);
}
us_internal_ssl_socket_context_t() {
mbedtls_ssl_config_init(&_sslConfig);
mbedtls_x509_crt_init(&_cert);
mbedtls_pk_init(&_privateKey);
mbedtls_ssl_conf_dbg(&_sslConfig, log_mbedtls, this);
mbedtls_debug_set_threshold(LIBUS_MBEDTLS_LOG_LEVEL);
}
~us_internal_ssl_socket_context_t() {
mbedtls_ssl_config_free(&_sslConfig);
mbedtls_x509_crt_free(&_cert);
mbedtls_pk_free(&_privateKey);
}
// initialization that might fail
bool init(us_socket_context_options_t options) {
us_internal_init_loop_ssl_data(loop);
mbedtls_ssl_conf_rng(&_sslConfig, mbedtls_ctr_drbg_random, loopSSLData()->random_context() );
int err = checkErr(mbedtls_ssl_config_defaults(&_sslConfig,
MBEDTLS_SSL_IS_CLIENT, //FIXME: server too
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT),
"mbedtls_ssl_config_defaults");
if (err) return false;
if (options.cert_file_name && options.key_file_name) {
if (checkErr(mbedtls_x509_crt_parse_file(&_cert, options.cert_file_name),
"mbedtls_x509_crt_parse_file"))
return false;
if (checkErr( mbedtls_pk_parse_keyfile(&_privateKey, options.key_file_name,
options.passphrase),
"mbedtls_pk_parse_keyfile"))
return false;
mbedtls_ssl_conf_ca_chain(&_sslConfig, _cert.next, NULL);
if (checkErr( mbedtls_ssl_conf_own_cert(&_sslConfig, &_cert, &_privateKey),
"mbedtls_ssl_conf_own_cert"))
return false;
}
if (options.skip_peer_cert_validation)
mbedtls_ssl_conf_authmode(&_sslConfig, MBEDTLS_SSL_VERIFY_OPTIONAL);
// TODO: options.dh_params_file_name
return true;
}
// initialize as a child context
void init(us_internal_ssl_socket_context_t *parent_context) {
_isParent = false;
//ssl_context = parent_context->ssl_context;
abort();// TODO: Implement
}
static void log_mbedtls(void *ctx, int level,
const char *file, int line,
const char *str )
{
auto slash = rindex(file, '/');
if (slash)
file = slash+1;
auto len = strlen(str);
if (str[len-1] == '\n')
--len;
log("%s:%04d: %.*s", file, line, len, str);
}
LoopSSLData* loopSSLData() {
return (LoopSSLData*) loop->data.ssl_data;
}
mbedtls_ssl_config _sslConfig; // mbedTLS SSL configuration
mbedtls_x509_crt _cert; // My cert
mbedtls_pk_context _privateKey; // My private key
bool _isParent = true; // Am I a parent or child context?
// Client callbacks:
us_internal_ssl_socket_t *(*_on_open)(us_internal_ssl_socket_t*, int is_client, char *ip, int ip_length, int error) = nullptr;
us_internal_ssl_socket_t *(*_on_data)(us_internal_ssl_socket_t*, char *data, int length) = nullptr;
us_internal_ssl_socket_t *(*_on_writable)(us_internal_ssl_socket_t*) = nullptr;
us_internal_ssl_socket_t *(*_on_close)(us_internal_ssl_socket_t*) = nullptr;
us_internal_ssl_socket_t *(*_on_end)(us_internal_ssl_socket_t*) = nullptr;
};
#pragma mark - SSL SOCKET:
struct us_internal_ssl_socket_t : public us_socket_t {
void* operator new(size_t, void *placement) {return placement;}
void operator delete(void *) { }
bool init(const char *host =nullptr, int port =0) {
log("Init TLS socket to %s:%d", host, port);
mbedtls_ssl_init(&_ssl);
if (checkErr(mbedtls_ssl_setup(&_ssl, &sslContext()->_sslConfig),
"mbedtls_ssl_setup"))
return false;
if (host && checkErr(mbedtls_ssl_set_hostname(&_ssl, host),
"mbedtls_ssl_set_hostname"))
return false;
return true;
}
us_internal_ssl_socket_context_t* sslContext() {
return (us_internal_ssl_socket_context_t *) us_socket_context(0, this);
}
// Continue the TLS handshake. Will set _open to true when complete. Returns false on error.
bool handshake() {
if (_open)
return true;
int error = mbedtls_ssl_handshake(&_ssl);
switch (error) {
case MBEDTLS_ERR_SSL_WANT_READ:
case MBEDTLS_ERR_SSL_WANT_WRITE:
case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
return true;
case 0:
log("!!! TLS handshake complete !!!");
_open = true;
// Chain to client's on_open function:
sslContext()->_on_open(this, _is_client, _ip, _ip_length, 0);
return true;
default:
open_failed(error);
return false;
}
}
void open_failed(int err) {
char desc[100];
if (err > 0) {
strerror_r(err, desc, sizeof(desc)); // POSIX (errno)
log("!!! Closing socket due to POSIX errno %d: %s", err, desc);
} else {
mbedtls_strerror(err, desc, sizeof(desc)); // mbedTLS
log("!!! Closing socket due to mbedTLS status -0x%04X: %s", -err, desc);
}
sslContext()->_on_open(this, _is_client, _ip, _ip_length, err);
us_internal_ssl_socket_close(this);
}
void close_with_error(int err) {
if (err == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
log("!!! Socket closed by peer");
us_internal_ssl_socket_shutdown(this);
} else {
char desc[100];
if (err > 0) {
strerror_r(err, desc, sizeof(desc)); // POSIX (errno)
log("!!! Closing socket due to POSIX errno %d: %s", err, desc);
} else {
mbedtls_strerror(err, desc, sizeof(desc)); // mbedTLS
log("!!! Closing socket due to mbedTLS status -0x%04X: %s", -err, desc);
}
}
us_internal_ssl_socket_close(this);
}
//==== Called by uSockets:
// Underlying socket has opened.
us_socket_t *on_open(bool is_client, char *ip, int ip_length, int error) {
_is_client = is_client;
_ip_length = ip_length;
memcpy(_ip, ip, ip_length);
if (error && is_client) {
open_failed(error);
return this;
}
mbedtls_ssl_set_bio(&_ssl, this,
[](void *ctx, const uint8_t *buf, size_t len) {
return ((us_internal_ssl_socket_t*)ctx)->bio_send(buf, len); },
[](void *ctx, uint8_t *buf, size_t len) {
return ((us_internal_ssl_socket_t*)ctx)->bio_recv(buf, len); },
nullptr);
handshake();
return this;
}
us_socket_t *on_writable() {
if (!_open)
handshake();
if (_open)
sslContext()->_on_writable(this);
return this;
}
// Underlying socket has closed.
us_socket_t *on_close() {
log("!!! on_close");
mbedtls_ssl_free(&_ssl);
// Chain to client on_close function:
return sslContext()->_on_close(this);
}
// Underlying socket is half-closed (read stream has reached EOF).
us_socket_t *on_end() {
log("!!! Closing socket due to on_end");
if (_open)
sslContext()->_on_end(this);
else
us_internal_ssl_socket_close(this);
return this;
}
// Underlying socket has (encrypted) data available
us_socket_t *on_data(char *data, int length) {
log("<<< Received %d bytes from peer (on_data)", length);
// Point to the data so bio_recv() below can find it:
_receivedData = (uint8_t*) data;
_receivedDataLen = length;
if (!_open)
handshake();
if (_open) {
// To get mbedTLS to read the encrypted data, we will pull cleartext data out of it --
// this will trigger calls to bio_recv:
int n;
while ((n = mbedtls_ssl_read(&_ssl, _decryptedDataBuf, sizeof(_decryptedDataBuf))) > 0) {
log(" >>> mbedTLS decrypted %d bytes (on_data)", n);
sslContext()->_on_data(this, (char*)_decryptedDataBuf, n); // Pass decrypted data to application
}
if (n <= 0) {
switch (n) {
case MBEDTLS_ERR_SSL_WANT_READ:
case MBEDTLS_ERR_SSL_WANT_WRITE:
case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
case MBEDTLS_ERR_SSL_CLIENT_RECONNECT:
// not an error
break;
case 0:
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
// EOF:
log("!!! Socket closed by peer");
sslContext()->_on_end(this);
return this;
default:
// error:
char desc[100];
mbedtls_strerror(n, desc, sizeof(desc));
log("!!! Closing socket due to mbedTLS status -0x%04X: %s", -n, desc);
us_internal_ssl_socket_close(this);
return this;
}
}
}
// Assume mbedTLS read all the encrypted data:
assert(_receivedDataLen == 0);
_receivedData = nullptr;
return this;
}
// Application wants to write data.
int write(const char *data, int length, bool msg_more) {
log(">>> App writing %d bytes to socket (write)", length);
if (us_socket_is_closed(0, this) || us_internal_ssl_socket_is_shut_down(this))
return 0;
// Push the data into mbedTLS, which will call bio_send as it produces encrypted data:
int written = mbedtls_ssl_write(&_ssl, (const uint8_t*)data, length);
if (written >= 0) {
if (written < length)
log(" ... %d of %d bytes written", written, length);
return written;
} else {
log(" ... mbedTLS didn't write; status -0x%x", -written);
if (written != MBEDTLS_ERR_SSL_WANT_READ)
checkErr(written, "mbedtls_ssl_write");
return 0;
}
}
bool is_shut_down() {
return us_socket_is_shut_down(0, this) || _shuttingDown;
}
void shutdown() {
if (us_socket_is_closed(0, this) || is_shut_down())
return;
int err = checkErr( mbedtls_ssl_close_notify(&_ssl),
"mbedtls_ssl_close_notify");
_shuttingDown = true;
if (err < 0) {
// we get here if we are shutting down while still in init
us_socket_shutdown(0, this);
}
}
//==== BIO methods called by mbedTLS:
// mbedTLS has data to send to the underlying socket:
int bio_send(const uint8_t *buf, size_t length) {
int written = us_socket_write(0, this, (const char*)buf, length, false);
log(" >>> mbedTLS sent %d of %zu bytes to peer (bio_send)", written, length);
if (written > 0)
return written;
else
return MBEDTLS_ERR_NET_SEND_FAILED; //???
}
// mbedTLS wants to read from the underlying socket:
int bio_recv(uint8_t *buf, size_t length) {
// This function is being called from within my call to mbedtls_ssl_read, wherein I've
// set up these pointers to the just-read data:
size_t n = length;
if (n > _receivedDataLen)
n = _receivedDataLen;
if (n > 0) {
memcpy(buf, _receivedData, n);
_receivedData += n;
_receivedDataLen -= n;
log(" <<< mbedTLS read %zu of %zu bytes from peer (bio_recv)", n, length);
return n;
} else if (_shuttingDown) {
log(" <<< mbedTLS read 0 of %zu bytes from peer, at EOF (bio_recv)", length);
return 0;
} else {
log(" <<< mbedTLS read 0 of %zu bytes from peer (bio_recv)", length);
return MBEDTLS_ERR_SSL_WANT_READ;
}
}
private:
mbedtls_ssl_context _ssl;
bool _is_client;
char _ip[16];
size_t _ip_length;
bool _open = false;
bool _shuttingDown = false;
uint8_t* _receivedData = nullptr;
size_t _receivedDataLen = 0;
uint8_t _decryptedDataBuf[256];
};
#pragma mark - LOOP FUNCTIONS:
/* Lazily inits loop ssl data first time */
void us_internal_init_loop_ssl_data(us_loop_t *loop) {
if (!loop->data.ssl_data)
loop->data.ssl_data = new LoopSSLData;
}
/* Called by loop free, clears any loop ssl data */
void us_internal_free_loop_ssl_data(us_loop_t *loop) {
delete (LoopSSLData *) loop->data.ssl_data;
loop->data.ssl_data = NULL;
}
#pragma mark - CONTEXT FUNCTIONS:
us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(us_loop_t *loop,
int context_ext_size,
us_socket_context_options_t options) {
us_internal_ssl_socket_context_t* context = new (loop, context_ext_size) us_internal_ssl_socket_context_t;
if (!context->init(options)) {
delete context;
return nullptr;
}
return context;
}
us_internal_ssl_socket_context_t *us_internal_create_child_ssl_socket_context(us_internal_ssl_socket_context_t *parentContext,
int context_ext_size) {
abort();//FIXME //TODO
// us_internal_ssl_socket_context_t *context = new (parentContext->loop, context_ext_size) us_internal_ssl_socket_context_t;
// context->init(parentContext);
// return context;
}
void us_internal_ssl_socket_context_free(us_internal_ssl_socket_context_t *context) {
delete context;
}
void us_internal_ssl_socket_context_on_open(us_internal_ssl_socket_context_t *context,
us_internal_ssl_socket_t *(*on_open)(us_internal_ssl_socket_t *s, int is_client,
char *ip, int ip_length, int error)) {
us_socket_context_on_open(0, context, [](us_socket_t *s, int is_client, char *ip, int ip_length, int error) {
return ((us_internal_ssl_socket_t*)s)->on_open(is_client, ip, ip_length, error);
});
context->_on_open = on_open;
}
void us_internal_ssl_socket_context_on_close(us_internal_ssl_socket_context_t *context,
us_internal_ssl_socket_t *(*on_close)(us_internal_ssl_socket_t *s)) {
us_socket_context_on_close(0, context, [](us_socket_t *s) {
return ((us_internal_ssl_socket_t*)s)->on_close();
});
context->_on_close = on_close;
}
void us_internal_ssl_socket_context_on_data(us_internal_ssl_socket_context_t *context,
us_internal_ssl_socket_t *(*on_data)(us_internal_ssl_socket_t *s, char *data,
int length)) {
us_socket_context_on_data(0, context, [](us_socket_t *s, char *data, int length) {
return ((us_internal_ssl_socket_t*)s)->on_data(data, length);
});
context->_on_data = on_data;
}
void us_internal_ssl_socket_context_on_writable(us_internal_ssl_socket_context_t *context,
us_internal_ssl_socket_t *(*on_writable)(us_internal_ssl_socket_t *s)) {
us_socket_context_on_writable(0, context, [](us_socket_t *s) {
return ((us_internal_ssl_socket_t*)s)->on_writable();
});
context->_on_writable = on_writable;
}
void us_internal_ssl_socket_context_on_timeout(us_internal_ssl_socket_context_t *context,
us_internal_ssl_socket_t *(*on_timeout)(us_internal_ssl_socket_t *s)) {
us_socket_context_on_timeout(0, context, (us_socket_t *(*)(us_socket_t *)) on_timeout);
}
void us_internal_ssl_socket_context_on_end(us_internal_ssl_socket_context_t *context,
us_internal_ssl_socket_t *(*on_end)(us_internal_ssl_socket_t *s)) {
us_socket_context_on_end(0, context, [](us_socket_t *s) {
return ((us_internal_ssl_socket_t*)s)->on_end();
});
context->_on_end = on_end;
}
us_listen_socket_t *us_internal_ssl_socket_context_listen(us_internal_ssl_socket_context_t *context,
const char *host, int port, int options,
int socket_ext_size) {
return us_socket_context_listen(0, context, host, port, options,
sizeof(us_internal_ssl_socket_t) - sizeof(us_socket_t) + socket_ext_size);
}
us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect(us_internal_ssl_socket_context_t *context,
const char *host, int port, int options,
int socket_ext_size) {
auto s = us_socket_context_connect(0, context, host, port, options,
sizeof(us_internal_ssl_socket_t) - sizeof(us_socket_t) + socket_ext_size);
auto socket = new (s) us_internal_ssl_socket_t;
socket->init(host, port);
return socket;
}
void *us_internal_ssl_socket_context_ext(us_internal_ssl_socket_context_t *context) {
return context + 1;
}
us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket(us_internal_ssl_socket_context_t *context,
us_internal_ssl_socket_t *s, int ext_size) {
// "todo: this is completely untested" -- from the OpenSSL equivalent
return (us_internal_ssl_socket_t *) us_socket_context_adopt_socket(0, context, s,
sizeof(us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + ext_size);
}
#pragma mark - SOCKET FUNCTIONS:
int us_internal_ssl_socket_write(us_internal_ssl_socket_t *s, const char *data, int length, int msg_more) {
return s->write(data, length, msg_more);
}
void *us_internal_ssl_socket_ext(us_internal_ssl_socket_t *s) {
return s + 1;
}
int us_internal_ssl_socket_is_shut_down(us_internal_ssl_socket_t *s) {
return s->is_shut_down();
}
void us_internal_ssl_socket_shutdown(us_internal_ssl_socket_t *s) {
s->shutdown();
}
us_internal_ssl_socket_t * us_internal_ssl_socket_close(us_internal_ssl_socket_t *s) {
return (us_internal_ssl_socket_t *) us_socket_close(0, s);
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment