Skip to content

Instantly share code, notes, and snippets.

@ukreator
Created October 18, 2014 11:32
Show Gist options
  • Save ukreator/5570d00696059c32ec6b to your computer and use it in GitHub Desktop.
Save ukreator/5570d00696059c32ec6b to your computer and use it in GitHub Desktop.
DTLS re-handshake bug reproduction
#include <stdint.h>
#include <assert.h>
#include <time.h>
#include <stdio.h>
#include <pthread.h>
#include <stdbool.h>
#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/x509.h>
typedef struct DtlsIdentity_
{
EVP_PKEY* key;
X509* certificate;
} DtlsIdentity;
/**
* Helper functions (defined in the bottom of the file)
*/
unsigned long idFunction();
void opensslLockingFunc(int mode, int n,
const char* /*file*/, int /*line*/);
void opensslInit();
void opensslCleanup();
EVP_PKEY* generateRsaKeyPair();
X509* generateCertificate(EVP_PKEY* pkey, const char* commonName);
DtlsIdentity generateIdentity();
enum DtlsRole
{
DTLS_CLIENT = 0,
DTLS_SERVER = 1
};
typedef struct PeerContext_
{
SSL* ssl;
SSL_CTX* ctx;
BIO* inBio;
BIO* outBio;
X509* certificate;
EVP_PKEY* key;
enum DtlsRole role;
char label[20];
bool handshakeCompleted;
bool activeRenegotiation;
int negotiationsCount;
} PeerContext;
void sslInfoCallbackInternal(PeerContext*, const SSL* s, int where, int ret);
void sslInfoCallback(const SSL* s, int where, int ret)
{
sslInfoCallbackInternal((PeerContext*)SSL_get_app_data(s), s, where, ret);
}
int sslVerifyCallback(int ok, X509_STORE_CTX* store)
{
// we don't verify certificate here for simplicity
return 1;
}
void initContext(PeerContext* ctx)
{
// set labels to distinguish client/server for logging:
if (ctx->role == DTLS_CLIENT)
sprintf(ctx->label, "[C] ");
else
sprintf(ctx->label, "[S] ");
ctx->handshakeCompleted = false;
ctx->activeRenegotiation = false;
ctx->negotiationsCount = 0;
// generate new certificate and private key:
DtlsIdentity identity = generateIdentity();
ctx->ctx = (ctx->role == DTLS_CLIENT) ?
SSL_CTX_new(DTLSv1_client_method()) :
SSL_CTX_new(DTLSv1_server_method());
assert(ctx->ctx);
SSL_CTX_use_certificate(ctx->ctx, identity.certificate);
SSL_CTX_use_PrivateKey(ctx->ctx, identity.key);
EVP_PKEY_free(identity.key);
X509_free(identity.certificate);
SSL_CTX_set_info_callback(ctx->ctx, &sslInfoCallback);
SSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
&sslVerifyCallback);
SSL_CTX_set_verify_depth(ctx->ctx, 1);
SSL_CTX_set_cipher_list(ctx->ctx, "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
SSL_CTX_set_read_ahead(ctx->ctx, 1);
// "bad decompression" error fix for some linuxes:
SSL_CTX_set_options(ctx->ctx, SSL_OP_NO_COMPRESSION);
// disable tickets for simplicity:
SSL_CTX_set_options(ctx->ctx, SSL_OP_NO_TICKET);
ctx->ssl = SSL_new(ctx->ctx);
assert(ctx->ssl);
ctx->inBio = BIO_new(BIO_s_mem());
ctx->outBio = BIO_new(BIO_s_mem());
SSL_set_app_data(ctx->ssl, ctx);
if (ctx->role == DTLS_CLIENT)
SSL_set_connect_state(ctx->ssl);
else
SSL_set_accept_state(ctx->ssl);
SSL_set_bio(ctx->ssl, ctx->inBio, ctx->outBio); //< the SSL object owns the bio now
}
void renegotiate(PeerContext* ctx)
{
printf("%s <<<<Renegotiation requested>>>>\n", ctx->label);
assert(ctx->handshakeCompleted);
assert(!ctx->activeRenegotiation);
ctx->activeRenegotiation = true;
(void)BIO_reset(ctx->inBio);
(void)BIO_reset(ctx->outBio);
SSL_renegotiate(ctx->ssl);
}
bool handshakeIteration(PeerContext* ctx, uint8_t** dataToSend, size_t* len, int* timeoutMs)
{
int wantRead = false;
// we don't actually read data, but need this for SSL_read:
uint8_t buf[4096];
// SSL_read after initial negotiation, SSL_do_handshake on client side
// when renegotiation requested
int res = (ctx->handshakeCompleted && !ctx->activeRenegotiation) ?
SSL_read(ctx->ssl, buf, sizeof(buf)) : SSL_do_handshake(ctx->ssl);
// get pointer to data written by handshake
*len = BIO_get_mem_data(ctx->outBio, dataToSend);
int err = SSL_get_error(ctx->ssl, res);
struct timeval timeout;
// check if remote side requested renegotiation
if (!ctx->activeRenegotiation && ctx->handshakeCompleted && SSL_renegotiate_pending(ctx->ssl) == 1)
{
printf("%s Remote renegotiation detected\n", ctx->label);
ctx->activeRenegotiation = true;
}
// check if renegotiation finished
bool renegotiationFinished = ctx->activeRenegotiation && SSL_renegotiate_pending(ctx->ssl) == 0;
// handle handshake errors
switch (err)
{
case SSL_ERROR_NONE:
if (!ctx->handshakeCompleted || renegotiationFinished)
{
ctx->handshakeCompleted = true;
ctx->activeRenegotiation = false;
ctx->negotiationsCount++;
}
break;
case SSL_ERROR_WANT_READ:
if (renegotiationFinished)
{
ctx->activeRenegotiation = false;
ctx->negotiationsCount++;
}
else if (DTLSv1_get_timeout(ctx->ssl, &timeout))
{
*timeoutMs = timeout.tv_sec * 1000 + timeout.tv_usec / 1000;
wantRead = true;
printf("%s WANT_READ with timeout %d\n", ctx->label, *timeoutMs);
}
break;
default:
printf("Unexpected error while processing DTLS: %d\n", err);
assert(false);
}
return wantRead;
}
void handleIncomingData(PeerContext* ctx, uint8_t* data, size_t size)
{
printf("%s INCOMING DATA of size %d\n", ctx->label, size);
(void)BIO_reset(ctx->inBio);
(void)BIO_reset(ctx->outBio);
BIO_write(ctx->inBio, data, size);
}
void sslInfoCallbackInternal(PeerContext* ctx, const SSL* s, int where, int ret)
{
char method[100];
int w = where & ~SSL_ST_MASK;
if (w & SSL_ST_CONNECT)
sprintf(method, "SSL_connect");
else if (w & SSL_ST_ACCEPT)
sprintf(method, "SSL_accept");
if (where & SSL_CB_LOOP)
{
printf("%s %s: %s\n", ctx->label, method, SSL_state_string_long(s));
}
else if (where & SSL_CB_ALERT)
{
const char* direction = (where & SSL_CB_READ) ? "read" : "write";
printf("%s SSL3 alert %s: %s : %s \n", ctx->label, direction,
SSL_alert_type_string_long(ret),
SSL_alert_desc_string_long(ret));
}
else if (where & SSL_CB_EXIT)
{
if (ret == 0)
{
printf("%s %s failed in %s \n", ctx->label, method,
SSL_state_string_long(s));
}
else if (ret < 0)
{
printf("%s %s failed in %s \n", ctx->label, method,
SSL_state_string_long(s));
}
}
}
int main()
{
opensslInit();
PeerContext clientCtx;
clientCtx.role = DTLS_CLIENT;
PeerContext serverCtx;
serverCtx.role = DTLS_SERVER;
initContext(&clientCtx);
initContext(&serverCtx);
uint8_t* data;
size_t len;
bool clientWantRead = false;
bool serverWantRead = false;
int timeoutMsClient = 0;
int timeoutMsServer = 0;
// starting to "listen" on server:
handshakeIteration(&serverCtx, &data, &len, &timeoutMsServer);
assert(len == 0);
// initial negotiation:
while (1)
{
handshakeIteration(&clientCtx, &data, &len, &timeoutMsClient);
if (len)
handleIncomingData(&serverCtx, data, len);
if (clientCtx.handshakeCompleted)
break;
handshakeIteration(&serverCtx, &data, &len, &timeoutMsServer);
if (len)
handleIncomingData(&clientCtx, data, len);
}
assert(clientCtx.negotiationsCount == 1);
printf("======== Renegotiating ========\n");
renegotiate(&clientCtx);
int clientPacketCounter = 0;
// renegotiation loop:
while (1)
{
clientWantRead = handshakeIteration(&clientCtx, &data, &len, &timeoutMsClient);
if (len)
{
clientPacketCounter++;
printf("Client has some data to send to the server. Size %d\n", len);
//
if (clientPacketCounter != 2)
handleIncomingData(&serverCtx, data, len);
else
printf("Intentionally dropping the packet and waiting\n");
if (clientPacketCounter > 5)
{
printf("Error: too much packets sent!\n");
goto end;
}
}
// one renegotiation is enough (1 - initial negotiation, 2 - 1st renegotioation)
if (clientCtx.negotiationsCount == 2)
break;
serverWantRead = handshakeIteration(&serverCtx, &data, &len, &timeoutMsServer);
if (len)
{
printf("Server has some data to send to the client of size %d\n", len);
handleIncomingData(&clientCtx, data, len);
// client read request satisfied, no need to start timer at the end of the iteration:
clientWantRead = false;
}
if (clientWantRead || serverWantRead)
{
int timeout = timeoutMsClient < timeoutMsServer ? timeoutMsClient : timeoutMsServer;
printf("Waiting for %d ms for client to generate new flight\n", timeout);
struct timespec rem;
struct timespec req;
req.tv_sec = timeout / 1000;
req.tv_nsec = (timeout % 1000) * 1000000;
nanosleep(&req, &rem);
printf("Waiting is over\n");
}
}
printf("Renegotiated successfully!\n");
end:
SSL_shutdown(clientCtx.ssl);
SSL_free(clientCtx.ssl);
SSL_CTX_free(clientCtx.ctx);
SSL_shutdown(serverCtx.ssl);
SSL_free(serverCtx.ssl);
SSL_CTX_free(serverCtx.ctx);
opensslCleanup();
return 0;
}
// Helper functions implementation
static pthread_mutex_t* mutex_buf = NULL;
unsigned long idFunction()
{
return (unsigned long)pthread_self();
}
void opensslLockingFunc(int mode, int n,
const char* file, int line)
{
if (mode & CRYPTO_LOCK)
pthread_mutex_lock(&mutex_buf[n]);
else
pthread_mutex_unlock(&mutex_buf[n]);
}
void opensslInit()
{
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
mutex_buf = (pthread_mutex_t*)malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
int i;
for (i = 0; i < CRYPTO_num_locks(); i++)
pthread_mutex_init(&mutex_buf[i], NULL);
CRYPTO_set_locking_callback(&opensslLockingFunc);
CRYPTO_set_id_callback(&idFunction);
}
void opensslCleanup()
{
CRYPTO_set_id_callback(0);
CRYPTO_set_locking_callback(0);
int i;
for (i = 0; i < CRYPTO_num_locks(); i++)
pthread_mutex_destroy(&mutex_buf[i]);
free(mutex_buf);
ERR_free_strings();
ERR_remove_state(0);
EVP_cleanup();
CRYPTO_cleanup_all_ex_data();
}
const int gKeyLength = 1024;
// number of random bits for certificate serial number
const int gRandomBitsNum = 64;
// one year certificate validity
const int gCertificateLifetime = 60 * 60 * 24 * 365;
// to compensate for slightly incorrect system clocks
const int gCertificateValidationWindow = -60 * 60 * 24;
EVP_PKEY* generateRsaKeyPair()
{
EVP_PKEY* pkey = EVP_PKEY_new();
BIGNUM* exponent = BN_new();
RSA* rsa = RSA_new();
if (!pkey || !exponent || !rsa ||
!BN_set_word(exponent, 0x10001) ||
!RSA_generate_key_ex(rsa, gKeyLength, exponent, NULL) ||
!EVP_PKEY_assign_RSA(pkey, rsa))
{
EVP_PKEY_free(pkey);
BN_free(exponent);
RSA_free(rsa);
return NULL;
}
BN_free(exponent);
return pkey;
}
X509* generateCertificate(EVP_PKEY* pkey, const char* commonName)
{
X509* x509 = NULL;
BIGNUM* serialNumber = NULL;
X509_NAME* name = NULL;
if ((x509 = X509_new()) == NULL)
goto error;
if (!X509_set_pubkey(x509, pkey))
goto error;
ASN1_INTEGER* asn1SerialNumber;
if ((serialNumber = BN_new()) == NULL ||
!BN_pseudo_rand(serialNumber, gRandomBitsNum, 0, 0) ||
(asn1SerialNumber = X509_get_serialNumber(x509)) == NULL ||
!BN_to_ASN1_INTEGER(serialNumber, asn1SerialNumber))
goto error;
if (!X509_set_version(x509, 0L))
goto error;
if ((name = X509_NAME_new()) == NULL ||
!X509_NAME_add_entry_by_NID(name, NID_commonName, MBSTRING_UTF8,
(unsigned char*)commonName, -1, -1, 0) ||
!X509_set_subject_name(x509, name) ||
!X509_set_issuer_name(x509, name))
goto error;
if (!X509_gmtime_adj(X509_get_notBefore(x509), gCertificateValidationWindow) ||
!X509_gmtime_adj(X509_get_notAfter(x509), gCertificateLifetime))
goto error;
if (!X509_sign(x509, pkey, EVP_sha256()))
goto error;
BN_free(serialNumber);
X509_NAME_free(name);
return x509;
error:
BN_free(serialNumber);
X509_NAME_free(name);
X509_free(x509);
return NULL;
}
DtlsIdentity generateIdentity()
{
DtlsIdentity id;
id.key = generateRsaKeyPair();
id.certificate = generateCertificate(id.key, "TestCompany Inc");
return id;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment