Skip to content

Instantly share code, notes, and snippets.

@x-yuri
Last active May 15, 2024 00:41
Show Gist options
  • Save x-yuri/3203591cf7b359c1c3f3f2fc88a2f584 to your computer and use it in GitHub Desktop.
Save x-yuri/3203591cf7b359c1c3f3f2fc88a2f584 to your computer and use it in GitHub Desktop.
TCP/TLS echo client/server

TCP/TLS echo client/server

#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <unistd.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <assert.h>

#define SERVER_PORT 1024
#define BUFFER_SIZE 256
#define INFO_CALLBACK_TYPESTRING_INIT_BUFFER_SIZE 256
#define INFO_CALLBACK_INIT_BUFFER_SIZE 256
#define READLINE_INIT_BUFFER_SIZE 256

int is_server;

void usage(const char *program_name)
{
    printf("Usage: %s {s | c}\n", program_name);
}

const int SSL_CB_values[] = {
    SSL_CB_LOOP,
    SSL_CB_EXIT,
    SSL_CB_READ,
    SSL_CB_WRITE,
    SSL_CB_ALERT,
    SSL_CB_HANDSHAKE_START,
    SSL_CB_HANDSHAKE_DONE,
};

const char *SSL_CB_strings[] = {
    "SSL_CB_LOOP",
    "SSL_CB_EXIT",
    "SSL_CB_READ",
    "SSL_CB_WRITE",
    "SSL_CB_ALERT",
    "SSL_CB_HANDSHAKE_START",
    "SSL_CB_HANDSHAKE_DONE",
};

const int SSL_CB_len = sizeof(SSL_CB_values) / sizeof(SSL_CB_values[0]);

void my_info_callback(const SSL *ssl, const int type, const int val)
{
    size_t type_string_size = INFO_CALLBACK_TYPESTRING_INIT_BUFFER_SIZE;
    char *type_string = malloc(type_string_size);
    if (!type_string) {
        perror("failed to allocate the type string buffer");
        return;
    }
    type_string[0] = '\0';

    size_t buf_size = INFO_CALLBACK_INIT_BUFFER_SIZE;
    char *buf = malloc(buf_size);
    if (!buf) {
        perror("failed to allocate the type string buffer");
        free(type_string);
        return;
    }
    const char *next_el;

    int empty = 1;
    for (int i = 0; i < SSL_CB_len; i++) {
        if (!(type & SSL_CB_values[i])) continue;
        if (empty)
            next_el = SSL_CB_strings[i];
        else {
            if (strlen(" | ") + strlen(SSL_CB_strings[i]) + 1 > buf_size) {
                buf_size = (strlen(" | ") + strlen(SSL_CB_strings[i])) * 2;
                char *new_buf = realloc(buf, buf_size);
                if (!new_buf) {
                    perror("failed to reallocate a buffer");
                    free(buf);
                    free(type_string);
                    return;
                }
                buf = new_buf;
            }
            strcpy(buf, " | ");
            strcat(buf, SSL_CB_strings[i]);
            next_el = buf;
        }
        if (strlen(type_string) + strlen(next_el) + 1 > type_string_size) {
            type_string_size = (strlen(type_string) + strlen(next_el)) * 2;
            char *new_type_string = realloc(type_string, type_string_size);
            if (!new_type_string) {
                perror("failed to reallocate the type string buffer");
                free(buf);
                free(type_string);
                return;
            }
            type_string = new_type_string;
        }
        strcat(type_string, next_el);
        empty = 0;
    }

    printf("-- type: %x, %s, val: %x\n", type, type_string, val);
    const char *state_string = SSL_state_string_long(ssl);
    if (strcmp(state_string, "unknown") != 0)
        printf("   state: %s\n", state_string);
    const char *alert_type_string = SSL_alert_type_string_long(val);
    if (strcmp(alert_type_string, "unknown") != 0)
        printf("   alert type: %s\n", alert_type_string);
    const char *alert_desc_string = SSL_alert_desc_string_long(val);
    if (strcmp(alert_desc_string, "unknown") != 0)
        printf("   alert desc: %s\n", alert_desc_string);

    free(buf);
    free(type_string);
}

int my_verify_callback(const int preverify_ok, __attribute__ ((unused)) X509_STORE_CTX *x509_ctx)
{
    return preverify_ok;
}

SSL_CTX *create_ssl_ctx()
{
    SSL_CTX *ctx = SSL_CTX_new(
        is_server ? TLS_server_method() : TLS_client_method());
    if (ctx == NULL) {
        ERR_print_errors_fp(stderr);
        return NULL;
    }

    if (SSL_CTX_use_certificate_chain_file(ctx,
        is_server ? "certs/server.crt" : "certs/client.crt"
    ) != 1) {
        ERR_print_errors_fp(stderr);
        SSL_CTX_free(ctx);
        return NULL;
    }

    if (SSL_CTX_use_PrivateKey_file(ctx,
        is_server ? "certs/server.key" : "certs/client.key",
    SSL_FILETYPE_PEM) != 1) {
        ERR_print_errors_fp(stderr);
        SSL_CTX_free(ctx);
        return NULL;
    }

    SSL_CTX_set_verify(ctx,
        SSL_VERIFY_PEER
        | (is_server ? SSL_VERIFY_FAIL_IF_NO_PEER_CERT : 0),
    my_verify_callback);

    // if (SSL_CTX_set_default_verify_file(ctx) != 1) {
    if (SSL_CTX_load_verify_file(ctx, "certs/root.crt") != 1) {
        ERR_print_errors_fp(stderr);
        SSL_CTX_free(ctx);
        return NULL;
    }

    SSL_CTX_set_info_callback(ctx, my_info_callback);

    return ctx;
}

SSL *create_ssl(SSL_CTX *ssl_ctx, const int s)
{
    SSL *ssl = SSL_new(ssl_ctx);
    if (ssl == NULL) {
        ERR_print_errors_fp(stderr);
        return NULL;
    }

    if (SSL_set_fd(ssl, s) != 1) {
        ERR_print_errors_fp(stderr);
        SSL_free(ssl);
        return NULL;
    }

    if (is_server) {
        if (SSL_accept(ssl) != 1) {
            ERR_print_errors_fp(stderr);
            SSL_free(ssl);
            return NULL;
        }

    } else {
        if (SSL_connect(ssl) != 1) {
            ERR_print_errors_fp(stderr);
            SSL_free(ssl);
            return NULL;
        }
    }
    return ssl;
}

int server_io_loop(SSL *ssl)
{
    int nread;
    char buf[BUFFER_SIZE];
    while ((nread = SSL_read(ssl, buf, BUFFER_SIZE)) > 0) {
        int nwritten = SSL_write(ssl, buf, nread);
        if (nwritten <= 0) {
            ERR_print_errors_fp(stderr);
            return -1;
        }
        assert(nwritten == nread);
    }
    if (nread <= 0) {
        const int error = SSL_get_error(ssl, nread);
        if (error != SSL_ERROR_ZERO_RETURN) {
            ERR_print_errors_fp(stderr);
            return -1;
        }
    }
    return 0;
}

int client_output_loop(SSL *ssl, const size_t nread_from_stdin)
{
    size_t nleft_to_read = nread_from_stdin;
    while (nleft_to_read > 0) {
        char buf[BUFFER_SIZE];
        const int nto_read = nleft_to_read > BUFFER_SIZE ? BUFFER_SIZE : (int)nleft_to_read;
        const int nread = SSL_read(ssl, buf, nto_read);
        if (nread <= 0) {
            ERR_print_errors_fp(stderr);
            return -1;
        }

        size_t nleft_to_write = (size_t)nread;
        ssize_t nwritten;
        const char *buf_ptr = buf;
        while (nleft_to_write > 0
        && (nwritten = write(STDOUT_FILENO, buf_ptr, nleft_to_write)) != -1) {
            buf_ptr += nwritten;
            nleft_to_write -= (size_t)nwritten;
        }
        if (nwritten == -1) {
            perror("failed to write to stdout");
            return -1;
        }

        nleft_to_read -= (size_t)nread;
    }
    return 0;
}

ssize_t read_line(char **line, size_t *line_size)
{
    if (!*line) {
        *line_size = READLINE_INIT_BUFFER_SIZE;
        *line = malloc(*line_size);
    }
    if (*line_size == 0) {
        *line_size = READLINE_INIT_BUFFER_SIZE;
        char *new_line = realloc(*line, *line_size);
        if (!new_line) {
            perror("failed to reallocate the line");
            return -1;
        }
        *line = new_line;
    }
    char *line_ptr = *line;
    *line[0] = '\0';
    size_t nread = 0;
    do {
        if (nread == *line_size - 1) {
            *line_size *= 2;
            char *new_line = realloc(*line, *line_size);
            if (!new_line) {
                perror("failed to reallocate the line");
                return -1;
            }
            *line = new_line;
            line_ptr = *line + nread;
        }
        if (*line_size - nread > INT_MAX) {
            fputs("stop typing, will you?", stderr);
            return -1;
        }
        if (!fgets(line_ptr, (int)(*line_size - nread), stdin) && errno) {
            perror("failed to read from stdin");
            return -1;
        }
        size_t delta = strlen(line_ptr);
        nread += delta;
        if (nread > LONG_MAX) {
            fputs("stop typing, will you?", stderr);
            return -1;
        }
        line_ptr += delta;
    } while(*(line_ptr - 1) != '\n');
    return (ssize_t)nread;
}

int client_io_loop(SSL *ssl)
{
    char *line = NULL;
    while (1) {
        size_t line_size;
        const ssize_t nread = read_line(&line, &line_size);
        if (nread == -1) {
            perror("failed to read from stdin");
            free(line);
            return -1;
        }

        if (strcmp(line, "\n") == 0) break;

        size_t nleft_to_write = (size_t)nread;
        while (nleft_to_write > 0) {
            const int towrite = nleft_to_write > INT_MAX ? INT_MAX : (int)nleft_to_write;
            const int nwritten = SSL_write(ssl, line, towrite);
            if (nwritten <= 0) {
                ERR_print_errors_fp(stderr);
                free(line);
                return -1;
            }
            assert(nwritten == nread);
            nleft_to_write -= (size_t)nwritten;
        }

        if (client_output_loop(ssl, (size_t)nread) == -1) {
            free(line);
            return -1;
        }
    }
    free(line);
    return 0;
}

int main(const int argc, const char **argv)
{
    if (argc < 2) {
        usage(argv[0]);
        return EXIT_FAILURE;
    }
    is_server = argv[1][0] == 's';

    SSL_CTX *ssl_ctx = create_ssl_ctx();
    if (ssl_ctx == NULL) {
        return EXIT_FAILURE;
    }

    const int s = socket(AF_INET, SOCK_STREAM, 0);
    if (s == -1) {
        perror("failed to create a socket");
        SSL_CTX_free(ssl_ctx);
        return EXIT_FAILURE;
    }

    if (is_server) {
        struct sockaddr_in addr;
        memset(&addr, 0, sizeof addr);
        addr.sin_family = AF_INET;
        addr.sin_port = htons(SERVER_PORT);
        addr.sin_addr.s_addr = INADDR_ANY;
        if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
            perror("failed to bind an address to the socket");
            if (close(s) == -1) {
                perror("failed to close the socket");
            }
            SSL_CTX_free(ssl_ctx);
            return EXIT_FAILURE;
        }

        if (listen(s, 1) == -1) {
            perror("failed to start listening on the socket");
            if (close(s) == -1) {
                perror("failed to close the socket");
            }
            SSL_CTX_free(ssl_ctx);
            return EXIT_FAILURE;
        }

        struct sockaddr_in client_addr;
        socklen_t client_addr_len = sizeof(client_addr);
        const int client_s = accept(s, (struct sockaddr *)&client_addr, &client_addr_len);
        if (client_s == -1) {
            perror("failed to accept a connection");
            if (close(s) == -1) {
                perror("failed to close the socket");
            }
            SSL_CTX_free(ssl_ctx);
            return EXIT_FAILURE;
        }

        SSL *ssl = create_ssl(ssl_ctx, client_s);
        if (ssl == NULL) {
            if (close(client_s) == -1) {
                perror("failed to close the client socket");
            }
            if (close(s) == -1) {
                perror("failed to close the socket");
            }
            SSL_CTX_free(ssl_ctx);
            return EXIT_FAILURE;
        }

        if (server_io_loop(ssl) == -1) {
            if (SSL_shutdown(ssl) != 1) {
                ERR_print_errors_fp(stderr);
            }
            SSL_free(ssl);
            if (close(client_s) == -1) {
                perror("failed to close the client socket");
            }
            if (close(s) == -1) {
                perror("failed to close the socket");
            }
            SSL_CTX_free(ssl_ctx);
            return EXIT_FAILURE;
        }

        if (SSL_shutdown(ssl) != 1) {
            ERR_print_errors_fp(stderr);
        }
        SSL_free(ssl);
        if (close(client_s) == -1) {
            perror("failed to close the client socket");
        }

    } else {
        struct sockaddr_in addr;
        memset(&addr, 0, sizeof addr);
        addr.sin_family = AF_INET;
        addr.sin_port = htons(SERVER_PORT);
        addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
        if (connect(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
            perror("failed to connect to the server");
            if (close(s) == -1) {
                perror("failed to close the socket");
            }
            SSL_CTX_free(ssl_ctx);
            return EXIT_FAILURE;
        }

        SSL *ssl = create_ssl(ssl_ctx, s);
        if (ssl == NULL) {
            if (close(s) == -1) {
                perror("failed to close the socket");
            }
            SSL_CTX_free(ssl_ctx);
            return EXIT_FAILURE;
        }

        if (client_io_loop(ssl) == -1) {
            if (SSL_shutdown(ssl) != 1) {
                ERR_print_errors_fp(stderr);
            }
            SSL_free(ssl);
            if (close(s) == -1) {
                perror("failed to close the socket");
            }
            SSL_CTX_free(ssl_ctx);
            return EXIT_FAILURE;
        }

        if (SSL_shutdown(ssl) != 1) {
            ERR_print_errors_fp(stderr);
        }
        SSL_free(ssl);
    }

    if (close(s) == -1) {
        perror("failed to close the socket");
    }
    SSL_CTX_free(ssl_ctx);
    return EXIT_SUCCESS;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment