Skip to content

Instantly share code, notes, and snippets.

@elyosh
Last active March 3, 2023 11:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elyosh/922e6c15f8d4d7102c8ac9508b0cdc3b to your computer and use it in GitHub Desktop.
Save elyosh/922e6c15f8d4d7102c8ac9508b0cdc3b to your computer and use it in GitHub Desktop.
kTLS zerocopy sendfile offset bug reproducer
// gcc -Wall ktls_test.c -o ktls_test -lssl -lcrypto
#include <errno.h>
#include <unistd.h>
#include <string.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <resolv.h>
#include <linux/tls.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/sha.h>
int in_fd;
int in_len;
int sendfile_offset;
char sha1_str[41];
int zero_copy;
#define READLEN 65536
void compute_sha1(int fd) {
size_t n;
void *in_buf = malloc(READLEN);
if (sendfile_offset) {
lseek(fd, sendfile_offset, SEEK_SET);
}
const EVP_MD *md = EVP_get_digestbyname("SHA1");
EVP_MD_CTX *mdctx = EVP_MD_CTX_new();
EVP_DigestInit_ex(mdctx, md, NULL);
while ((n = read(fd, in_buf, READLEN))) {
EVP_DigestUpdate(mdctx, in_buf, n);
}
unsigned char md_value[EVP_MAX_MD_SIZE];
unsigned int md_len;
EVP_DigestFinal_ex(mdctx, md_value, &md_len);
EVP_MD_CTX_free(mdctx);
free(in_buf);
for (int i=0; i < SHA_DIGEST_LENGTH; i++) {
sprintf(sha1_str + (i*2), "%02x", md_value[i]);
}
}
int create_listener(int port) {
struct sockaddr_in6 addr;
bzero(&addr, sizeof(addr));
addr.sin6_family = AF_INET6;
addr.sin6_port = htons(port);
addr.sin6_addr = in6addr_any;
int sd = socket(AF_INET6, SOCK_STREAM, 0);
if (sd < 0) {
perror("Unable to create listen socket");
exit(EXIT_FAILURE);
}
if (bind(sd, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
perror("Unable to bind listen socket");
exit(EXIT_FAILURE);
}
if (listen(sd, 10) < 0) {
perror("Unable to listen");
exit(EXIT_FAILURE);
}
return sd;
}
SSL_CTX *create_ssl_context() {
SSL_library_init();
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
const SSL_METHOD *method = TLS_server_method();
SSL_CTX *ssl_ctx = SSL_CTX_new(method);
if (ssl_ctx == NULL) {
ERR_print_errors_fp(stderr);
exit(EXIT_FAILURE);
}
SSL_CTX_set_max_proto_version(ssl_ctx, TLS1_2_VERSION);
SSL_CTX_set_cipher_list(ssl_ctx, "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-GCM-SHA256");
SSL_CTX_set_options(ssl_ctx, SSL_OP_CIPHER_SERVER_PREFERENCE|SSL_OP_ENABLE_KTLS);
return ssl_ctx;
}
void load_cert_key(SSL_CTX* ctx, const char* CertFile, const char* KeyFile)
{
if (SSL_CTX_use_certificate_chain_file(ctx, CertFile) <= 0) {
ERR_print_errors_fp(stderr);
exit(EXIT_FAILURE);
}
if (SSL_CTX_use_PrivateKey_file(ctx, KeyFile, SSL_FILETYPE_PEM) <= 0) {
ERR_print_errors_fp(stderr);
exit(EXIT_FAILURE);
}
if (!SSL_CTX_check_private_key(ctx)) {
fprintf(stderr, "Private key does not match the public certificate\n");
exit(EXIT_FAILURE);
}
}
void serve_file(SSL *ssl) {
char buf[1024];
int bytes = 0;
int sd = SSL_get_fd(ssl);
int ret = SSL_accept(ssl);
if (ret != 1) {
fprintf(stderr, "SSL_accept error\n");
goto shutdown;
}
ret = SSL_read(ssl, buf, sizeof(buf));
if (ret <= 0) {
fprintf(stderr, "SSL_read error\n");
goto shutdown;
}
snprintf(buf, sizeof(buf), "HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nContent-Length: %d\r\nX-Source-SHA1: %s\r\n\r\n", in_len - sendfile_offset, sha1_str);
SSL_write(ssl, buf, strlen(buf));
bytes = in_len - sendfile_offset;
if (zero_copy) {
if (setsockopt(sd, SOL_TLS, TLS_TX_ZEROCOPY_RO, &zero_copy, sizeof(zero_copy)) < 0) {
perror("Error enabling zerocopy sendfile");
goto shutdown;
}
printf("TLS_TX_ZEROCOPY_RO enabled\n");
}
ret = SSL_sendfile(ssl, in_fd, sendfile_offset, bytes, 0);
printf("sendfile(%d, %d, %d, %d) = %d\n", sd, in_fd, sendfile_offset, bytes, ret);
shutdown:
SSL_shutdown(ssl);
SSL_free(ssl); /* release SSL state */
close(sd); /* close connection */
}
int main(int argc, char **argv) {
char *cert_file = NULL;
char *key_file = NULL;
char *in_file = NULL;
zero_copy = 0;
int listen_port = 4443;
sendfile_offset = 0;
int c;
int help = 0;
while ((c = getopt (argc, argv, "hzo:c:k:i:p:")) != -1) {
switch (c) {
case 'i':
in_file = optarg;
break;
case 'o':
sendfile_offset = atoi(optarg);
break;
case 'p':
listen_port = atoi(optarg);
break;
case 'c':
cert_file = optarg;
break;
case 'k':
key_file = optarg;
break;
case 'z':
zero_copy = 1;
break;
case 'h':
help = 1;
break;
}
}
if (help || !cert_file || !key_file || !in_file) {
fprintf(stderr, "Usage: %s -i <in_file> -p <listen_port> -c <ssl_cert> -k <ssl_key> -o <sendfile_offset> -z <zerocopy_enable>\n", argv[0]);
exit(EXIT_FAILURE);
}
in_fd = open(in_file, O_RDONLY);
if (in_fd < 0) {
perror("Could not open input file");
exit(EXIT_FAILURE);
}
struct stat f_stat;
if (fstat(in_fd, &f_stat) < 0) {
perror("Could not stat input file");
exit(EXIT_FAILURE);
}
in_len = f_stat.st_size;
if (sendfile_offset >= in_len) {
fprintf(stderr, "Invalid offset %d for file size %d\n", sendfile_offset, in_len);
exit(EXIT_FAILURE);
}
compute_sha1(in_fd);
printf("Serving file %s, will send %d bytes (%d - %d) with SHA1 sum %s\n", in_file, in_len - sendfile_offset, sendfile_offset, in_len, sha1_str);
SSL_CTX *ssl_ctx = create_ssl_context();
load_cert_key(ssl_ctx, cert_file, key_file);
int server_sd = create_listener(listen_port);
while (1) {
struct sockaddr_in addr;
socklen_t len = sizeof(addr);
int cli_sd = accept(server_sd, (struct sockaddr*)&addr, &len);
SSL *ssl = SSL_new(ssl_ctx);
SSL_set_fd(ssl, cli_sd);
serve_file(ssl);
}
close(server_sd);
SSL_CTX_free(ssl_ctx);
close(in_fd);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment