#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <unistd.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <assert.h>
#define SERVER_PORT 1024
#define BUFFER_SIZE 256
#define INFO_CALLBACK_TYPESTRING_INIT_BUFFER_SIZE 256
#define INFO_CALLBACK_INIT_BUFFER_SIZE 256
#define READLINE_INIT_BUFFER_SIZE 256
int is_server;
void usage(const char *program_name)
{
printf("Usage: %s {s | c}\n", program_name);
}
const int SSL_CB_values[] = {
SSL_CB_LOOP,
SSL_CB_EXIT,
SSL_CB_READ,
SSL_CB_WRITE,
SSL_CB_ALERT,
SSL_CB_HANDSHAKE_START,
SSL_CB_HANDSHAKE_DONE,
};
const char *SSL_CB_strings[] = {
"SSL_CB_LOOP",
"SSL_CB_EXIT",
"SSL_CB_READ",
"SSL_CB_WRITE",
"SSL_CB_ALERT",
"SSL_CB_HANDSHAKE_START",
"SSL_CB_HANDSHAKE_DONE",
};
const int SSL_CB_len = sizeof(SSL_CB_values) / sizeof(SSL_CB_values[0]);
void my_info_callback(const SSL *ssl, const int type, const int val)
{
size_t type_string_size = INFO_CALLBACK_TYPESTRING_INIT_BUFFER_SIZE;
char *type_string = malloc(type_string_size);
if (!type_string) {
perror("failed to allocate the type string buffer");
return;
}
type_string[0] = '\0';
size_t buf_size = INFO_CALLBACK_INIT_BUFFER_SIZE;
char *buf = malloc(buf_size);
if (!buf) {
perror("failed to allocate the type string buffer");
free(type_string);
return;
}
const char *next_el;
int empty = 1;
for (int i = 0; i < SSL_CB_len; i++) {
if (!(type & SSL_CB_values[i])) continue;
if (empty)
next_el = SSL_CB_strings[i];
else {
if (strlen(" | ") + strlen(SSL_CB_strings[i]) + 1 > buf_size) {
buf_size = (strlen(" | ") + strlen(SSL_CB_strings[i])) * 2;
char *new_buf = realloc(buf, buf_size);
if (!new_buf) {
perror("failed to reallocate a buffer");
free(buf);
free(type_string);
return;
}
buf = new_buf;
}
strcpy(buf, " | ");
strcat(buf, SSL_CB_strings[i]);
next_el = buf;
}
if (strlen(type_string) + strlen(next_el) + 1 > type_string_size) {
type_string_size = (strlen(type_string) + strlen(next_el)) * 2;
char *new_type_string = realloc(type_string, type_string_size);
if (!new_type_string) {
perror("failed to reallocate the type string buffer");
free(buf);
free(type_string);
return;
}
type_string = new_type_string;
}
strcat(type_string, next_el);
empty = 0;
}
printf("-- type: %x, %s, val: %x\n", type, type_string, val);
const char *state_string = SSL_state_string_long(ssl);
if (strcmp(state_string, "unknown") != 0)
printf(" state: %s\n", state_string);
const char *alert_type_string = SSL_alert_type_string_long(val);
if (strcmp(alert_type_string, "unknown") != 0)
printf(" alert type: %s\n", alert_type_string);
const char *alert_desc_string = SSL_alert_desc_string_long(val);
if (strcmp(alert_desc_string, "unknown") != 0)
printf(" alert desc: %s\n", alert_desc_string);
free(buf);
free(type_string);
}
int my_verify_callback(const int preverify_ok, __attribute__ ((unused)) X509_STORE_CTX *x509_ctx)
{
return preverify_ok;
}
SSL_CTX *create_ssl_ctx()
{
SSL_CTX *ctx = SSL_CTX_new(
is_server ? TLS_server_method() : TLS_client_method());
if (ctx == NULL) {
ERR_print_errors_fp(stderr);
return NULL;
}
if (SSL_CTX_use_certificate_chain_file(ctx,
is_server ? "certs/server.crt" : "certs/client.crt"
) != 1) {
ERR_print_errors_fp(stderr);
SSL_CTX_free(ctx);
return NULL;
}
if (SSL_CTX_use_PrivateKey_file(ctx,
is_server ? "certs/server.key" : "certs/client.key",
SSL_FILETYPE_PEM) != 1) {
ERR_print_errors_fp(stderr);
SSL_CTX_free(ctx);
return NULL;
}
SSL_CTX_set_verify(ctx,
SSL_VERIFY_PEER
| (is_server ? SSL_VERIFY_FAIL_IF_NO_PEER_CERT : 0),
my_verify_callback);
// if (SSL_CTX_set_default_verify_file(ctx) != 1) {
if (SSL_CTX_load_verify_file(ctx, "certs/root.crt") != 1) {
ERR_print_errors_fp(stderr);
SSL_CTX_free(ctx);
return NULL;
}
SSL_CTX_set_info_callback(ctx, my_info_callback);
return ctx;
}
SSL *create_ssl(SSL_CTX *ssl_ctx, const int s)
{
SSL *ssl = SSL_new(ssl_ctx);
if (ssl == NULL) {
ERR_print_errors_fp(stderr);
return NULL;
}
if (SSL_set_fd(ssl, s) != 1) {
ERR_print_errors_fp(stderr);
SSL_free(ssl);
return NULL;
}
if (is_server) {
if (SSL_accept(ssl) != 1) {
ERR_print_errors_fp(stderr);
SSL_free(ssl);
return NULL;
}
} else {
if (SSL_connect(ssl) != 1) {
ERR_print_errors_fp(stderr);
SSL_free(ssl);
return NULL;
}
}
return ssl;
}
int server_io_loop(SSL *ssl)
{
int nread;
char buf[BUFFER_SIZE];
while ((nread = SSL_read(ssl, buf, BUFFER_SIZE)) > 0) {
int nwritten = SSL_write(ssl, buf, nread);
if (nwritten <= 0) {
ERR_print_errors_fp(stderr);
return -1;
}
assert(nwritten == nread);
}
if (nread <= 0) {
const int error = SSL_get_error(ssl, nread);
if (error != SSL_ERROR_ZERO_RETURN) {
ERR_print_errors_fp(stderr);
return -1;
}
}
return 0;
}
int client_output_loop(SSL *ssl, const size_t nread_from_stdin)
{
size_t nleft_to_read = nread_from_stdin;
while (nleft_to_read > 0) {
char buf[BUFFER_SIZE];
const int nto_read = nleft_to_read > BUFFER_SIZE ? BUFFER_SIZE : (int)nleft_to_read;
const int nread = SSL_read(ssl, buf, nto_read);
if (nread <= 0) {
ERR_print_errors_fp(stderr);
return -1;
}
size_t nleft_to_write = (size_t)nread;
ssize_t nwritten;
const char *buf_ptr = buf;
while (nleft_to_write > 0
&& (nwritten = write(STDOUT_FILENO, buf_ptr, nleft_to_write)) != -1) {
buf_ptr += nwritten;
nleft_to_write -= (size_t)nwritten;
}
if (nwritten == -1) {
perror("failed to write to stdout");
return -1;
}
nleft_to_read -= (size_t)nread;
}
return 0;
}
ssize_t read_line(char **line, size_t *line_size)
{
if (!*line) {
*line_size = READLINE_INIT_BUFFER_SIZE;
*line = malloc(*line_size);
}
if (*line_size == 0) {
*line_size = READLINE_INIT_BUFFER_SIZE;
char *new_line = realloc(*line, *line_size);
if (!new_line) {
perror("failed to reallocate the line");
return -1;
}
*line = new_line;
}
char *line_ptr = *line;
*line[0] = '\0';
size_t nread = 0;
do {
if (nread == *line_size - 1) {
*line_size *= 2;
char *new_line = realloc(*line, *line_size);
if (!new_line) {
perror("failed to reallocate the line");
return -1;
}
*line = new_line;
line_ptr = *line + nread;
}
if (*line_size - nread > INT_MAX) {
fputs("stop typing, will you?", stderr);
return -1;
}
if (!fgets(line_ptr, (int)(*line_size - nread), stdin) && errno) {
perror("failed to read from stdin");
return -1;
}
size_t delta = strlen(line_ptr);
nread += delta;
if (nread > LONG_MAX) {
fputs("stop typing, will you?", stderr);
return -1;
}
line_ptr += delta;
} while(*(line_ptr - 1) != '\n');
return (ssize_t)nread;
}
int client_io_loop(SSL *ssl)
{
char *line = NULL;
while (1) {
size_t line_size;
const ssize_t nread = read_line(&line, &line_size);
if (nread == -1) {
perror("failed to read from stdin");
free(line);
return -1;
}
if (strcmp(line, "\n") == 0) break;
size_t nleft_to_write = (size_t)nread;
while (nleft_to_write > 0) {
const int towrite = nleft_to_write > INT_MAX ? INT_MAX : (int)nleft_to_write;
const int nwritten = SSL_write(ssl, line, towrite);
if (nwritten <= 0) {
ERR_print_errors_fp(stderr);
free(line);
return -1;
}
assert(nwritten == nread);
nleft_to_write -= (size_t)nwritten;
}
if (client_output_loop(ssl, (size_t)nread) == -1) {
free(line);
return -1;
}
}
free(line);
return 0;
}
int main(const int argc, const char **argv)
{
if (argc < 2) {
usage(argv[0]);
return EXIT_FAILURE;
}
is_server = argv[1][0] == 's';
SSL_CTX *ssl_ctx = create_ssl_ctx();
if (ssl_ctx == NULL) {
return EXIT_FAILURE;
}
const int s = socket(AF_INET, SOCK_STREAM, 0);
if (s == -1) {
perror("failed to create a socket");
SSL_CTX_free(ssl_ctx);
return EXIT_FAILURE;
}
if (is_server) {
struct sockaddr_in addr;
memset(&addr, 0, sizeof addr);
addr.sin_family = AF_INET;
addr.sin_port = htons(SERVER_PORT);
addr.sin_addr.s_addr = INADDR_ANY;
if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
perror("failed to bind an address to the socket");
if (close(s) == -1) {
perror("failed to close the socket");
}
SSL_CTX_free(ssl_ctx);
return EXIT_FAILURE;
}
if (listen(s, 1) == -1) {
perror("failed to start listening on the socket");
if (close(s) == -1) {
perror("failed to close the socket");
}
SSL_CTX_free(ssl_ctx);
return EXIT_FAILURE;
}
struct sockaddr_in client_addr;
socklen_t client_addr_len = sizeof(client_addr);
const int client_s = accept(s, (struct sockaddr *)&client_addr, &client_addr_len);
if (client_s == -1) {
perror("failed to accept a connection");
if (close(s) == -1) {
perror("failed to close the socket");
}
SSL_CTX_free(ssl_ctx);
return EXIT_FAILURE;
}
SSL *ssl = create_ssl(ssl_ctx, client_s);
if (ssl == NULL) {
if (close(client_s) == -1) {
perror("failed to close the client socket");
}
if (close(s) == -1) {
perror("failed to close the socket");
}
SSL_CTX_free(ssl_ctx);
return EXIT_FAILURE;
}
if (server_io_loop(ssl) == -1) {
if (SSL_shutdown(ssl) != 1) {
ERR_print_errors_fp(stderr);
}
SSL_free(ssl);
if (close(client_s) == -1) {
perror("failed to close the client socket");
}
if (close(s) == -1) {
perror("failed to close the socket");
}
SSL_CTX_free(ssl_ctx);
return EXIT_FAILURE;
}
if (SSL_shutdown(ssl) != 1) {
ERR_print_errors_fp(stderr);
}
SSL_free(ssl);
if (close(client_s) == -1) {
perror("failed to close the client socket");
}
} else {
struct sockaddr_in addr;
memset(&addr, 0, sizeof addr);
addr.sin_family = AF_INET;
addr.sin_port = htons(SERVER_PORT);
addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
if (connect(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) {
perror("failed to connect to the server");
if (close(s) == -1) {
perror("failed to close the socket");
}
SSL_CTX_free(ssl_ctx);
return EXIT_FAILURE;
}
SSL *ssl = create_ssl(ssl_ctx, s);
if (ssl == NULL) {
if (close(s) == -1) {
perror("failed to close the socket");
}
SSL_CTX_free(ssl_ctx);
return EXIT_FAILURE;
}
if (client_io_loop(ssl) == -1) {
if (SSL_shutdown(ssl) != 1) {
ERR_print_errors_fp(stderr);
}
SSL_free(ssl);
if (close(s) == -1) {
perror("failed to close the socket");
}
SSL_CTX_free(ssl_ctx);
return EXIT_FAILURE;
}
if (SSL_shutdown(ssl) != 1) {
ERR_print_errors_fp(stderr);
}
SSL_free(ssl);
}
if (close(s) == -1) {
perror("failed to close the socket");
}
SSL_CTX_free(ssl_ctx);
return EXIT_SUCCESS;
}
Last active
May 15, 2024 00:41
-
-
Save x-yuri/3203591cf7b359c1c3f3f2fc88a2f584 to your computer and use it in GitHub Desktop.
TCP/TLS echo client/server
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment