Skip to content

Instantly share code, notes, and snippets.

@dwilliamson
Created June 6, 2014 12:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dwilliamson/b802334f6e9641d5c7b7 to your computer and use it in GitHub Desktop.
Save dwilliamson/b802334f6e9641d5c7b7 to your computer and use it in GitHub Desktop.
#pragma once
#include <clcpp/clcpp.h>
#include <tinycrt/tinycrt.h>
// Forward declaration of SOCKET to prevent any windows inclusion
typedef u32 SOCKET;
namespace core
{
template <typename TYPE> class Vector;
class HashTable;
struct String;
}
namespace clutl
{
class WriteBuffer;
class ObjectGroup;
class ParameterObjectCache;
struct Object;
}
clcpp_reflect_part(net)
namespace net
{
//
// Connection status
//
struct clcpp_attr(reflect) Status
{
// Initially in a state of no read/write
Status()
: can_read(false)
, can_write(false)
, has_errors(false)
{
}
// Initialise from a pre-defined state
Status(bool can_read, bool can_write, bool has_errors)
: can_read(can_read)
, can_write(can_write)
, has_errors(has_errors)
{
}
bool can_read;
bool can_write;
bool has_errors;
};
enum clcpp_attr(reflect) SendResult
{
SEND_Success,
SEND_TimedOut,
SEND_Error,
};
enum clcpp_attr(reflect) RecvResult
{
// Safe
RECV_Success,
RECV_NoData,
// Unhealthy, probably requires a disconnect
RECV_TimedOut,
RECV_Error,
};
enum clcpp_attr(reflect) DiskLog
{
DISKLOG_Enable,
DISKLOG_Disable,
};
//
// Base connection interface
//
struct clcpp_attr(reflect_part) Connection
{
virtual ~Connection() { };
// Check the connection status
virtual Status PollStatus() = 0;
// Check for any incoming connection requests and accept them
virtual Connection* AcceptConnection() = 0;
// Block sending all data specified by the length parameter. Optionally specify
// a timeout period in milliseconds.
virtual SendResult Send(const void* data, u32 length, u32 timeout_ms = 20) = 0;
// Block receiving all data specified by the length parameter. Optionally specify
// a timeout period in milliseconds.
virtual RecvResult Receive(void* data, u32 length, u32 timeout_ms = 20) = 0;
// Close the connection, rendering the object unusable after that point
virtual void Close() = 0;
};
//
// A TCP/IP connection where Winsock API are non-blocking but Send and Receive
// will block dependent upon how much data is requested.
//
class clcpp_attr(reflect_part) TCPIPConnection : public Connection
{
public:
// Create a listening connection
TCPIPConnection(unsigned short port, DiskLog disk_log = DISKLOG_Enable);
// Create a connection to a remote address
TCPIPConnection(const char* address, unsigned short port, DiskLog disk_log = DISKLOG_Enable);
~TCPIPConnection();
// Connection interface implementation
Status PollStatus();
TCPIPConnection* AcceptConnection();
SendResult Send(const void* data, u32 length, u32 timeout_ms = 20);
RecvResult Receive(void* data, u32 length, u32 timeout_ms = 20);
void Close();
private:
// Create a connection to an existing socket
TCPIPConnection(SOCKET s, DiskLog disk_log);
// Noncopyable
TCPIPConnection(const TCPIPConnection&);
TCPIPConnection& operator= (const TCPIPConnection&);
SOCKET CreateSocket();
void SetNonBlocking();
// Endpoint connection
SOCKET m_Socket;
DiskLog m_DiskLog;
};
enum WebSocketMode
{
WEBSOCKETMODE_Text = 1,
WEBSOCKETMODE_Binary = 2,
};
//
// Wrapper around a TCPIPConnection, function as a WebSocket.
//
class clcpp_attr(reflect_part) WebSocketConnection : public Connection
{
public:
// Create a listening connection
WebSocketConnection(unsigned short port, WebSocketMode mode, DiskLog disk_log = DISKLOG_Enable);
~WebSocketConnection();
// Connection interface implementation
Status PollStatus();
WebSocketConnection* AcceptConnection();
SendResult Send(const void* data, u32 length, u32 timeout_ms = 20);
RecvResult Receive(void* data, u32 length, u32 timeout_ms = 20);
void Close();
private:
WebSocketConnection(TCPIPConnection* connection, WebSocketMode mode, DiskLog disk_log);
// Noncopyable
WebSocketConnection(const WebSocketConnection&);
WebSocketConnection& operator = (const WebSocketConnection&);
bool ReceiveFrameHeader();
TCPIPConnection* m_TCPIPConnection;
WebSocketMode m_Mode;
DiskLog m_DiskLog;
u32 m_FrameBytesRemaining;
u32 m_MaskOffset;
u8 m_DataMask[4];
};
//
// Enumeration of all possible message types
//
enum clcpp_attr(reflect) MessageType
{
MSGTYPE_Message = 0,
MSGTYPE_FunctionCall = 1,
MSGTYPE_Replication = 2,
};
//
// The header for the basic message passing protocol, handling typed messages, function calls and
// object replication. Typically packed next to message data in the incoming stream.
//
#pragma pack(push, 1)
struct clcpp_attr(reflect) MessageHeader
{
MessageHeader();
MessageHeader(MessageType type, u32 id, u32 size);
void GetDebugDescription(core::String& dest) const;
// What kind of data does this message describe?
unsigned char type;
//
// The ID means different things for different message types.
//
// TYPE_Message: The hash of the message type being sent.
// TYPE_FunctionCall: The hash of the function name being called.
// TYPE_Reflection: The unique ID of the object being replicated.
//
u32 id;
// Size of the data the header describes (excluding the header itself)
u32 data_size;
};
#pragma pack(pop)
//
// Can send and receive reflection-based messages with any connection type.
// If there is no active connection, these functions will fail safely.
//
// TODO: Send needs to BLOCK or stick in a queue and return immediately.
// TODO: Receive needs to be capable of receiving partial messages
//
class clcpp_attr(reflect_part) MessageIO
{
public:
MessageIO(const clcpp::Database* reflection_db, Connection* connection = 0);
~MessageIO();
// Set and unset the active connection
void ChangeConnection(Connection* connection);
// Get the next message as header and raw JSON data
RecvResult GetNextMessage(MessageHeader& header, clutl::WriteBuffer& data);
// Send a raw data message packet with data encoded as a JSON string
SendResult SendMessage(const MessageHeader& header, const void* data);
// Send an object as a message, serialising its properties to JSON
SendResult SendMessage(const clutl::Object& message, u32 json_flags = 0);
// Send an object as a replication request, serialising its properties to JSON
SendResult SendReplicateObject(const clutl::Object& object, u32 json_flags = 0);
// Get the next message, parse it as JSON and create a message object of its type
void BindHandlers(void* this_ptr, const clcpp::Class* receiver_class, const clcpp::Namespace* message_ns);
RecvResult DispatchMessages(clutl::ObjectGroup* group0, clutl::ObjectGroup* group1);
//
// These functions will allocate whatever message object you require on the local stack and
// automatically assign the type parameter, saving any need to dynamically allocate any memory.
//
// They will also automatically populate the fields in the message, in the order that they are
// passed to the function.
//
// They are statically bound to the message and parameter types, only looking up the required
// data on the first call.
//
// TODO: Replace with a runtime equivalent (some kind of RPC layer)
//
template <typename TYPE>
SendResult Send()
{
TYPE msg;
msg.type = clcpp::GetType<TYPE>();
return SendMessage(msg);
}
template <typename TYPE, typename A>
SendResult Send(const A& a)
{
TYPE msg;
msg.type = clcpp::GetType<TYPE>();
static const u32 offset_a = GetFieldOffset(msg.type, 0);
CopyField(msg, offset_a, a);
return SendMessage(msg);
}
template <typename TYPE, typename A, typename B>
SendResult Send(const A& a, const B& b)
{
TYPE msg;
msg.type = clcpp::GetType<TYPE>();
static const u32 offset_a = GetFieldOffset(msg.type, 0);
static const u32 offset_b = GetFieldOffset(msg.type, 1);
CopyField(msg, offset_a, a);
CopyField(msg, offset_b, b);
return SendMessage(msg);
}
private:
//
// Type parameterised copying of fields. Reduces amount of generated code and allows me to specialise
// for types that don't have implicit assignment operators.
//
template <typename TYPE, typename A>
static void CopyField(TYPE& msg, u32 field_offset, const A& data)
{
*(A*)((char*)&msg + field_offset) = data;
}
template <typename TYPE, typename A>
static void CopyField(TYPE& msg, u32 field_offset, const core::Vector<A>& data)
{
core::Vector<A>& dest = *(core::Vector<A>*)((char*)&msg + field_offset);
dest.copy_from(data);
}
bool DispatchMessage(const MessageHeader& header);
bool DispatchFunctionCall(const MessageHeader& header, clutl::ObjectGroup* group);
void DispatchObjectReplication(const MessageHeader& header, clutl::ObjectGroup* group0, clutl::ObjectGroup* group1);
bool ParseJSONMessage(clutl::Object* object);
bool ReportError(const char* error, const MessageHeader& header);
void SetHandler(const clcpp::Type* type, const clcpp::Function* handler, void* this_ptr);
static u32 GetFieldOffset(const clcpp::Type* type, u32 index);
// Reflection database used to serialise messages
const clcpp::Database* m_ReflectionDB;
// Destination buffer for any received messages
clutl::WriteBuffer* m_MessageBuffer;
// Message dispatch table
core::HashTable* m_MessageHandlers;
// Cache for parameters to function calls
clutl::ParameterObjectCache* m_ParamObjectCache;
// Pointer to active connection whose lifetime is managed externally
net::Connection* m_Connection;
};
}
//
// TODO: Need to remove error handling from individual send/recv calls and push everything into queues.
// This would allow partial completes to eventually succeed without losing data.
// Would help if this was threaded/async so that receives can be processed on-demand.
//
// TODO: MessageIO needs to be able to partially receive messages and continue functioning. Probably needs to
// sit atop the same message system that drives the network socket IO.
//
// TODO: Timeouts don't mean anything for big data retrieval as the timeout expires while data is being validly
// receive. Timeout may need to reset itself on each valid receipt of data.
//
#include "Network.h"
#include "Core.h"
#include "SHA1.h"
#include "Base64.h"
#include <clutl/Objects.h>
#include <clutl/Serialise.h>
#include <clutl/SerialiseFunction.h>
#include <TinyCRT/TinyCRT.h>
#include <TinyCRT/TinyWinsock.h>
// Custom version of CORE_VA_NARGS that returns a max of 2
// Could do with PP min
#define NETWORK_LOG_VA_NARGS(...) CORE_EXPAND(CORE_VA_NARGS_IMPL(__VA_ARGS__, 2, 2, 2, 2, 2, 2, 2, 1, 0))
#define NETWORK_LOG_FORMAT "%s(%d) : [%s] - "
#define NETWORK_LOG_PARAMS __FILE__, __LINE__, __FUNCTION__
#define NETWORK_LOG1(msg) core::LogText(NETWORK_LOG_FORMAT msg, NETWORK_LOG_PARAMS)
#define NETWORK_LOG2(msg, ...) core::LogText(NETWORK_LOG_FORMAT msg, NETWORK_LOG_PARAMS, __VA_ARGS__)
#define NETWORK_LOG(disk_log, ...) if (disk_log == net::DISKLOG_Enable) CORE_JOIN(NETWORK_LOG, NETWORK_LOG_VA_NARGS(__VA_ARGS__))(__VA_ARGS__)
namespace
{
i32 g_NetworkInitRefCount = 0;
i32 g_LastError = 0;
void InitialiseNetwork(net::DiskLog disk_log)
{
// Check to see if the network is already initialised
if (g_NetworkInitRefCount)
{
g_NetworkInitRefCount++;
return;
}
// Initialise winsock at version 2.2
NETWORK_LOG(disk_log, "WSAStartup\n");
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data))
{
NETWORK_LOG(disk_log, "WSAStartup failed\n");
return;
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2)
{
NETWORK_LOG(disk_log, "Incorrect Winsock version (%d.%d)\n", LOBYTE(wsa_data.wVersion), HIBYTE(wsa_data.wVersion));
return;
}
g_NetworkInitRefCount++;
}
void ShutdownNetwork(net::DiskLog disk_log)
{
if (g_NetworkInitRefCount && --g_NetworkInitRefCount == 0)
{
NETWORK_LOG(disk_log, "WSACleanup\n");
WSACleanup();
}
}
net::Status PollSocketStatus(SOCKET s, net::DiskLog disk_log)
{
if (s == INVALID_SOCKET)
return net::Status(false, false, true);
// Set read/write/error markers for the socket
fd_set fd_read, fd_write, fd_errors;
FD_ZERO(&fd_read);
FD_ZERO(&fd_write);
FD_ZERO(&fd_errors);
FD_SET(s, &fd_read);
FD_SET(s, &fd_write);
FD_SET(s, &fd_errors);
// Poll socket status without blocking
timeval tv;
tv.tv_sec = 0;
tv.tv_usec = 0;
if (select(0, &fd_read, &fd_write, &fd_errors, &tv) == SOCKET_ERROR)
{
g_LastError = WSAGetLastError();
NETWORK_LOG(disk_log, "Socket error selecting (%d)\n", g_LastError);
return net::Status(false, false, true);
}
return net::Status(
FD_ISSET(s, &fd_read) != 0,
FD_ISSET(s, &fd_write) != 0,
FD_ISSET(s, &fd_errors) != 0);
}
}
net::TCPIPConnection::TCPIPConnection(unsigned short port, DiskLog disk_log)
: m_Socket(INVALID_SOCKET)
, m_DiskLog(disk_log)
{
InitialiseNetwork(m_DiskLog);
// Try to create the socket
SOCKET s = CreateSocket();
if (s == INVALID_SOCKET)
return;
// Bind the socket to the incoming port
NETWORK_LOG(m_DiskLog, "Binding to port %d\n", port);
sockaddr_in sin = { 0 };
sin.sin_family = AF_INET;
sin.sin_addr.s_addr = INADDR_ANY;
sin.sin_port = htons(port);
if (bind(s, (sockaddr*)&sin, sizeof(sin)) == SOCKET_ERROR)
{
NETWORK_LOG(m_DiskLog, "Failed to bind\n");
closesocket(s);
return;
}
// Connection is valid, remaining code is socket state modification
m_Socket = s;
// Enter a listening state with a backlog of 1 connection
NETWORK_LOG(m_DiskLog, "Socket entering listening state\n", port);
if (listen(m_Socket, 1) == SOCKET_ERROR)
{
NETWORK_LOG(m_DiskLog, "Failed to listen\n");
Close();
return;
}
SetNonBlocking();
}
net::TCPIPConnection::TCPIPConnection(SOCKET s, DiskLog disk_log)
: m_Socket(s)
, m_DiskLog(disk_log)
{
InitialiseNetwork(m_DiskLog);
SetNonBlocking();
}
net::TCPIPConnection::TCPIPConnection(const char* address, unsigned short port, DiskLog disk_log)
: m_Socket(INVALID_SOCKET)
{
InitialiseNetwork(m_DiskLog);
// Try to create the socket
SOCKET s = CreateSocket();
if (s == INVALID_SOCKET)
return;
// Can't lookup remote host?
hostent* remote_host = gethostbyname(address);
if (remote_host == 0)
return;
// Try to connect to the endpoint
NETWORK_LOG(m_DiskLog, "Connecting to endpoint %s:%d\n", address, port);
sockaddr_in sa = { 0 };
sa.sin_family = AF_INET;
sa.sin_port = htons(port);
memcpy(&sa.sin_addr, remote_host->h_addr_list[0], remote_host->h_length);
if (connect(s, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR)
{
NETWORK_LOG(m_DiskLog, "Failed to connect\n");
closesocket(s);
return;
}
// Connection is valid, remaining code is socket state modification
m_Socket = s;
SetNonBlocking();
}
net::TCPIPConnection::~TCPIPConnection()
{
Close();
}
net::Status net::TCPIPConnection::PollStatus()
{
// Close the socket if there are any errors
Status status = PollSocketStatus(m_Socket, m_DiskLog);
if (status.has_errors)
Close();
return status;
}
net::TCPIPConnection* net::TCPIPConnection::AcceptConnection()
{
// Ensure there is an incoming connection
Status status = PollStatus();
if (status.has_errors || !status.can_read)
return 0;
// Accept the connection
NETWORK_LOG(m_DiskLog, "Accepting new connection\n");
SOCKET s = accept(m_Socket, 0, 0);
if (s == SOCKET_ERROR)
{
NETWORK_LOG(m_DiskLog, "Error while accepting socket connection\n");
Close();
return 0;
}
return new TCPIPConnection(s, m_DiskLog);
}
net::SendResult net::TCPIPConnection::Send(const void* data, u32 length, u32 timeout_ms)
{
// Can't send if there are socket errors
Status status = PollStatus();
if (status.has_errors)
return SEND_Error;
if (!status.can_write)
return SEND_TimedOut;
char* cur_data = (char*)data;
char* end_data = cur_data + length;
core::Milliseconds start_ms = core::GetLowResTimer();
while (cur_data < end_data)
{
// Attempt to send the remaining chunk of data
i32 bytes_sent = send(m_Socket, cur_data, end_data - cur_data, 0);
if (bytes_sent == SOCKET_ERROR || bytes_sent == 0)
{
// Close the connection if sending fails for any other reason other than blocking
g_LastError = WSAGetLastError();
if (g_LastError != WSAEWOULDBLOCK)
{
NETWORK_LOG(m_DiskLog, "Socket error sending (%d)\n", g_LastError);
Close();
return SEND_Error;
}
// First check for tick-count overflow and reset, giving a slight hitch every 49.7 days
core::Milliseconds cur_ms = core::GetLowResTimer();
if (cur_ms < start_ms)
{
start_ms = cur_ms;
continue;
}
//
// Timeout can happen when:
//
// 1) endpoint is no longer there
// 2) endpoint can't consume quick enough
// 3) local buffers overflow
//
// As none of these are actually errors, we have to pass this timeout back to the caller.
//
// TODO: This strategy breaks down if a send partially completes and then times out!
//
if (cur_ms - start_ms > core::Milliseconds(timeout_ms))
{
NETWORK_LOG(m_DiskLog, "Send timeout (%dms)\n", timeout_ms);
return SEND_TimedOut;
}
}
else
{
// Jump over the data sent
cur_data += bytes_sent;
}
}
return SEND_Success;
}
net::RecvResult net::TCPIPConnection::Receive(void* data, u32 length, u32 timeout_ms)
{
// Ensure there is data to receive
Status status = PollStatus();
if (status.has_errors)
return RECV_Error;
if (!status.can_read)
return RECV_NoData;
char* cur_data = (char*)data;
char* end_data = cur_data + length;
// Loop until all data has been received
core::Milliseconds start_ms = core::GetLowResTimer();
while (cur_data < end_data)
{
i32 bytes_received = recv(m_Socket, cur_data, end_data - cur_data, 0);
if (bytes_received == SOCKET_ERROR || bytes_received == 0)
{
// Close the connection if receiving fails for any other reason other than blocking
g_LastError = WSAGetLastError();
if (g_LastError != WSAEWOULDBLOCK)
{
NETWORK_LOG(m_DiskLog, "Socket error receiving (%d)\n", g_LastError);
Close();
return RECV_Error;
}
// First check for tick-count overflow and reset, giving a slight hitch every 49.7 days
core::Milliseconds cur_ms = core::GetLowResTimer();
if (cur_ms < start_ms)
{
start_ms = cur_ms;
continue;
}
//
// Timeout can happen when:
//
// 1) data is delayed by sender
// 2) sender fails to send a complete set of packets
//
// As not all of these scenarios are errors, we need to pass this information back to the caller.
//
// TODO: This strategy breaks down if a send partially completes and then times out!
//
if (cur_ms - start_ms > core::Milliseconds(timeout_ms))
{
NETWORK_LOG(m_DiskLog, "Receive timeout (%dms)\n", timeout_ms);
return RECV_TimedOut;
}
}
else
{
// Jump over the data received
cur_data += bytes_received;
}
}
return RECV_Success;
}
void net::TCPIPConnection::Close()
{
if (m_Socket != INVALID_SOCKET)
{
NETWORK_LOG(m_DiskLog, "Closing TCP/IP Socket\n");
// Shutdown the connection, stopping all sends
i32 result = shutdown(m_Socket, SD_SEND);
if (result != SOCKET_ERROR)
{
NETWORK_LOG(m_DiskLog, "Data still to be received on socket, receiving until complete\n");
// Keep receiving until the peer closes the connection
i32 total = 0;
char temp_buf[128];
while (result > 0)
{
result = recv(m_Socket, temp_buf, sizeof(temp_buf), 0);
total += result;
}
NETWORK_LOG(m_DiskLog, "%d bytes of discarded data received during socket shutdown\n", total);
}
// Close the socket
closesocket(m_Socket);
NETWORK_LOG(m_DiskLog, "Socket closed\n");
// Invalidate the socket and issue a network shutdown request
m_Socket = INVALID_SOCKET;
ShutdownNetwork(m_DiskLog);
}
}
SOCKET net::TCPIPConnection::CreateSocket()
{
NETWORK_LOG(m_DiskLog, "Creating socket\n");
SOCKET s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (s == SOCKET_ERROR)
{
NETWORK_LOG(m_DiskLog, "Failed to create socket\n");
return SOCKET_ERROR;
}
return s;
}
void net::TCPIPConnection::SetNonBlocking()
{
// Set as non-blocking
NETWORK_LOG(m_DiskLog, "Setting non-blocking socket\n");
u_long nonblock = 1;
if (ioctlsocket(m_Socket, FIONBIO, &nonblock) == SOCKET_ERROR)
{
NETWORK_LOG(m_DiskLog, "Failed to set non-blocking\n");
Close();
}
}
namespace
{
char* GetField(char* buffer, const char* field_name)
{
// Search for the start of the field
char* field = (char*)strstr(buffer, field_name);
if (field == 0)
return 0;
// Skip over the field name and any trailing whitespace
field += strlen(field_name);
while (*field == ' ')
field++;
return field;
}
bool WebSocketHandshake(net::TCPIPConnection& connection, net::DiskLog disk_log)
{
NETWORK_LOG(disk_log, "Start\n");
core::Milliseconds start_ms = core::GetLowResTimer();
// Really inefficient way of receiving the handshake data from the browser
// Not really sure how to do this any better, as the termination requirement is \r\n\r\n
char buffer[1024];
char* buffer_ptr = buffer;
while (true && buffer_ptr - buffer < sizeof(buffer) - 1)
{
net::RecvResult result = connection.Receive(buffer_ptr, 1);
if (result == net::RECV_Error)
{
NETWORK_LOG(disk_log, "Failed to receive handshake data - %d bytes received\n", buffer_ptr - buffer);
return false;
}
// If there's a stall receiving the data, check for a handshake timeout
if (result == net::RECV_NoData || result == net::RECV_TimedOut)
{
core::Milliseconds now_ms = core::GetLowResTimer();
if (now_ms - start_ms > core::Milliseconds(1000))
{
NETWORK_LOG(disk_log, "Timeout receving handshake data\n");
return false;
}
continue;
}
// Just in case new enums are added...
core::Assert(result == net::RECV_Success);
if (buffer_ptr - buffer >= 4)
{
if (*(buffer_ptr - 3) == '\r' &&
*(buffer_ptr - 2) == '\n' &&
*(buffer_ptr - 1) == '\r' &&
*(buffer_ptr - 0) == '\n')
break;
}
buffer_ptr++;
}
*buffer_ptr = 0;
// HTTP GET instruction
if (memcmp(buffer, "GET", 3) != 0)
{
NETWORK_LOG(disk_log, "Handshake failed, not HTTP GET\n");
return false;
}
// Look for the version number and verify that it's supported
char* version = GetField(buffer, "Sec-WebSocket-Version:");
if (version == 0)
{
NETWORK_LOG(disk_log, "Handshake failed, can't locate WebSocket version\n");
return false;
}
int api_version = atoi(version);
if (api_version != 13 && api_version != 8)
{
NETWORK_LOG(disk_log, "Handshake failed, unsupported version (%d)\n", api_version);
return false;
}
// Make sure this is a localhost connection only
// TODO: This can be spoofed so need to use getpeername with TCP/IP connection!
char* host = GetField(buffer, "Host:");
if (host == 0)
{
NETWORK_LOG(disk_log, "Handshake failed, can't locate host\n");
return false;
}
const char* localhost = "localhost:";
if (memcmp(host, localhost, strlen(localhost)) != 0)
{
NETWORK_LOG(disk_log, "Handshake failed, host is not localhost\n");
return false;
}
// Look for the key start and null-terminate it within the receive buffer
char* key = GetField(buffer, "Sec-WebSocket-Key:");
if (key == 0)
{
NETWORK_LOG(disk_log, "Handshake failed, can't locate WebSocket key\n");
return false;
}
char* key_end = (char*)strstr(key, "\r\n");
if (key_end == 0)
{
NETWORK_LOG(disk_log, "Handshake failed, WebSocket key is ill-formed\n");
return false;
}
*key_end = 0;
// Concatenate the browser's key with the WebSocket Protocol GUID and base64 encode
// the hash, to prove to the browser that this is a bonafide WebSocket server
core::String hash_string(key);
hash_string += core::String("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
core::SHA1 hash;
core::SHA1_Calculate(hash_string.data(), hash_string.length(), hash);
core::Base64_Encode(hash.data, sizeof(hash.data), hash_string);
// Send the response back to the server with a longer timeout than usual
core::String response_string;
response_string.setv(
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n\r\n", hash_string.c_str());
net::SendResult result = connection.Send(response_string.c_str(), response_string.length(), 1000);
switch (result)
{
case net::SEND_Success:
return true;
case net::SEND_TimedOut:
NETWORK_LOG(disk_log, "Timed out sending handshake response\n");
return false;
case net::SEND_Error:
NETWORK_LOG(disk_log, "Error sending handshake response\n");
return false;
}
return true;
}
static void WriteSize(u32 size, u8* dest, u32 dest_size, u32 dest_offset)
{
int size_size = dest_size - dest_offset;
for (u32 i = 0; i < dest_size; i++)
{
int j = i - dest_offset;
dest[i] = (j < 0) ? 0 : (size >> ((size_size - j - 1) * 8)) & 0xFF;
}
}
}
net::WebSocketConnection::WebSocketConnection(unsigned short port, WebSocketMode mode, DiskLog disk_log)
: m_TCPIPConnection(0)
, m_Mode(mode)
, m_DiskLog(disk_log)
, m_FrameBytesRemaining(0)
, m_MaskOffset(0)
{
m_DataMask[0] = m_DataMask[1] = m_DataMask[2] = m_DataMask[3] = 0;
m_TCPIPConnection = new TCPIPConnection(port, m_DiskLog);
}
net::WebSocketConnection::WebSocketConnection(TCPIPConnection* connection, WebSocketMode mode, DiskLog disk_log)
: m_TCPIPConnection(connection)
, m_Mode(mode)
, m_DiskLog(disk_log)
, m_FrameBytesRemaining(0)
, m_MaskOffset(0)
{
m_DataMask[0] = m_DataMask[1] = m_DataMask[2] = m_DataMask[3] = 0;
}
net::WebSocketConnection::~WebSocketConnection()
{
delete m_TCPIPConnection;
}
net::Status net::WebSocketConnection::PollStatus()
{
core::Assert(m_TCPIPConnection != 0);
return m_TCPIPConnection->PollStatus();
}
net::WebSocketConnection* net::WebSocketConnection::AcceptConnection()
{
// Is there a waiting connection?
core::Assert(m_TCPIPConnection != 0);
TCPIPConnection* connection = m_TCPIPConnection->AcceptConnection();
if (connection == 0)
return 0;
// Need a successful handshake between client/server before allowing the connection
if (!WebSocketHandshake(*connection, m_DiskLog))
{
delete connection;
return 0;
}
return new WebSocketConnection(connection, m_Mode, m_DiskLog);
}
net::SendResult net::WebSocketConnection::Send(const void* data, u32 length, u32 timeout_ms)
{
core::Assert(m_TCPIPConnection != 0);
// Can't send if there are socket errors
Status status = PollStatus();
if (status.has_errors)
return SEND_Error;
if (!status.can_write)
return SEND_TimedOut;
u8 final_fragment = 0x1 << 7;
u8 frame_type = (u8)m_Mode;
u8 frame_header[10];
frame_header[0] = final_fragment | frame_type;
// Construct the frame header, correctly applying the narrowest size
u32 frame_header_size = 0;
if (length <= 125)
{
frame_header_size = 2;
frame_header[1] = length;
}
else if (length <= 65535)
{
frame_header_size = 2 + 2;
frame_header[1] = 126;
WriteSize(length, frame_header + 2, 2, 0);
}
else
{
frame_header_size = 2 + 8;
frame_header[1] = 127;
WriteSize(length, frame_header + 2, 8, 4);
}
// Allocate the frame and copy in the header and data
u32 frame_size = frame_header_size + length;
u8* frame_data = new u8[frame_size];
memcpy(frame_data, frame_header, frame_header_size);
memcpy(frame_data + frame_header_size, data, length);
// Pass Send result onto the caller
SendResult result = m_TCPIPConnection->Send(frame_data, frame_size, timeout_ms);
delete [] frame_data;
return result;
}
net::RecvResult net::WebSocketConnection::Receive(void* data, u32 length, u32 timeout_ms)
{
core::Assert(m_TCPIPConnection != 0);
// Ensure there is data to receive
Status status = PollStatus();
if (status.has_errors)
return RECV_Error;
if (!status.can_read)
return RECV_NoData;
char* cur_data = (char*)data;
char* end_data = cur_data + length;
//NETWORK_LOG("Next values: %d, %d\n", m_FrameBytesRemaining, m_MaskOffset);
core::Milliseconds start_ms = core::GetLowResTimer();
while (cur_data < end_data)
{
// Get next WebSocket frame if we've run out of data to read from the socket
if (m_FrameBytesRemaining == 0)
{
if (!ReceiveFrameHeader())
{
// Frame header potentially partially received so need to close
m_TCPIPConnection->Close();
return RECV_Error;
}
}
// Read as much required data as possible
u32 bytes_to_read = min(m_FrameBytesRemaining, length);
RecvResult result = m_TCPIPConnection->Receive(cur_data, bytes_to_read);
if (result == RECV_Error)
return RECV_Error;
// If there's a stall receiving the data, check for timeout
if (result == net::RECV_NoData || result == net::RECV_TimedOut)
{
core::Milliseconds now_ms = core::GetLowResTimer();
if (now_ms - start_ms > core::Milliseconds(timeout_ms))
{
NETWORK_LOG(m_DiskLog, "Timeout receving handshake data\n");
return RECV_TimedOut;
}
continue;
}
// Apply data mask
if (*(u32*)m_DataMask != 0)
{
for (u32 i = 0; i < bytes_to_read; i++)
{
*((u8*)cur_data + i) ^= m_DataMask[m_MaskOffset & 3];
m_MaskOffset++;
}
}
cur_data += bytes_to_read;
m_FrameBytesRemaining -= bytes_to_read;
}
return RECV_Success;
}
void net::WebSocketConnection::Close()
{
core::Assert(m_TCPIPConnection != 0);
m_TCPIPConnection->Close();
}
bool net::WebSocketConnection::ReceiveFrameHeader()
{
// TODO: Specify infinite timeout?
//NETWORK_LOG("RECEIVE FRAME HEADER\n");
// Get message header
u8 msg_header[2] = { 0, 0 };
if (m_TCPIPConnection->Receive(msg_header, 2) != RECV_Success)
return false;
// Check for WebSocket Protocol disconnect
if (msg_header[0] == 0x88)
{
NETWORK_LOG(m_DiskLog, "Websocket Protocol Disconnect requested\n");
return false;
}
// Check that the client isn't sending messages we don't understand
if (msg_header[0] != 0x81 && msg_header[0] != 0x82)
{
NETWORK_LOG(m_DiskLog, "Couldn't parse frame header\n");
return false;
}
// Get message length and check to see if it's a marker for a wider length
int msg_length = msg_header[1] & 0x7F;
int size_bytes_remaining = 0;
switch (msg_length)
{
case 126: size_bytes_remaining = 2; break;
case 127: size_bytes_remaining = 8; break;
}
if (size_bytes_remaining > 0)
{
// Receive the wider bytes of the length
u8 size_bytes[4];
if (m_TCPIPConnection->Receive(size_bytes, size_bytes_remaining) != RECV_Success)
{
NETWORK_LOG(m_DiskLog, "Partially received wide frame header size\n");
return false;
}
// Calculate new length, MSB first
msg_length = 0;
for (int i = 0; i < size_bytes_remaining; i++)
msg_length |= size_bytes[i] << ((size_bytes_remaining - 1 - i) * 8);
}
// Receive any message data masks
bool mask_present = (msg_header[1] & 0x80) != 0;
if (mask_present)
{
if (m_TCPIPConnection->Receive(m_DataMask, 4) != RECV_Success)
{
NETWORK_LOG(m_DiskLog, "Partially received frame header data mask\n");
return false;
}
}
m_FrameBytesRemaining = msg_length;
m_MaskOffset = 0;
//NETWORK_LOG("Start values: %d, %d\n", m_FrameBytesRemaining, m_MaskOffset);
return true;
}
namespace
{
struct MessageHandlerDesc
{
// A map from type to handler
const clcpp::Type* type;
const clcpp::Function* handler;
void* this_ptr;
};
struct PointerSave : public clutl::IPtrSave
{
bool CanSavePtr(void* ptr, const clcpp::Field* field, const clcpp::Type* type)
{
// Don't save raw pointers
if (type->kind != clcpp::Primitive::KIND_CLASS)
return false;
// Don't save values for pointer fields that aren't derived from Object
const clcpp::Class* class_type = type->AsClass();
if (!(class_type->flag_attributes & clutl::FLAG_ATTR_IS_OBJECT))
return false;
// Only use the hash if the pointer is non-null
if (ptr != 0)
{
// If the target object has no unique ID then its pointer is not meant for serialisation
clutl::Object* object_ptr = (clutl::Object*)ptr;
if (object_ptr->unique_id == 0)
return false;
}
return true;
}
u32 SavePtr(void* ptr)
{
// Return the object unique ID as the hash
u32 hash = 0;
if (ptr != 0)
{
clutl::Object* object_ptr = (clutl::Object*)ptr;
hash = object_ptr->unique_id;
}
return hash;
}
};
}
net::MessageHeader::MessageHeader()
: type(0)
, id(0)
, data_size(0)
{
}
net::MessageHeader::MessageHeader(MessageType type, u32 id, u32 size)
: type(type)
, id(id)
, data_size(size)
{
}
void net::MessageHeader::GetDebugDescription(core::String& dest) const
{
const char* type_name = "Unknown";
if (type == MSGTYPE_Message) type_name = "TYPE_Message";
else if (type == MSGTYPE_FunctionCall) type_name = "TYPE_FunctionCall";
else if (type == MSGTYPE_Replication) type_name = "TYPE_Replication";
dest.setv("type: %s, id: 0x%x, size: 0x%x", type_name, id, data_size);
}
net::MessageIO::MessageIO(const clcpp::Database* reflection_db, Connection* connection)
: m_ReflectionDB(reflection_db)
, m_MessageBuffer(0)
, m_MessageHandlers(0)
, m_ParamObjectCache(0)
, m_Connection(connection)
{
core::Assert(m_ReflectionDB != 0);
// Create message exchange buffers
m_MessageBuffer = new clutl::WriteBuffer;
m_MessageHandlers = new core::HashTable();
m_ParamObjectCache = new clutl::ParameterObjectCache();
}
net::MessageIO::~MessageIO()
{
// Release message handlers
if (m_MessageHandlers != 0)
{
for (core::HashTableIterator i(*m_MessageHandlers); i.IsValid(); i.MoveNext())
delete (MessageHandlerDesc*)i.GetPtr();
}
delete m_ParamObjectCache;
delete m_MessageHandlers;
delete m_MessageBuffer;
}
void net::MessageIO::ChangeConnection(Connection* connection)
{
// TODO: Reset internal message gathering
m_Connection = connection;
}
net::RecvResult net::MessageIO::GetNextMessage(MessageHeader& header, clutl::WriteBuffer& message_buffer)
{
if (m_Connection == 0)
return RECV_Error;
// Receive the message header
RecvResult result = m_Connection->Receive(&header, sizeof(header));
if (result != RECV_Success)
{
if (result != RECV_NoData)
ReportError("RECEIVE message header", header);
return result;
}
//NETWORK_LOG("Header: %d, %x, %d\n", header.type, header.id, header.data_size);
// Read the message with longer timeout
message_buffer.Reset();
message_buffer.Alloc(header.data_size);
result = m_Connection->Receive((void*)message_buffer.GetData(), header.data_size, 1000);
if (result != RECV_Success)
ReportError("RECEIVE message data", header);
return result;
}
net::SendResult net::MessageIO::SendMessage(const MessageHeader& header, const void* data)
{
if (m_Connection == 0)
return SEND_Error;
// Pack header and data into one block for sending
// TODO: This is a limitation of the IConnection interface working with frame-based WebSockets - remove new!
u32 packed_size = sizeof(header) + header.data_size;
u8* packed_data = new u8[packed_size];
memcpy(packed_data, &header, sizeof(header));
memcpy(packed_data + sizeof(header), data, header.data_size);
// Send everything together
SendResult result = m_Connection->Send(packed_data, packed_size);
if (result != SEND_Success)
ReportError("SEND message", header);
delete [] packed_data;
return result;
}
net::SendResult net::MessageIO::SendMessage(const clutl::Object& message, u32 json_flags)
{
if (m_Connection == 0)
return SEND_Error;
core::Assert(m_MessageBuffer != 0);
core::Assert(message.type != 0);
// Serialise the message
// TODO: Handle any save errors - there are none yet
m_MessageBuffer->Reset();
PointerSave ptr_save;
clutl::SaveJSON(*m_MessageBuffer, &message, message.type, &ptr_save, json_flags);
i32 data_size = m_MessageBuffer->GetBytesWritten();
//core::LogText("Message: %s\n", m_MessageBuffer->GetData());
return SendMessage(MessageHeader(MSGTYPE_Message, message.type->name.hash, data_size), m_MessageBuffer->GetData());
}
net::SendResult net::MessageIO::SendReplicateObject(const clutl::Object& object, u32 json_flags)
{
if (m_Connection == 0)
return SEND_Error;
core::Assert(m_MessageBuffer != 0);
core::Assert(object.type != 0);
// Can't replicate unnamed objects
core::Assert(object.unique_id != 0);
// Serialise the object
// TODO: Handle any save errors - there are none yet
m_MessageBuffer->Reset();
PointerSave ptr_save;
clutl::SaveJSON(*m_MessageBuffer, &object, object.type, &ptr_save, json_flags);
i32 data_size = m_MessageBuffer->GetBytesWritten();
return SendMessage(MessageHeader(MSGTYPE_Replication, object.unique_id, data_size), m_MessageBuffer->GetData());
}
void net::MessageIO::BindHandlers(void* this_ptr, const clcpp::Class *receiver_class, const clcpp::Namespace *message_ns)
{
core::Assert(this_ptr != 0);
core::Assert(receiver_class != 0);
core::Assert(message_ns != 0);
const clcpp::CArray<const clcpp::Function*>& methods = receiver_class->methods;
// Iterate over every message class
const clcpp::CArray<const clcpp::Class*>& classes = message_ns->classes;
for (u32 i = 0; i < classes.size; i++)
{
const clcpp::Class* msg_class = classes[i];
if (!(msg_class->flag_attributes & clutl::FLAG_ATTR_IS_OBJECT))
continue;
// Iterate over every candidate method in the receiver class
for (u32 j = 0; j < methods.size; j++)
{
const clcpp::Function* method = methods[j];
const clcpp::CArray<const clcpp::Field*>& params = method->parameters;
// Try to match the first/only parameter with the message type
if (params.size == 2 && params[1]->type == msg_class)
SetHandler(msg_class, method, this_ptr);
}
}
}
net::RecvResult net::MessageIO::DispatchMessages(clutl::ObjectGroup* group0, clutl::ObjectGroup* group1)
{
// TODO: Does MessageIO needs its own return values?
MessageHeader header;
net::RecvResult return_result = RECV_NoData;
// Empty the message queue
while (true)
{
net::RecvResult result = GetNextMessage(header, *m_MessageBuffer);
if (result != RECV_Success)
return result;
// Parse and dispatch message to handler
bool call_result = true;
switch (header.type)
{
case MSGTYPE_Message:
call_result = DispatchMessage(header);
break;
case MSGTYPE_FunctionCall:
call_result = DispatchFunctionCall(header, group0);
break;
case MSGTYPE_Replication:
DispatchObjectReplication(header, group0, group1);
break;
}
// Force an error on the connection if a dispatch fails
if (!call_result)
return RECV_Error;
return_result = RECV_Success;
}
return return_result;
}
bool net::MessageIO::DispatchMessage(const MessageHeader& header)
{
if (m_MessageHandlers == 0)
return true;
// Lookup the message type
const clcpp::Type* type = m_ReflectionDB->GetType(header.id);
if (type == 0)
{
ReportError("GetType for DispatchMessage", header);
return true;
}
// Create the message object
clutl::Object* message = clutl::CreateObject(type);
if (message == 0)
{
ReportError("CREATE object for DispatchMessage", header);
return true;
}
//core::LogText("RECEIVE: %.*s\n", m_MessageBuffer->GetBytesWritten(), m_MessageBuffer->GetData());
if (!ParseJSONMessage(message))
{
Delete(message);
return true;
}
// Locate the correct message handler
MessageHandlerDesc* desc = (MessageHandlerDesc*)m_MessageHandlers->find(message->type->name.hash);
if (desc == 0)
{
core::LogText("NETWORK ERROR: Unhandled message RECEIVED (type = %s)\n", message->type->name.text);
Delete(message);
return true;
}
// Store the address and this pointer locally
u32 address = desc->handler->address;
void* this_ptr = desc->this_ptr;
core::Assert(address != 0);
core::Assert(this_ptr != 0);
// Call will platform ABI 'thiscall'
unsigned res = 0;
__asm
{
mov ecx, this_ptr
push message
call address
mov res, eax
}
Delete(message);
return (res & 0xFF) != 0;
}
bool net::MessageIO::DispatchFunctionCall(const MessageHeader& header, clutl::ObjectGroup* group)
{
// Lookup the function being called
const clcpp::Function* function = m_ReflectionDB->GetFunction(header.id);
if (function == 0)
{
ReportError("handle RECEIVED function call", header);
return true;
}
// Parse and create all parameter objects
clutl::ReadBuffer read_buffer(*m_MessageBuffer);
if (!clutl::BuildParameterObjectCache_JSON(*m_ParamObjectCache, function, read_buffer))
{
ReportError("Parse function parameters", header);
return true;
}
// Walk through every Object pointer
clutl::ParameterData& parameters = m_ParamObjectCache->GetParameters();
u32 nb_parameters = parameters.GetNbParameters();
for (u32 i = 0; i < nb_parameters; i++)
{
clutl::ParameterData::ParamDesc& param = parameters.GetParameter(i);
if (param.op != clcpp::Qualifier::POINTER)
continue;
if (param.type->kind != clcpp::Primitive::KIND_CLASS)
continue;
const clcpp::Class* class_type = (clcpp::Class*)param.type;
if (!(class_type->flag_attributes & clutl::FLAG_ATTR_IS_OBJECT))
continue;
// Patch up
u32 unique_id = *(u32*)param.object;
if (unique_id)
{
clutl::Object* ptr = group->FindObjectSearchParents(unique_id);
*(clutl::Object**)param.object = ptr;
}
}
// Call the function
// TODO: Global primitives aren't parented to the global namespace!
if (function->parent != 0 && function->parent->kind == clcpp::Primitive::KIND_CLASS)
clutl::CallFunction_x86_32_msvc_thiscall(function, parameters);
else
clutl::CallFunction_x86_32_msvc_cdecl(function, parameters);
// TODO: What about return value?
return true;
}
void net::MessageIO::DispatchObjectReplication(const MessageHeader& header, clutl::ObjectGroup* group0, clutl::ObjectGroup* group1)
{
// Locate the object being replicated
clutl::Object* object = group0->FindObjectSearchParents(header.id);
if (object == 0 && group1 != 0)
object = group1->FindObjectSearchParents(header.id);
if (object == 0)
{
ReportError("CREATE object for DispatchMessage", header);
return;
}
// Just replicate the data into the object!
ParseJSONMessage(object);
}
bool net::MessageIO::ParseJSONMessage(clutl::Object* object)
{
// Parse the object
clutl::ReadBuffer read_buffer(*m_MessageBuffer);
clutl::JSONError result = clutl::LoadJSON(read_buffer, object, object->type);
if (result.code == clutl::JSONError::NONE)
return true;
// Uncomment to re-parse and debug the error
//clutl::ReadBuffer new_read_buffer(*m_MessageBuffer);
//result = clutl::LoadJSON(new_read_buffer, object, object->type);
// Determine the error message
const char* error;
switch (result.code)
{
case clutl::JSONError::UNEXPECTED_END_OF_DATA:
error = "unexpected end of data";
break;
case clutl::JSONError::EXPECTING_HEX_DIGIT:
error = "expecting hex digit";
break;
case clutl::JSONError::EXPECTING_DIGIT:
error = "expecting digit";
break;
case clutl::JSONError::UNEXPECTED_CHARACTER:
error = "expecting character";
break;
case clutl::JSONError::INVALID_KEYWORD:
error = "invalid keyword";
break;
case clutl::JSONError::INVALID_ESCAPE_SEQUENCE:
error = "invalid escape sequence";
break;
case clutl::JSONError::UNEXPECTED_TOKEN:
error = "unexpected token";
break;
default:
error = "unknown error";
}
// Report error
core::LogText("NETWORK ERROR: Failed to parse JSON message data, %s (type = %s)\n", error, object->type->name.text);
core::LogText(" Line %d, Column %d, Position %d\n", result.line, result.column, result.position);
core::LogText(" Buffer size %d bytes (%d bytes remaining)\n", read_buffer.GetTotalBytes(), read_buffer.GetBytesRemaining());
core::LogText(" Data: %.%s\n", m_MessageBuffer->GetBytesWritten(), m_MessageBuffer->GetData());
return false;
}
bool net::MessageIO::ReportError(const char* error, const MessageHeader& header)
{
core::String message_info;
header.GetDebugDescription(message_info);
core::LogText("MESSAGE IO ERROR: Failed to %s (%s)\n", error, message_info.c_str());
// Returns false so that it can be passed on to other failing functions
return false;
}
void net::MessageIO::SetHandler(const clcpp::Type* type, const clcpp::Function* handler, void* this_ptr)
{
core::Assert(type != 0);
core::Assert(handler != 0);
// If a handler already exists, just update it
MessageHandlerDesc* desc = (MessageHandlerDesc*)m_MessageHandlers->find(type->name.hash);
if (desc != 0)
{
desc->handler = handler;
desc->this_ptr = this_ptr;
return;
}
// Add a new handler
desc = new MessageHandlerDesc();
desc->type = type;
desc->handler = handler;
desc->this_ptr = this_ptr;
m_MessageHandlers->insert(type->name.hash, desc);
}
u32 net::MessageIO::GetFieldOffset(const clcpp::Type* type, u32 index)
{
// Locally sorted fields, from lowest offset to highest
static const i32 MAX_NB_FIELDS = 5;
const clcpp::Field* sorted_fields[MAX_NB_FIELDS] = { 0, 0, 0, 0, 0 };
// Check there are enough fields
const clcpp::CArray<const clcpp::Field*>& fields = type->AsClass()->fields;
core::Assert(fields.size <= MAX_NB_FIELDS);
// Sort the field list
for (u32 i = 0; i < fields.size; i++)
{
// Search for the first field with a larger offset
const clcpp::Field* field = fields[i];
for (u32 j = 0; j < fields.size; j++)
{
if (sorted_fields[j] == 0 || sorted_fields[j]->offset > field->offset)
{
// Shift all fields up 1 slot
for (u32 k = j; k < fields.size - 1; k++)
sorted_fields[k + 1] = sorted_fields[k];
// Take over this slot
sorted_fields[j] = field;
break;
}
}
}
// Ensure there is a valid result
core::Assert(index < MAX_NB_FIELDS);
core::Assert(sorted_fields[index] != 0);
return (u32)sorted_fields[index]->offset;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment