Skip to content

Instantly share code, notes, and snippets.

@Maxdamantus
Last active July 20, 2021 19:50
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Maxdamantus/e32ab94dbc5d9d43298428400020620e to your computer and use it in GitHub Desktop.
Save Maxdamantus/e32ab94dbc5d9d43298428400020620e to your computer and use it in GitHub Desktop.
#include <stdlib.h>
#include <unistd.h>
#include <poll.h>
#include <fcntl.h>
#include <signal.h>
#include <ctype.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <pthread.h>
#include <openssl/bio.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
struct state {
SSL_CTX *cctx, *sctx;
EVP_PKEY *privkey;
X509 *cacert;
};
struct cstate {
struct state *state;
BIO *cbio;
};
struct handle_client_args {
struct state *state;
int csock;
};
static void assertssl(int a, char *msg){
if(!a){
fprintf(stderr, "SSL failure: %s; ", msg);
ERR_print_errors_fp(stderr);
exit(EXIT_FAILURE);
}
}
static void asserterrno(int a, char *msg){
if(!a){
perror(msg);
exit(EXIT_FAILURE);
}
}
static int checksslstderr(int a, char *msg){
if(!a){
fprintf(stderr, "SSL error: %s; ", msg);
ERR_print_errors_fp(stderr);
}
return a;
}
static int checksslhead(struct cstate *state, int a, char *msg){
if(!a){
fprintf(stderr, "SSL error: %s; ", msg);
// TODO: consider writing an HTTP response or something containing the error (not convinced this is actually desirable)
ERR_print_errors_fp(stderr);
}
return a;
}
static void assertx(int a, char *msg){
if(!a){
fprintf(stderr, "error: %s\n", msg);
exit(EXIT_FAILURE);
}
}
static void mitm_listen(struct state *state, struct in_addr listen_address, int listen_port);
static int sni_callback(SSL *cssl, int *al, void *data);
static void *handle_client(void *args_);
static void generate_host_serial_number(struct state *state, const char *hostname, ASN1_INTEGER *out);
static X509 *generate_host_cert(struct state *state, const char *hostname);
static int blocking_write(BIO *out, int outfd, char *data, int len);
static int accept_http_connect(int sock, char *addressout, int addresslen);
static int thread_setup(void);
static int thread_cleanup(void);
int main(int argc, char **argv){
if(argc != 5){
fprintf(stderr, "Usage: %s CAKEY.pem CACERT.pem LISTEN-ADDRESS LISTEN-PORT\n", argv[0]);
return EXIT_FAILURE;
}
struct in_addr listen_address;
char *cakey_pem = argv[1];
char *cacert_pem = argv[2];
assertx(inet_aton(argv[3], &listen_address), "invalid LISTEN-ADDRESS");
int listen_port = atoi(argv[4]);
thread_setup();
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
struct state *state = &(struct state){
.sctx = SSL_CTX_new(SSLv23_client_method()),
.cctx = SSL_CTX_new(SSLv23_server_method()),
};
assertssl(SSL_CTX_set_default_verify_paths(state->sctx), "SSL_CTX_set_default_verify_paths");
SSL_CTX_set_verify(state->sctx, SSL_VERIFY_PEER, NULL);
FILE *privkeyfile = fopen(cakey_pem, "rb");
asserterrno(privkeyfile != NULL, "fopen (privkey)");
state->privkey = PEM_read_PrivateKey(privkeyfile, NULL, NULL, NULL);
assertssl(state->privkey != NULL, "PEM_read_PrivateKey");
fclose(privkeyfile);
FILE *cacertfile = fopen(cacert_pem, "rb");
asserterrno(cacertfile != NULL, "fopen (cacert)");
state->cacert = PEM_read_X509(cacertfile, NULL, NULL, NULL);
assertssl(state->cacert != NULL, "PEM_read_X509");
fclose(cacertfile);
SSL_CTX_set_ecdh_auto(state->cctx, 1);
assertssl(SSL_CTX_add_extra_chain_cert(state->cctx, state->cacert), "SSL_CTX_add_extra_chain_cert");
SSL_CTX_set_tlsext_servername_callback(state->cctx, sni_callback);
SSL_CTX_set_tlsext_servername_arg(state->cctx, state);
sigaction(SIGPIPE, &(struct sigaction){ .sa_handler = SIG_IGN }, NULL);
mitm_listen(state, listen_address, listen_port);
// NOTE: if we were to reach this point in the program, we'd need to make sure all threads have completed
SSL_CTX_free(state->sctx);
SSL_CTX_free(state->cctx);
EVP_PKEY_free(state->privkey);
thread_cleanup();
return EXIT_SUCCESS;
}
static int sni_callback(SSL *cssl, int *al, void *data){
struct state *state = data;
const char *hostname = SSL_get_servername(cssl, TLSEXT_NAMETYPE_host_name);
printf("sni_callback: %s\n", hostname);
if(2 + 2 == 4)
return 1;
if(hostname == NULL)
return 0;
X509 *cert = generate_host_cert(state, hostname);
assertssl(SSL_use_certificate(cssl, cert), "SSL_CTX_use_certificate");
assertssl(SSL_use_PrivateKey(cssl, state->privkey), "SSL_CTX_use_PrivateKey");
X509_free(cert);
return 1;
}
static void mitm_listen(struct state *state, struct in_addr listen_address, int listen_port){
int ssock = socket(AF_INET, SOCK_STREAM, 0);
asserterrno(ssock >= 0, "socket (listen)");
asserterrno(setsockopt(ssock, SOL_SOCKET, SO_REUSEADDR, &(int){ 1 }, sizeof (int)) >= 0, "setsockopt");
struct sockaddr_in saddr = {
.sin_family = AF_INET,
.sin_port = htons(listen_port),
.sin_addr = listen_address
};
asserterrno(bind(ssock, (struct sockaddr *)&saddr, sizeof saddr) >= 0, "bind");
asserterrno(listen(ssock, 100) >= 0, "listen");
for(;;){
int csock = accept(ssock, NULL, NULL);
asserterrno(csock >= 0, "accept");
struct handle_client_args *args = malloc(sizeof *args);
*args = (struct handle_client_args){
.state = state,
.csock = csock
};
pthread_t thread;
pthread_create(&thread, NULL, handle_client, args);
pthread_detach(thread);
}
}
static void *handle_client(void *args_){
struct handle_client_args *args = args_;
struct state *state = args->state;
int csock = args->csock;
free(args);
BIO *sbio = NULL, *cbio = NULL;
char address[1024];
if(!accept_http_connect(csock, address, sizeof address))
goto end;
char *hostname = address, *port = strchr(address, ':');
printf("hostname = %s, port = %s\n", hostname, port);
if(!hostname || !port)
goto end;
*port = 0;
port++;
cbio = BIO_new_ssl(state->cctx, 0);
SSL *cssl = NULL;
BIO_get_ssl(cbio, &cssl);
BIO_set_close(cbio, BIO_CLOSE);
assertssl(cssl != NULL, "cssl != null");
assertssl(SSL_set_fd(cssl, csock), "SSL_set_fd");
X509 *cert = generate_host_cert(state, hostname);
assertssl(SSL_use_certificate(cssl, cert), "SSL_CTX_use_certificate");
assertssl(SSL_use_PrivateKey(cssl, state->privkey), "SSL_CTX_use_PrivateKey");
X509_free(cert);
// NOTE: errors are no longer necessarily assertions from this point
if(!checksslstderr(SSL_accept(cssl) > 0, "SSL_accept"))
goto end;
struct cstate cstate = {
.cbio = cbio,
.state = state
};
sbio = BIO_new_ssl_connect(state->sctx);
SSL *sssl = NULL;
BIO_get_ssl(sbio, &sssl);
assertssl(sssl != NULL, "sssl != null");
BIO_set_conn_hostname(sbio, hostname);
BIO_set_conn_port(sbio, port);
assertx(SSL_set_tlsext_host_name(sssl, hostname), "SSL_set_tlsext_host_name");
X509_VERIFY_PARAM *vpm = SSL_get0_param(sssl);
assertx(X509_VERIFY_PARAM_set1_host(vpm, hostname, 0), "X509_VERIFY_PARAM_set1_host");
if(!checksslhead(&cstate, BIO_do_connect(sbio) > 0, "BIO_do_connect"))
goto end;
if(!checksslhead(&cstate, BIO_do_handshake(sbio) > 0, "BIO_do_handshake"))
goto end;
int ssock = SSL_get_fd(sssl);
if(!checksslhead(&cstate, ssock >= 0, "SSL_get_fd"))
goto end;
asserterrno(!fcntl(csock, F_SETFL, O_NONBLOCK), "fcntl");
asserterrno(!fcntl(ssock, F_SETFL, O_NONBLOCK), "fcntl");
int alive[] = { csock, ssock };
int alive_s = sizeof alive/sizeof *alive;
for(;;){
int reads = 0;
for(int x = 0; x < alive_s; x++){
BIO *in, *out;
int infd = alive[x];
int outfd;
if(infd == csock){
in = cbio;
out = sbio;
outfd = ssock;
}else if(infd == ssock){
in = sbio;
out = cbio;
outfd = csock;
}else
continue;
char buf[1024];
int len = BIO_read(in, buf, sizeof buf);
if(len <= 0){
if(!BIO_should_retry(in)){
BIO_ssl_shutdown(out);
// TODO: consider reporting error
alive[x] = -1;
}
}else{
reads++;
if(!blocking_write(out, outfd, buf, len)){
// TODO: consider reporting error
}
}
}
if(reads > 0)
continue;
struct pollfd pollfds[alive_s];
int empty = 1;
for(int x = 0; x < alive_s; x++){
pollfds[x] = (struct pollfd){ .fd = alive[x], .events = POLLIN };
if(alive[x] >= 0)
empty = 0;
}
if(empty)
break;
poll(pollfds, alive_s, -1);
}
end:
BIO_ssl_shutdown(sbio);
BIO_ssl_shutdown(cbio);
BIO_free_all(sbio);
BIO_free_all(cbio);
close(csock);
return NULL;
}
static int blocking_write(BIO *out, int outfd, char *data, int len){
while(len > 0){
int w = BIO_write(out, data, len);
if(w > 0){
data += w;
len -= w;
}else if(!BIO_should_retry(out))
return 0;
else{
struct pollfd pollfd = {
.fd = outfd
};
if(BIO_should_read(out))
pollfd.events |= POLLIN;
if(BIO_should_write(out))
pollfd.events |= POLLOUT;
poll(&pollfd, 1, -1);
}
}
return 1;
}
static X509 *generate_host_cert(struct state *state, const char *hostname){
X509 *cert = X509_new();
assertx(cert != NULL, "X509_new");
generate_host_serial_number(state, hostname, X509_get_serialNumber(cert));
X509_set_notBefore(cert, X509_get_notBefore(state->cacert));
X509_set_notAfter(cert, X509_get_notAfter(state->cacert));
X509_set_pubkey(cert, state->privkey);
X509_set_issuer_name(cert, X509_get_issuer_name(state->cacert));
X509_NAME *name = X509_get_subject_name(cert);
X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC, (unsigned char *)hostname, -1, -1, 0);
X509_sign(cert, state->privkey, EVP_sha256());
return cert;
}
static void generate_host_serial_number(struct state *state, const char *hostname, ASN1_INTEGER *out){
SHA256_CTX c;
SHA256_Init(&c);
BIGNUM *sn = BN_new();
ASN1_INTEGER_to_BN(X509_get_serialNumber(state->cacert), sn);
unsigned char snbytes[BN_num_bytes(sn)];
BN_bn2bin(sn, snbytes);
SHA256_Update(&c, snbytes, sizeof snbytes);
SHA256_Update(&c, hostname, strlen(hostname));
unsigned char hashbytes[SHA256_DIGEST_LENGTH];
SHA256_Final(hashbytes, &c);
hashbytes[0] &= 0x7f; // ensure number is positive
BN_bin2bn(hashbytes, sizeof hashbytes, sn);
BN_to_ASN1_INTEGER(sn, out);
BN_free(sn);
}
static int accept_http_connect(int sock, char *addressout, int addresslen){
char *connect = "connect ";
while(*connect){
char c;
if(read(sock, &c, 1) < 1 || tolower(c) != *connect)
return 0;
connect++;
}
int start = 1;
for(;;){
char c;
if(read(sock, &c, 1) < 1)
return 0;
if(start && c == ' ')
continue;
start = 0;
if(c == '\n')
return 0;
if(c == ' ')
c = 0;
if(addresslen < 1)
return 0;
*addressout = c;
addressout++;
addresslen--;
if(!c)
break;
}
int nl = 0;
for(;;){
char c;
if(read(sock, &c, 1) < 1)
return 0;
if(c == '\r')
continue;
if(c == '\n')
nl++;
else
nl = 0;
if(nl == 2)
break;
}
static const char okay[] = "HTTP/1.1 200 OK\r\n\r\n";
return write(sock, okay, sizeof okay - 1) == sizeof okay - 1;
}
/* This array will store all of the mutexes available to OpenSSL. */
static pthread_mutex_t *mutex_buf = NULL;
static void locking_function(int mode, int n, const char *file, int line)
{
if(mode & CRYPTO_LOCK)
pthread_mutex_lock(&mutex_buf[n]);
else
pthread_mutex_unlock(&mutex_buf[n]);
}
static unsigned long id_function(void)
{
return ((unsigned long)pthread_self());
}
static int thread_setup(void)
{
int i;
mutex_buf = malloc(CRYPTO_num_locks() * sizeof *mutex_buf);
if(!mutex_buf)
return 0;
for(i = 0; i < CRYPTO_num_locks(); i++)
pthread_mutex_init(&mutex_buf[i], NULL);
CRYPTO_set_id_callback(id_function);
CRYPTO_set_locking_callback(locking_function);
return 1;
}
static int thread_cleanup(void)
{
int i;
if(!mutex_buf)
return 0;
CRYPTO_set_id_callback(NULL);
CRYPTO_set_locking_callback(NULL);
for(i = 0; i < CRYPTO_num_locks(); i++)
pthread_mutex_destroy(&mutex_buf[i]);
free(mutex_buf);
mutex_buf = NULL;
return 1;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment