Created
June 4, 2019 18:15
-
-
Save ssrlive/c2fbfcd4421dbb07db91cc6767dc033d to your computer and use it in GitHub Desktop.
WebSocket server demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <string.h> | |
#include <stdlib.h> | |
#include <stdio.h> | |
#include <unistd.h> | |
#include <sys/types.h> | |
#include <http_parser.h> | |
#include <inttypes.h> | |
#include <ctype.h> | |
#include <mbedtls/sha1.h> | |
#include <mbedtls/base64.h> | |
#if defined(WIN32) || defined(_WIN32) | |
#include <WinSock2.h> | |
#include <WS2tcpip.h> | |
#pragma comment(lib,"ws2_32.lib") | |
#else | |
#include <sys/socket.h> | |
#include <netinet/in.h> | |
#endif | |
#define MAX_CONNECTIONS_BACKLOG 5 | |
#define MAX_HTTP_HEADERS 20 | |
#define MSG_BUFFER_SIZE 1024*80 | |
#define MAX_HTTP_MSG_SIZE MSG_BUFFER_SIZE*10 | |
#define MAX_FRAME_SIZE 2^32 | |
struct http_header { | |
char *key; | |
char *value; | |
}; | |
struct http_headers { | |
size_t count; | |
struct http_header headers[MAX_HTTP_HEADERS]; | |
int complete; | |
}; | |
void exit_with_error(const char *err) { | |
perror(err); | |
exit(1); | |
} | |
int get_socket_port(const int sockfd) { | |
struct sockaddr_in socket_addr; | |
socklen_t len = sizeof(socket_addr); | |
int could_get_sockname = | |
getsockname(sockfd, (struct sockaddr *)&socket_addr, &len) == 0; | |
if (could_get_sockname) | |
return ntohs(socket_addr.sin_port); | |
else return -1; | |
} | |
int on_header_field(const http_parser *parser, const char *at, const size_t len) { | |
struct http_headers *parsed_headers = (struct http_headers *)parser->data; | |
char *key; | |
if (parsed_headers->count == MAX_HTTP_HEADERS) { | |
printf("Request contained too many headers\n"); | |
exit(1); | |
} | |
key = (char *) calloc(1, len + 1); | |
strncpy(key, at, len); | |
parsed_headers->headers[parsed_headers->count].key = key; | |
return 0; | |
} | |
int on_header_value(http_parser *parser, char *at, size_t len) { | |
struct http_headers *parsed_headers = (struct http_headers *) parser->data; | |
char *value = (char *) calloc(1, len + 1); | |
strncpy(value, at, len); | |
parsed_headers->headers[parsed_headers->count].value = value; | |
parsed_headers->count++; | |
return 0; | |
} | |
int on_headers_complete(http_parser *parser) { | |
((struct http_headers *)parser->data)->complete = 1; | |
return 0; | |
} | |
const char * get_header_val(const struct http_headers *headers, const char *header_key) { | |
size_t len = headers->count; | |
size_t i; | |
for(i = 0; i < len; i++) { | |
struct http_header h = headers->headers[i]; | |
if (strcmp(h.key, header_key) == 0) { | |
return h.value; | |
} | |
} | |
return NULL; | |
} | |
void upcase(char *s) { | |
while((*s = toupper(*s))) | |
++s; | |
} | |
char * generate_sec_websocket_accept(const char *sec_websocket_key, void*(*allocator)(size_t size)) { | |
#ifndef SHA_DIGEST_LENGTH | |
#define SHA_DIGEST_LENGTH 20 | |
#endif | |
mbedtls_sha1_context sha1_ctx = { 0 }; | |
unsigned char sha1_hash[SHA_DIGEST_LENGTH] = { 0 }; | |
size_t b64_str_len = 0; | |
char *b64_str; | |
size_t concatenated_val_len; | |
char *concatenated_val; | |
if (sec_websocket_key==NULL || 0==strlen(sec_websocket_key) || allocator==NULL) { | |
return NULL; | |
} | |
#define WEBSOCKET_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" | |
concatenated_val_len = strlen(sec_websocket_key) + strlen(WEBSOCKET_GUID); | |
concatenated_val = (char *) calloc(concatenated_val_len + 1, sizeof(char)); | |
strcat(concatenated_val, sec_websocket_key); | |
strcat(concatenated_val, WEBSOCKET_GUID); | |
mbedtls_sha1_init(&sha1_ctx); | |
mbedtls_sha1_starts(&sha1_ctx); | |
mbedtls_sha1_update(&sha1_ctx, (unsigned char *)concatenated_val, concatenated_val_len); | |
mbedtls_sha1_finish(&sha1_ctx, sha1_hash); | |
mbedtls_sha1_free(&sha1_ctx); | |
mbedtls_base64_encode(NULL, 0, &b64_str_len, sha1_hash, sizeof(sha1_hash)); | |
b64_str = (char *) allocator(b64_str_len + 1); | |
b64_str[b64_str_len] = 0; | |
mbedtls_base64_encode((unsigned char *)b64_str, b64_str_len, &b64_str_len, sha1_hash, sizeof(sha1_hash)); | |
free(concatenated_val); | |
return b64_str; | |
} | |
unsigned char * websocket_server_retrieve_payload(unsigned char *buf, size_t len, void*(*allocator)(size_t size), size_t *payload_len) { | |
if (buf==NULL || len==0 || allocator==NULL) { | |
return NULL; | |
} | |
else { | |
// A client message follows the binary format outlined in the RFC | |
unsigned char has_fin = buf[0] & 0x80; | |
unsigned char has_rsv1 = buf[0] & 0x40; | |
unsigned char has_rsv2 = buf[0] & 0x20; | |
unsigned char has_rsv3 = buf[0] & 0x10; | |
unsigned char op_code = buf[0] & 0xF; | |
unsigned char has_mask = buf[1] & 0x80; | |
unsigned char small_payload_len = buf[1] & 0x7F; | |
unsigned char mask_offset; | |
size_t _payload_len; | |
unsigned char masking_key[4]; | |
unsigned char *masked_payload_data; | |
const char *payload_start; | |
size_t i; | |
if (small_payload_len < 126) { | |
// Just use the specified length | |
_payload_len = (size_t) small_payload_len; | |
mask_offset = 2; | |
} else if (small_payload_len == 126) { | |
unsigned short payload_len_nbo = *((unsigned short *)buf + 2); | |
_payload_len = (size_t) ntohs(payload_len_nbo); | |
mask_offset = 4; | |
} else { | |
// The following 8 bytes are an unsigned 64-bit integer. MSB = 0 | |
// multi-byte lengths are in network byte order | |
fprintf(stderr, "64-bit payload lengths not supported (no ntohll available)\n"); | |
exit(1); | |
mask_offset = 10; | |
} | |
masking_key[0] = buf[mask_offset]; | |
masking_key[1] = buf[mask_offset + 1]; | |
masking_key[2] = buf[mask_offset + 2]; | |
masking_key[3] = buf[mask_offset + 3]; | |
masked_payload_data = (unsigned char *) allocator(_payload_len + 1); | |
if (masked_payload_data == NULL) { | |
return NULL; | |
} | |
masked_payload_data[_payload_len] = 0; | |
payload_start = (char *)buf + mask_offset + 4; | |
memcpy(masked_payload_data, payload_start, _payload_len); | |
for (i = 0; i < _payload_len; i++) { | |
char mask = masking_key[i % 4]; | |
masked_payload_data[i] = masked_payload_data[i] ^ mask; | |
} | |
if (payload_len) { | |
*payload_len = _payload_len; | |
} | |
return masked_payload_data; | |
} | |
} | |
unsigned char * websocket_server_build_frame(const char *payload, size_t payload_len, void*(*allocator)(size_t), size_t *frame_len) { | |
unsigned char *frame_buf; | |
size_t offset; | |
size_t msg_size; | |
if (payload==NULL || payload_len==0 || allocator==NULL) { | |
return NULL; | |
} | |
frame_buf = (unsigned char *) allocator(payload_len + 10 + 1); | |
if (frame_buf == NULL) { | |
return NULL; | |
} | |
memset(frame_buf, 0, payload_len + 10 + 1); | |
// FIN = 1 (it's the last message) RSV1 = 0, RSV2 = 0, RSV3 = | |
// 0 OpCode(4b) = 1 (text) | |
frame_buf[0] = 0x81; | |
if (payload_len < 126) { | |
offset = 2; | |
frame_buf[1] = (char)payload_len; | |
} else if (payload_len < 65536) { | |
offset = 4; | |
frame_buf[1] = 126; | |
*((short *)frame_buf + 2) = htons(payload_len); | |
} else { | |
fprintf(stderr, "Cannot write payloads larger than 2^32 bytes (can't htoni)"); | |
exit(1); | |
} | |
memcpy(frame_buf + offset, payload, payload_len); | |
msg_size = offset + payload_len; | |
if (frame_len) { | |
*frame_len = msg_size; | |
} | |
return frame_buf; | |
} | |
int perform_websocket_handshake(int client_sockfd, const struct http_headers *upgrade_headers) { | |
const char *sec_websocket_val = get_header_val(upgrade_headers, "Sec-WebSocket-Key"); | |
if (sec_websocket_val == NULL) { | |
printf("Request did not contain a Sec-WebSocket-Key\n"); | |
return 0; | |
} else { | |
char *b64_str; | |
char response[2048]; | |
b64_str = generate_sec_websocket_accept(sec_websocket_val, &malloc); | |
sprintf(response, | |
"HTTP/1.1 101 Switching Protocols\r\n" | |
"Upgrade: websocket\r\n" | |
"Connection: Upgrade\r\n" | |
"Sec-WebSocket-Accept: %s\r\n\r\n", | |
b64_str); | |
//write(client_sockfd, response, strlen(response)); | |
send(client_sockfd, response, strlen(response), 0); | |
fprintf(stderr, "Handshake sent\n"); | |
free(b64_str); | |
return 1; | |
} | |
} | |
int send_websocket_frame(int client_sockfd, const char *payload, size_t payload_len) { | |
size_t msg_size = 0; | |
unsigned char *frame_buf = websocket_server_build_frame(payload, payload_len, &malloc, &msg_size); | |
// write(client_sockfd, frame_buf, msg_size); | |
send(client_sockfd, frame_buf, msg_size, 0); | |
return 0; | |
} | |
void handle_websocket(int client_sockfd, const struct http_headers *upgrade_headers) { | |
// RFC6455 for websocket: https://tools.ietf.org/html/rfc6455 | |
int handshake_was_successful = | |
perform_websocket_handshake(client_sockfd, upgrade_headers); | |
if (handshake_was_successful) { | |
unsigned char buf[MSG_BUFFER_SIZE] = { 0 }; | |
while(1) { | |
ssize_t client_msg_len = recv(client_sockfd, (char *)buf, sizeof(buf), 0); | |
if (client_msg_len < 0) { | |
perror("Error reading websocket frame from client"); | |
return; | |
} else if (client_msg_len == 0) { | |
fprintf(stderr, "Client closed connection while in frame mode\n"); | |
break; | |
} else { | |
size_t payload_len = 0; | |
unsigned char *masked_payload_data = | |
websocket_server_retrieve_payload(buf, client_msg_len, &malloc, &payload_len); | |
printf("Unmasked Payload: %s\n", masked_payload_data); | |
upcase((char *)masked_payload_data); | |
// Echo the upcased message back to the client | |
send_websocket_frame(client_sockfd, (char *)masked_payload_data, payload_len); | |
free(masked_payload_data); | |
} | |
} | |
} else { | |
fprintf(stderr, "Client handshake failed\n"); | |
} | |
} | |
void handle_client_connection(int client_sockfd) { | |
// The first thing a client should send is a HTTP GET request | |
// to upgrade the connection to Websocket. Anything else is a | |
// bad request. | |
ssize_t bytes_recv_total = 0; | |
http_parser_settings settings = {0}; | |
settings.on_header_field = (http_data_cb)on_header_field; | |
settings.on_header_value = (http_data_cb)on_header_value; | |
settings.on_headers_complete = (http_cb)on_headers_complete; | |
for(;;) { | |
struct http_headers *parsed_headers; | |
char buf[MSG_BUFFER_SIZE] = { 0 }; | |
ssize_t bytes_recv; | |
http_parser *parser = (http_parser *)malloc(sizeof(http_parser)); | |
http_parser_init(parser, HTTP_REQUEST); | |
parsed_headers = (struct http_headers *)calloc(1, sizeof(struct http_headers)); | |
parser->data = parsed_headers; | |
bytes_recv = recv(client_sockfd, buf, sizeof(buf), 0); | |
bytes_recv_total += bytes_recv; | |
if (bytes_recv > 0) { | |
http_parser_execute(parser, &settings, buf, bytes_recv); | |
printf("%s\n", buf); | |
if (parsed_headers->complete) { | |
if (parser->upgrade) { | |
printf("Connection upgrade requested. Performing upgrade\n"); | |
handle_websocket(client_sockfd, parsed_headers); | |
break; | |
} else { | |
fprintf(stderr, "Request was not an upgrade request.\n"); | |
break; | |
} | |
} | |
} else if (bytes_recv == 0) { | |
fprintf(stderr, "Client closed the connection\n"); | |
break; | |
} else { | |
perror("Error recieving data from client"); | |
break; | |
} | |
if (bytes_recv_total > MAX_HTTP_MSG_SIZE) { | |
fprintf(stderr, "Client initial HTTP message exceeded the maximum message size\n"); | |
break; | |
} else continue; | |
free(parser); | |
free(parsed_headers); | |
} | |
} | |
void accept_connections_through(int server_sockfd) { | |
for(;;) { | |
struct sockaddr_in client_address; | |
int client_address_len = (int) sizeof(client_address); | |
int client_sockfd = | |
accept(server_sockfd, (struct sockaddr *) &client_address, &client_address_len); | |
if (client_sockfd > 0) { | |
handle_client_connection(client_sockfd); | |
shutdown(client_sockfd, 2); | |
} else { | |
exit_with_error("Error accepting client connection"); | |
} | |
} | |
} | |
int try_open_server_on_port(int port) { | |
int server_socket = socket(AF_INET, SOCK_STREAM, 0); | |
struct sockaddr_in server_address = {0}; | |
int socket_could_not_bind; | |
int socket_could_not_start_listening; | |
int server_port; | |
if (server_socket < 0) | |
exit_with_error("Error opening socket"); | |
server_address.sin_family = AF_INET; | |
server_address.sin_addr.s_addr = INADDR_ANY; | |
server_address.sin_port = port; | |
socket_could_not_bind = | |
bind(server_socket, (struct sockaddr *) &server_address, sizeof(server_address)) < 0; | |
if (socket_could_not_bind) | |
exit_with_error("Error binding the socket to an address"); | |
socket_could_not_start_listening = listen(server_socket, MAX_CONNECTIONS_BACKLOG) < 0; | |
if (socket_could_not_start_listening) | |
exit_with_error("Error making the socket start listening"); | |
server_port = get_socket_port(server_socket); | |
if (server_port < 0) | |
exit_with_error("Error getting the server's port\n"); | |
fprintf(stderr, "Server listening on port %i\n", server_port); | |
return server_socket; | |
} | |
int main() { | |
int server_sockfd; | |
WSADATA wsaData; | |
WSAStartup(MAKEWORD(2, 2), &wsaData); | |
server_sockfd = try_open_server_on_port(0); | |
accept_connections_through(server_sockfd); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment