Skip to content

Instantly share code, notes, and snippets.

@mmozeiko
Last active December 1, 2024 23:48
Show Gist options
  • Save mmozeiko/c0dfcc8fec527a90a02145d2cc0bfb6d to your computer and use it in GitHub Desktop.
Save mmozeiko/c0dfcc8fec527a90a02145d2cc0bfb6d to your computer and use it in GitHub Desktop.
simple example of TLS socket client using win32 schannel api
#define WIN32_LEAN_AND_MEAN
#include <winsock2.h>
#include <windows.h>
#define SECURITY_WIN32
#include <security.h>
#include <schannel.h>
#include <shlwapi.h>
#include <assert.h>
#include <stdio.h>
#pragma comment (lib, "ws2_32.lib")
#pragma comment (lib, "secur32.lib")
#pragma comment (lib, "shlwapi.lib")
#define TLS_MAX_PACKET_SIZE (16384+512) // payload + extra over head for header/mac/padding (probably an overestimate)
typedef struct {
SOCKET sock;
CredHandle handle;
CtxtHandle context;
SecPkgContext_StreamSizes sizes;
int received; // byte count in incoming buffer (ciphertext)
int used; // byte count used from incoming buffer to decrypt current packet
int available; // byte count available for decrypted bytes
char* decrypted; // points to incoming buffer where data is decrypted inplace
char incoming[TLS_MAX_PACKET_SIZE];
} tls_socket;
// returns 0 on success or negative value on error
static int tls_connect(tls_socket* s, const char* hostname, unsigned short port)
{
// initialize windows sockets
WSADATA wsadata;
if (WSAStartup(MAKEWORD(2, 2), &wsadata) != 0)
{
return -1;
}
// create TCP IPv4 socket
s->sock = socket(AF_INET, SOCK_STREAM, 0);
if (s->sock == INVALID_SOCKET)
{
WSACleanup();
return -1;
}
char sport[64];
wnsprintfA(sport, sizeof(sport), "%u", port);
// connect to server
if (!WSAConnectByNameA(s->sock, hostname, sport, NULL, NULL, NULL, NULL, NULL, NULL))
{
closesocket(s->sock);
WSACleanup();
return -1;
}
// initialize schannel
{
SCHANNEL_CRED cred =
{
.dwVersion = SCHANNEL_CRED_VERSION,
.dwFlags = SCH_USE_STRONG_CRYPTO // use only strong crypto alogorithms
| SCH_CRED_AUTO_CRED_VALIDATION // automatically validate server certificate
| SCH_CRED_NO_DEFAULT_CREDS, // no client certificate authentication
.grbitEnabledProtocols = SP_PROT_TLS1_2, // allow only TLS v1.2
};
if (AcquireCredentialsHandleA(NULL, UNISP_NAME_A, SECPKG_CRED_OUTBOUND, NULL, &cred, NULL, NULL, &s->handle, NULL) != SEC_E_OK)
{
closesocket(s->sock);
WSACleanup();
return -1;
}
}
s->received = s->used = s->available = 0;
s->decrypted = NULL;
// perform tls handshake
// 1) call InitializeSecurityContext to create/update schannel context
// 2) when it returns SEC_E_OK - tls handshake completed
// 3) when it returns SEC_I_INCOMPLETE_CREDENTIALS - server requests client certificate (not supported here)
// 4) when it returns SEC_I_CONTINUE_NEEDED - send token to server and read data
// 5) when it returns SEC_E_INCOMPLETE_MESSAGE - need to read more data from server
// 6) otherwise read data from server and go to step 1
CtxtHandle* context = NULL;
int result = 0;
for (;;)
{
SecBuffer inbuffers[2] = { 0 };
inbuffers[0].BufferType = SECBUFFER_TOKEN;
inbuffers[0].pvBuffer = s->incoming;
inbuffers[0].cbBuffer = s->received;
inbuffers[1].BufferType = SECBUFFER_EMPTY;
SecBuffer outbuffers[1] = { 0 };
outbuffers[0].BufferType = SECBUFFER_TOKEN;
SecBufferDesc indesc = { SECBUFFER_VERSION, ARRAYSIZE(inbuffers), inbuffers };
SecBufferDesc outdesc = { SECBUFFER_VERSION, ARRAYSIZE(outbuffers), outbuffers };
DWORD flags = ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM;
SECURITY_STATUS sec = InitializeSecurityContextA(
&s->handle,
context,
context ? NULL : (SEC_CHAR*)hostname,
flags,
0,
0,
context ? &indesc : NULL,
0,
context ? NULL : &s->context,
&outdesc,
&flags,
NULL);
// after first call to InitializeSecurityContext context is available and should be reused for next calls
context = &s->context;
if (inbuffers[1].BufferType == SECBUFFER_EXTRA)
{
MoveMemory(s->incoming, s->incoming + (s->received - inbuffers[1].cbBuffer), inbuffers[1].cbBuffer);
s->received = inbuffers[1].cbBuffer;
}
else
{
s->received = 0;
}
if (sec == SEC_E_OK)
{
// tls handshake completed
break;
}
else if (sec == SEC_I_INCOMPLETE_CREDENTIALS)
{
// server asked for client certificate, not supported here
result = -1;
break;
}
else if (sec == SEC_I_CONTINUE_NEEDED)
{
// need to send data to server
char* buffer = outbuffers[0].pvBuffer;
int size = outbuffers[0].cbBuffer;
while (size != 0)
{
int d = send(s->sock, buffer, size, 0);
if (d <= 0)
{
break;
}
size -= d;
buffer += d;
}
FreeContextBuffer(outbuffers[0].pvBuffer);
if (size != 0)
{
// failed to fully send data to server
result = -1;
break;
}
}
else if (sec != SEC_E_INCOMPLETE_MESSAGE)
{
// SEC_E_CERT_EXPIRED - certificate expired or revoked
// SEC_E_WRONG_PRINCIPAL - bad hostname
// SEC_E_UNTRUSTED_ROOT - cannot vertify CA chain
// SEC_E_ILLEGAL_MESSAGE / SEC_E_ALGORITHM_MISMATCH - cannot negotiate crypto algorithms
result = -1;
break;
}
// read more data from server when possible
if (s->received == sizeof(s->incoming))
{
// server is sending too much data instead of proper handshake?
result = -1;
break;
}
int r = recv(s->sock, s->incoming + s->received, sizeof(s->incoming) - s->received, 0);
if (r == 0)
{
// server disconnected socket
return 0;
}
else if (r < 0)
{
// socket error
result = -1;
break;
}
s->received += r;
}
if (result != 0)
{
DeleteSecurityContext(context);
FreeCredentialsHandle(&s->handle);
closesocket(s->sock);
WSACleanup();
return result;
}
QueryContextAttributes(context, SECPKG_ATTR_STREAM_SIZES, &s->sizes);
return 0;
}
// disconnects socket & releases resources (call this even if tls_write/tls_read function return error)
static void tls_disconnect(tls_socket* s)
{
DWORD type = SCHANNEL_SHUTDOWN;
SecBuffer inbuffers[1];
inbuffers[0].BufferType = SECBUFFER_TOKEN;
inbuffers[0].pvBuffer = &type;
inbuffers[0].cbBuffer = sizeof(type);
SecBufferDesc indesc = { SECBUFFER_VERSION, ARRAYSIZE(inbuffers), inbuffers };
ApplyControlToken(&s->context, &indesc);
SecBuffer outbuffers[1];
outbuffers[0].BufferType = SECBUFFER_TOKEN;
SecBufferDesc outdesc = { SECBUFFER_VERSION, ARRAYSIZE(outbuffers), outbuffers };
DWORD flags = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM;
if (InitializeSecurityContextA(&s->handle, &s->context, NULL, flags, 0, 0, &outdesc, 0, NULL, &outdesc, &flags, NULL) == SEC_E_OK)
{
char* buffer = outbuffers[0].pvBuffer;
int size = outbuffers[0].cbBuffer;
while (size != 0)
{
int d = send(s->sock, buffer, size, 0);
if (d <= 0)
{
// ignore any failures socket will be closed anyway
break;
}
buffer += d;
size -= d;
}
FreeContextBuffer(outbuffers[0].pvBuffer);
}
shutdown(s->sock, SD_BOTH);
DeleteSecurityContext(&s->context);
FreeCredentialsHandle(&s->handle);
closesocket(s->sock);
WSACleanup();
}
// returns 0 on success or negative value on error
static int tls_write(tls_socket* s, const void* buffer, int size)
{
while (size != 0)
{
int use = min(size, s->sizes.cbMaximumMessage);
char wbuffer[TLS_MAX_PACKET_SIZE];
assert(s->sizes.cbHeader + s->sizes.cbMaximumMessage + s->sizes.cbTrailer <= sizeof(wbuffer));
SecBuffer buffers[3];
buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
buffers[0].pvBuffer = wbuffer;
buffers[0].cbBuffer = s->sizes.cbHeader;
buffers[1].BufferType = SECBUFFER_DATA;
buffers[1].pvBuffer = wbuffer + s->sizes.cbHeader;
buffers[1].cbBuffer = use;
buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
buffers[2].pvBuffer = wbuffer + s->sizes.cbHeader + use;
buffers[2].cbBuffer = s->sizes.cbTrailer;
CopyMemory(buffers[1].pvBuffer, buffer, use);
SecBufferDesc desc = { SECBUFFER_VERSION, ARRAYSIZE(buffers), buffers };
SECURITY_STATUS sec = EncryptMessage(&s->context, 0, &desc, 0);
if (sec != SEC_E_OK)
{
// this should not happen, but just in case check it
return -1;
}
int total = buffers[0].cbBuffer + buffers[1].cbBuffer + buffers[2].cbBuffer;
int sent = 0;
while (sent != total)
{
int d = send(s->sock, wbuffer + sent, total - sent, 0);
if (d <= 0)
{
// error sending data to socket, or server disconnected
return -1;
}
sent += d;
}
buffer = (char*)buffer + use;
size -= use;
}
return 0;
}
// blocking read, waits & reads up to size bytes, returns amount of bytes received on success (<= size)
// returns 0 on disconnect or negative value on error
static int tls_read(tls_socket* s, void* buffer, int size)
{
int result = 0;
while (size != 0)
{
if (s->decrypted)
{
// if there is decrypted data available, then use it as much as possible
int use = min(size, s->available);
CopyMemory(buffer, s->decrypted, use);
buffer = (char*)buffer + use;
size -= use;
result += use;
if (use == s->available)
{
// all decrypted data is used, remove ciphertext from incoming buffer so next time it starts from beginning
MoveMemory(s->incoming, s->incoming + s->used, s->received - s->used);
s->received -= s->used;
s->used = 0;
s->available = 0;
s->decrypted = NULL;
}
else
{
s->available -= use;
s->decrypted += use;
}
}
else
{
// if any ciphertext data available then try to decrypt it
if (s->received != 0)
{
SecBuffer buffers[4];
assert(s->sizes.cBuffers == ARRAYSIZE(buffers));
buffers[0].BufferType = SECBUFFER_DATA;
buffers[0].pvBuffer = s->incoming;
buffers[0].cbBuffer = s->received;
buffers[1].BufferType = SECBUFFER_EMPTY;
buffers[2].BufferType = SECBUFFER_EMPTY;
buffers[3].BufferType = SECBUFFER_EMPTY;
SecBufferDesc desc = { SECBUFFER_VERSION, ARRAYSIZE(buffers), buffers };
SECURITY_STATUS sec = DecryptMessage(&s->context, &desc, 0, NULL);
if (sec == SEC_E_OK)
{
assert(buffers[0].BufferType == SECBUFFER_STREAM_HEADER);
assert(buffers[1].BufferType == SECBUFFER_DATA);
assert(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER);
s->decrypted = buffers[1].pvBuffer;
s->available = buffers[1].cbBuffer;
s->used = s->received - (buffers[3].BufferType == SECBUFFER_EXTRA ? buffers[3].cbBuffer : 0);
// data is now decrypted, go back to beginning of loop to copy memory to output buffer
continue;
}
else if (sec == SEC_I_CONTEXT_EXPIRED)
{
// server closed TLS connection (but socket is still open)
s->received = 0;
return result;
}
else if (sec == SEC_I_RENEGOTIATE)
{
// server wants to renegotiate TLS connection, not implemented here
return -1;
}
else if (sec != SEC_E_INCOMPLETE_MESSAGE)
{
// some other schannel or TLS protocol error
return -1;
}
// otherwise sec == SEC_E_INCOMPLETE_MESSAGE which means need to read more data
}
// otherwise not enough data received to decrypt
if (result != 0)
{
// some data is already copied to output buffer, so return that before blocking with recv
break;
}
if (s->received == sizeof(s->incoming))
{
// server is sending too much garbage data instead of proper TLS packet
return -1;
}
// wait for more ciphertext data from server
int r = recv(s->sock, s->incoming + s->received, sizeof(s->incoming) - s->received, 0);
if (r == 0)
{
// server disconnected socket
return 0;
}
else if (r < 0)
{
// error receiving data from socket
result = -1;
break;
}
s->received += r;
}
}
return result;
}
int main()
{
const char* hostname = "www.google.com";
//const char* hostname = "badssl.com";
//const char* hostname = "expired.badssl.com";
//const char* hostname = "wrong.host.badssl.com";
//const char* hostname = "self-signed.badssl.com";
//const char* hostname = "untrusted-root.badssl.com";
const char* path = "/";
tls_socket s;
if (tls_connect(&s, hostname, 443) != 0)
{
printf("Error connecting to %s\n", hostname);
return -1;
}
printf("Connected!\n");
// send request
char req[1024];
int len = sprintf(req, "GET / HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", hostname);
if (tls_write(&s, req, len) != 0)
{
tls_disconnect(&s);
return -1;
}
// write response to file
FILE* f = fopen("response.txt", "wb");
int received = 0;
for (;;)
{
char buf[65536];
int r = tls_read(&s, buf, sizeof(buf));
if (r < 0)
{
printf("Error receiving data\n");
break;
}
else if (r == 0)
{
printf("Socket disconnected\n");
break;
}
else
{
fwrite(buf, 1, r, f);
fflush(f);
received += r;
}
}
fclose(f);
printf("Received %d bytes\n", received);
tls_disconnect(&s);
}
@mmozeiko
Copy link
Author

mmozeiko commented Jun 16, 2023

I'd say this is pretty complete example for simple TLS connections. Client certificate authentication is one thing that's missing. Another, a bit minor thing, is ALPN that some TLS protocols use (like Syncthing) - I have not looked how to do that, but I believe it should be doable.

Having user ability to read server certificate and override its verification (like ignoring expire date, or ignore hostname) would be another useful thing to offer as API.

That said - I feel like using manually TLS sockets would be a very rare situation on Windows. Most likely you want HTTPS client connections, and for that using WinHTTP API makes way more sense - it even supports WebSockets. Or for listening server connections - use HTTP Server API.

@RandyGaul
Copy link

RandyGaul commented Jun 16, 2023

I did spot one potential issue. After swapping to non-blocking sockets the timings of recv/send are different (I end up calling them quite quickly). Sometimes during the handshake SEC_E_INCOMPLETE_MESSAGE can be encountered, meaning a decrypt failed as the full record was not present. In this case we need to call recv and append more data, then try again.

if (inbuffers[1].BufferType == SECBUFFER_EXTRA)
{
	MoveMemory(s->incoming, s->incoming + (s->received - inbuffers[1].cbBuffer), inbuffers[1].cbBuffer);
	s->received = inbuffers[1].cbBuffer;
}
else
{
	s->received = 0;
}

But this code snippet zero's out the received buffer no matter what. Instead I tried a slight modification to not clear to zero if we get a SECBUFFER_MISSING code, so more data can then be appended and decrypted later.

if (inbuffers[1].BufferType == SECBUFFER_EXTRA)
{
	MoveMemory(s->incoming, s->incoming + (s->received - inbuffers[1].cbBuffer), inbuffers[1].cbBuffer);
	s->received = inbuffers[1].cbBuffer;
}
else if (inbuffers[1].BufferType != SECBUFFER_MISSING)
{
	s->received = 0;
}

Testing this is a bit annoying as it was only happening occasionally to begin with, but, I haven't hit the issue since. Please let me know your thoughts!

@mmozeiko
Copy link
Author

Yes, that sounds reasonable, I have missed that case.
You can probably test such case more reliably by writing wrapper over "recv" call and returning just 1 byte at a time to caller. Basically read everything that's available from socket, but pretend you got only 1 byte, and return it from cached buffer in further calls until buffer is empty.

@RandyGaul
Copy link

Here's my initial implementation of a TLS header, the Windows implementation completely based off of Martins' code here :) https://github.com/RandyGaul/cute_headers/blob/master/cute_tls.h

Thanks again for posting this stuff! Super helpful.

I'll be using this code to hook up an HTTP layer in another project quite soon~

@mmozeiko
Copy link
Author

@never-unsealed What instability? InitializeSecurityContex says to pass NULL for pNewContext in subsequent calls when using schannel API:

When using the Schannel SSP, on calls after the first call, pass the handle returned here as the phContext parameter and specify NULL for phNewContext.

Maybe doing non-NULL there is relevant for non-schannel providers?

curl and Qt tls usage also passes NULL there.

@mmozeiko
Copy link
Author

Hmm, that sounds like a issue with kernel-mode implementation of it.
Because I'd be very surprised if you could BSOD just by passing NULL there in user-space call. I mean nothing is impossible, but this would mean Windows breaking compatibility with a lot of deployed Qt/curl applications out there - not even counting custom code using schannel.

@RandyGaul
Copy link

Thanks for bringing it up anyways @never-unsealed

@Moon4u
Copy link

Moon4u commented Dec 20, 2023

I am trying to use SChannel and stumbled upon this gist, but I have a question:

        SECURITY_STATUS sec = InitializeSecurityContextA(
            &s->handle,
            context,
            context ? NULL : (SEC_CHAR*)hostname,
            flags,
            0,
            0,
            context ? &indesc : NULL,
            0,
            context ? NULL : &s->context,
            &outdesc,
            &flags,
            NULL);

        //.....

        int r = recv(s->sock, s->incoming + s->received, sizeof(s->incoming) - s->received, 0);
        if (r == 0)
        {
            // server disconnected socket
            return 0;
        }

In this snipped (or this line) - if the server disconnects, then we can't send/recv any data, so shouldn't the tls_connect return -1 in this case?

Sorry if this is a dumb question, I am learning.

@ticehujl1
Copy link

Do you have plans to update your example to support TLS1.3 ?
i.e. Update to SCH_CREDENTIALS and handle a Renegotiate from DecryptMessage

John

@IvanGazul
Copy link

thanks a lot, that helped me to write custom TLS/SSL socket for my bot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment