Skip to content

Instantly share code, notes, and snippets.

@twist84
Forked from mmozeiko/tls_client.c
Last active April 13, 2023 12:46
Show Gist options
  • Save twist84/600decfa81ccf7931492529499812d02 to your computer and use it in GitHub Desktop.
Save twist84/600decfa81ccf7931492529499812d02 to your computer and use it in GitHub Desktop.
simple example of TLS socket client using win32 schannel api
#include "tls_client.hpp"
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#include <winsock2.h>
#define SECURITY_WIN32
#include <security.h>
#include <schannel.h>
#include <shlwapi.h>
#include <assert.h>
#include <stdio.h>
#pragma comment (lib, "ws2_32.lib")
#pragma comment (lib, "secur32.lib")
#pragma comment (lib, "shlwapi.lib")
#define TLS_MAX_PACKET_SIZE (16384+512) // payload + extra over head for header/mac/padding (probably an overestimate)
struct s_socket
{
SOCKET s;
CredHandle handle;
CtxtHandle context;
SecPkgContext_StreamSizes sizes;
long received; // byte count in incoming buffer (ciphertext)
long used; // byte count used from incoming buffer to decrypt current packet
long available; // byte count available for decrypted bytes
char* decrypted; // points to incoming buffer where data is decrypted inplace
char incoming[TLS_MAX_PACKET_SIZE];
};
c_tls_client::c_tls_client() :
m_socket(new s_socket{})
{
}
c_tls_client::~c_tls_client()
{
if (m_socket)
delete m_socket;
}
// returns 0 on success or negative value on error
long c_tls_client::connect(char const* hostname, unsigned short port)
{
// initialize windows sockets
WSADATA wsadata;
if (WSAStartup(MAKEWORD(2, 2), &wsadata) != 0)
{
return -1;
}
// create TCP IPv4 socket
m_socket->s = socket(AF_INET, SOCK_STREAM, 0);
if (m_socket->s == INVALID_SOCKET)
{
WSACleanup();
return -1;
}
wchar_t hostname_wide[256];
wnsprintfW(hostname_wide, 256, L"%hs", hostname);
wchar_t sport[64];
wnsprintfW(sport, 64, L"%u", port);
// connect to server
if (!WSAConnectByNameW(m_socket->s, hostname_wide, sport, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr))
{
closesocket(m_socket->s);
WSACleanup();
return -1;
}
// initialize schannel
{
SCHANNEL_CRED cred =
{
.dwVersion = SCHANNEL_CRED_VERSION,
.grbitEnabledProtocols = SP_PROT_TLS1_2, // allow only TLS v1.2
.dwFlags = SCH_USE_STRONG_CRYPTO // use only strong crypto alogorithms
| SCH_CRED_AUTO_CRED_VALIDATION // automatically validate server certificate
| SCH_CRED_NO_DEFAULT_CREDS, // no client certificate authentication
};
if (AcquireCredentialsHandleW(nullptr, (LPWSTR)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, nullptr, &cred, nullptr, nullptr, &m_socket->handle, nullptr) != SEC_E_OK)
{
closesocket(m_socket->s);
WSACleanup();
return -1;
}
}
m_socket->received = m_socket->used = m_socket->available = 0;
m_socket->decrypted = nullptr;
// perform tls handshake
// 1) call InitializeSecurityContext to create/update schannel context
// 2) when it returns SEC_E_OK - tls handshake completed
// 3) when it returns SEC_I_INCOMPLETE_CREDENTIALS - server requests client certificate (not supported here)
// 4) when it returns SEC_I_CONTINUE_NEEDED - send token to server and read data
// 5) when it returns SEC_E_INCOMPLETE_MESSAGE - need to read more data from server
// 6) otherwise read data from server and go to step 1
CtxtHandle* context = nullptr;
int result = 0;
for (;;)
{
SecBuffer inbuffers[2] = { 0 };
inbuffers[0].BufferType = SECBUFFER_TOKEN;
inbuffers[0].pvBuffer = m_socket->incoming;
inbuffers[0].cbBuffer = m_socket->received;
inbuffers[1].BufferType = SECBUFFER_EMPTY;
SecBuffer outbuffers[1] = { 0 };
outbuffers[0].BufferType = SECBUFFER_TOKEN;
SecBufferDesc indesc = { SECBUFFER_VERSION, ARRAYSIZE(inbuffers), inbuffers };
SecBufferDesc outdesc = { SECBUFFER_VERSION, ARRAYSIZE(outbuffers), outbuffers };
DWORD flags = ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM;
SECURITY_STATUS sec = InitializeSecurityContextW(
&m_socket->handle,
context,
context ? nullptr : (SEC_WCHAR*)hostname_wide,
flags,
0,
0,
context ? &indesc : nullptr,
0,
context ? nullptr : &m_socket->context,
&outdesc,
&flags,
nullptr);
// after first call to InitializeSecurityContext context is available and should be reused for next calls
context = &m_socket->context;
if (inbuffers[1].BufferType == SECBUFFER_EXTRA)
{
MoveMemory(m_socket->incoming, m_socket->incoming + (m_socket->received - inbuffers[1].cbBuffer), inbuffers[1].cbBuffer);
m_socket->received = inbuffers[1].cbBuffer;
}
else
{
m_socket->received = 0;
}
if (sec == SEC_E_OK)
{
// tls handshake completed
break;
}
else if (sec == SEC_I_INCOMPLETE_CREDENTIALS)
{
// server asked for client certificate, not supported here
result = -1;
break;
}
else if (sec == SEC_I_CONTINUE_NEEDED)
{
// need to send data to server
char* buffer = static_cast<char*>(outbuffers[0].pvBuffer);
int size = outbuffers[0].cbBuffer;
while (size != 0)
{
int d = send(m_socket->s, buffer, size, 0);
if (d <= 0)
{
break;
}
size -= d;
buffer += d;
}
if (outbuffers[0].pvBuffer)
FreeContextBuffer(outbuffers[0].pvBuffer);
if (size != 0)
{
// failed to fully send data to server
result = -1;
break;
}
}
else if (sec != SEC_E_INCOMPLETE_MESSAGE)
{
// SEC_E_CERT_EXPIRED - certificate expired or revoked
// SEC_E_WRONG_PRINCIPAL - bad hostname
// SEC_E_UNTRUSTED_ROOT - cannot vertify CA chain
// SEC_E_ILLEGAL_MESSAGE / SEC_E_ALGORITHM_MISMATCH - cannot negotiate crypto algorithms
result = -1;
break;
}
// read more data from server when possible
if (m_socket->received == sizeof(m_socket->incoming))
{
// server is sending too much data instead of proper handshake?
result = -1;
break;
}
int r = recv(m_socket->s, m_socket->incoming + m_socket->received, sizeof(m_socket->incoming) - m_socket->received, 0);
if (r == 0)
{
// server disconnected socket
return 0;
}
else if (r < 0)
{
// socket error
result = -1;
break;
}
m_socket->received += r;
}
if (result != 0)
{
DeleteSecurityContext(context);
FreeCredentialsHandle(&m_socket->handle);
closesocket(m_socket->s);
WSACleanup();
return result;
}
QueryContextAttributes(context, SECPKG_ATTR_STREAM_SIZES, &m_socket->sizes);
return 0;
}
// disconnects socket & releases resources (call this even if tls_write/tls_read function return error)
void c_tls_client::disconnect()
{
DWORD type = SCHANNEL_SHUTDOWN;
SecBuffer inbuffers[1];
inbuffers[0].BufferType = SECBUFFER_TOKEN;
inbuffers[0].pvBuffer = &type;
inbuffers[0].cbBuffer = sizeof(type);
SecBufferDesc indesc = { SECBUFFER_VERSION, ARRAYSIZE(inbuffers), inbuffers };
ApplyControlToken(&m_socket->context, &indesc);
SecBuffer outbuffers[1];
outbuffers[0].BufferType = SECBUFFER_TOKEN;
SecBufferDesc outdesc = { SECBUFFER_VERSION, ARRAYSIZE(outbuffers), outbuffers };
DWORD flags = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM;
if (InitializeSecurityContextW(&m_socket->handle, &m_socket->context, nullptr, flags, 0, 0, &outdesc, 0, nullptr, &outdesc, &flags, nullptr) == SEC_E_OK)
{
char* buffer = static_cast<char*>(outbuffers[0].pvBuffer);
int size = outbuffers[0].cbBuffer;
while (size != 0)
{
int d = send(m_socket->s, buffer, size, 0);
if (d <= 0)
{
// ignore any failures socket will be closed anyway
break;
}
buffer += d;
size -= d;
}
FreeContextBuffer(outbuffers[0].pvBuffer);
}
shutdown(m_socket->s, SD_BOTH);
DeleteSecurityContext(&m_socket->context);
FreeCredentialsHandle(&m_socket->handle);
closesocket(m_socket->s);
WSACleanup();
}
// returns 0 on success or negative value on error
long c_tls_client::write(void const* buffer, long size)
{
while (size != 0)
{
int use = min(static_cast<unsigned long>(size), m_socket->sizes.cbMaximumMessage);
static char wbuffer[TLS_MAX_PACKET_SIZE];
ZeroMemory(wbuffer, TLS_MAX_PACKET_SIZE);
assert(m_socket->sizes.cbHeader + m_socket->sizes.cbMaximumMessage + m_socket->sizes.cbTrailer <= sizeof(wbuffer));
SecBuffer buffers[3];
buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
buffers[0].pvBuffer = wbuffer;
buffers[0].cbBuffer = m_socket->sizes.cbHeader;
buffers[1].BufferType = SECBUFFER_DATA;
buffers[1].pvBuffer = wbuffer + m_socket->sizes.cbHeader;
buffers[1].cbBuffer = use;
buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
buffers[2].pvBuffer = wbuffer + m_socket->sizes.cbHeader + use;
buffers[2].cbBuffer = m_socket->sizes.cbTrailer;
CopyMemory(buffers[1].pvBuffer, buffer, use);
SecBufferDesc desc = { SECBUFFER_VERSION, ARRAYSIZE(buffers), buffers };
SECURITY_STATUS sec = EncryptMessage(&m_socket->context, 0, &desc, 0);
if (sec != SEC_E_OK)
{
// this should not happen, but just in case check it
return -1;
}
int total = static_cast<int>(buffers[0].cbBuffer + buffers[1].cbBuffer + buffers[2].cbBuffer);
int sent = 0;
while (sent != total)
{
int d = send(m_socket->s, wbuffer + sent, total - sent, 0);
if (d <= 0)
{
// error sending data to socket, or server disconnected
return -1;
}
sent += d;
}
buffer = (char*)buffer + use;
size -= use;
}
return 0;
}
// blocking read, waits & reads up to size bytes, returns amount of bytes received on success (<= size)
// returns 0 on disconnect or negative value on error
long c_tls_client::read(void* buffer, long size)
{
int result = 0;
while (size != 0)
{
if (m_socket->decrypted)
{
// if there is decrypted data available, then use it as much as possible
int use = min(size, m_socket->available);
CopyMemory(buffer, m_socket->decrypted, use);
buffer = (char*)buffer + use;
size -= use;
result += use;
if (use == m_socket->available)
{
// all decrypted data is used, remove ciphertext from incoming buffer so next time it starts from beginning
MoveMemory(m_socket->incoming, m_socket->incoming + m_socket->used, m_socket->received - m_socket->used);
m_socket->received -= m_socket->used;
m_socket->used = 0;
m_socket->available = 0;
m_socket->decrypted = nullptr;
}
else
{
m_socket->available -= use;
m_socket->decrypted += use;
}
}
else
{
// if any ciphertext data available then try to decrypt it
if (m_socket->received != 0)
{
SecBuffer buffers[4];
assert(m_socket->sizes.cBuffers == ARRAYSIZE(buffers));
buffers[0].BufferType = SECBUFFER_DATA;
buffers[0].pvBuffer = m_socket->incoming;
buffers[0].cbBuffer = m_socket->received;
buffers[1].BufferType = SECBUFFER_EMPTY;
buffers[2].BufferType = SECBUFFER_EMPTY;
buffers[3].BufferType = SECBUFFER_EMPTY;
SecBufferDesc desc = { SECBUFFER_VERSION, ARRAYSIZE(buffers), buffers };
SECURITY_STATUS sec = DecryptMessage(&m_socket->context, &desc, 0, nullptr);
if (sec == SEC_E_OK)
{
assert(buffers[0].BufferType == SECBUFFER_STREAM_HEADER);
assert(buffers[1].BufferType == SECBUFFER_DATA);
assert(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER);
m_socket->decrypted = static_cast<char*>(buffers[1].pvBuffer);
m_socket->available = buffers[1].cbBuffer;
m_socket->used = m_socket->received - (buffers[3].BufferType == SECBUFFER_EXTRA ? buffers[3].cbBuffer : 0);
// data is now decrypted, go back to beginning of loop to copy memory to output buffer
continue;
}
else if (sec == SEC_I_CONTEXT_EXPIRED)
{
// server closed TLS connection (but socket is still open)
m_socket->received = 0;
return result;
}
else if (sec == SEC_I_RENEGOTIATE)
{
// server wants to renegotiate TLS connection, not implemented here
return -1;
}
else if (sec != SEC_E_INCOMPLETE_MESSAGE)
{
// some other schannel or TLS protocol error
return -1;
}
// otherwise sec == SEC_E_INCOMPLETE_MESSAGE which means need to read more data
}
// otherwise not enough data received to decrypt
if (result != 0)
{
// some data is already copied to output buffer, so return that before blocking with recv
break;
}
if (m_socket->received == sizeof(m_socket->incoming))
{
// server is sending too much garbage data instead of proper TLS packet
return -1;
}
// wait for more ciphertext data from server
int r = recv(m_socket->s, m_socket->incoming + m_socket->received, sizeof(m_socket->incoming) - m_socket->received, 0);
if (r == 0)
{
// server disconnected socket
return 0;
}
else if (r < 0)
{
// error receiving data from socket
result = -1;
break;
}
m_socket->received += r;
}
}
return result;
}
#pragma once
struct s_socket;
class c_tls_client
{
public:
c_tls_client();
~c_tls_client();
// returns 0 on success or negative value on error
long connect(char const* hostname, unsigned short port);
// disconnects socket & releases resources (call this even if the write/read functions return error)
void disconnect();
// returns 0 on success or negative value on error
long write(void const* buffer, long size);
// blocking read, waits & reads up to size bytes, returns amount of bytes received on success (<= size)
// returns 0 on disconnect or negative value on error
long read(void* buffer, long size);
protected:
s_socket* m_socket;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment