Skip to content

Instantly share code, notes, and snippets.

@marcbarry
Forked from reklis/Example.cpp
Created November 8, 2016 17:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcbarry/6ddba19eb6164fe62417ab69f0f9bc24 to your computer and use it in GitHub Desktop.
Save marcbarry/6ddba19eb6164fe62417ab69f0f9bc24 to your computer and use it in GitHub Desktop.
reliability-and-flow-control
/*
Reliability and Flow Control Example
From "Networking for Game Programmers" - http://www.gaffer.org/networking-for-game-programmers
Author: Glenn Fiedler <gaffer@gaffer.org>
*/
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include "Net.h"
//#define SHOW_ACKS
using namespace std;
using namespace net;
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11223344;
const float DeltaTime = 1.0f / 30.0f;
const float SendRate = 1.0f / 30.0f;
const float TimeOut = 10.0f;
const int PacketSize = 256;
class FlowControl
{
public:
FlowControl()
{
printf( "flow control initialized\n" );
Reset();
}
void Reset()
{
mode = Bad;
penalty_time = 4.0f;
good_conditions_time = 0.0f;
penalty_reduction_accumulator = 0.0f;
}
void Update( float deltaTime, float rtt )
{
const float RTT_Threshold = 250.0f;
if ( mode == Good )
{
if ( rtt > RTT_Threshold )
{
printf( "*** dropping to bad mode ***\n" );
mode = Bad;
if ( good_conditions_time < 10.0f && penalty_time < 60.0f )
{
penalty_time *= 2.0f;
if ( penalty_time > 60.0f )
penalty_time = 60.0f;
printf( "penalty time increased to %.1f\n", penalty_time );
}
good_conditions_time = 0.0f;
penalty_reduction_accumulator = 0.0f;
return;
}
good_conditions_time += deltaTime;
penalty_reduction_accumulator += deltaTime;
if ( penalty_reduction_accumulator > 10.0f && penalty_time > 1.0f )
{
penalty_time /= 2.0f;
if ( penalty_time < 1.0f )
penalty_time = 1.0f;
printf( "penalty time reduced to %.1f\n", penalty_time );
penalty_reduction_accumulator = 0.0f;
}
}
if ( mode == Bad )
{
if ( rtt <= RTT_Threshold )
good_conditions_time += deltaTime;
else
good_conditions_time = 0.0f;
if ( good_conditions_time > penalty_time )
{
printf( "*** upgrading to good mode ***\n" );
good_conditions_time = 0.0f;
penalty_reduction_accumulator = 0.0f;
mode = Good;
return;
}
}
}
float GetSendRate()
{
return mode == Good ? 30.0f : 10.0f;
}
private:
enum Mode
{
Good,
Bad
};
Mode mode;
float penalty_time;
float good_conditions_time;
float penalty_reduction_accumulator;
};
// ----------------------------------------------
int main( int argc, char * argv[] )
{
// parse command line
enum Mode
{
Client,
Server
};
Mode mode = Server;
Address address;
if ( argc >= 2 )
{
int a,b,c,d;
if ( sscanf( argv[1], "%d.%d.%d.%d", &a, &b, &c, &d ) )
{
mode = Client;
address = Address(a,b,c,d,ServerPort);
}
}
// initialize
if ( !InitializeSockets() )
{
printf( "failed to initialize sockets\n" );
return 1;
}
ReliableConnection connection( ProtocolId, TimeOut );
const int port = mode == Server ? ServerPort : ClientPort;
if ( !connection.Start( port ) )
{
printf( "could not start connection on port %d\n", port );
return 1;
}
if ( mode == Client )
connection.Connect( address );
else
connection.Listen();
bool connected = false;
float sendAccumulator = 0.0f;
float statsAccumulator = 0.0f;
FlowControl flowControl;
while ( true )
{
// update flow control
if ( connection.IsConnected() )
flowControl.Update( DeltaTime, connection.GetReliabilitySystem().GetRoundTripTime() * 1000.0f );
const float sendRate = flowControl.GetSendRate();
// detect changes in connection state
if ( mode == Server && connected && !connection.IsConnected() )
{
flowControl.Reset();
printf( "reset flow control\n" );
connected = false;
}
if ( !connected && connection.IsConnected() )
{
printf( "client connected to server\n" );
connected = true;
}
if ( !connected && connection.ConnectFailed() )
{
printf( "connection failed\n" );
break;
}
// send and receive packets
sendAccumulator += DeltaTime;
while ( sendAccumulator > 1.0f / sendRate )
{
unsigned char packet[PacketSize];
memset( packet, 0, sizeof( packet ) );
connection.SendPacket( packet, sizeof( packet ) );
sendAccumulator -= 1.0f / sendRate;
}
while ( true )
{
unsigned char packet[256];
int bytes_read = connection.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
// show packets that were acked this frame
#ifdef SHOW_ACKS
unsigned int * acks = NULL;
int ack_count = 0;
connection.GetReliabilitySystem().GetAcks( &acks, ack_count );
if ( ack_count > 0 )
{
printf( "acks: %d", acks[0] );
for ( int i = 1; i < ack_count; ++i )
printf( ",%d", acks[i] );
printf( "\n" );
}
#endif
// update connection
connection.Update( DeltaTime );
// show connection stats
statsAccumulator += DeltaTime;
while ( statsAccumulator >= 0.25f && connection.IsConnected() )
{
float rtt = connection.GetReliabilitySystem().GetRoundTripTime();
unsigned int sent_packets = connection.GetReliabilitySystem().GetSentPackets();
unsigned int acked_packets = connection.GetReliabilitySystem().GetAckedPackets();
unsigned int lost_packets = connection.GetReliabilitySystem().GetLostPackets();
float sent_bandwidth = connection.GetReliabilitySystem().GetSentBandwidth();
float acked_bandwidth = connection.GetReliabilitySystem().GetAckedBandwidth();
printf( "rtt %.1fms, sent %d, acked %d, lost %d (%.1f%%), sent bandwidth = %.1fkbps, acked bandwidth = %.1fkbps\n",
rtt * 1000.0f, sent_packets, acked_packets, lost_packets,
sent_packets > 0.0f ? (float) lost_packets / (float) sent_packets * 100.0f : 0.0f,
sent_bandwidth, acked_bandwidth );
statsAccumulator -= 0.25f;
}
net::wait( DeltaTime );
}
ShutdownSockets();
return 0;
}
# makefile for macosx
flags = -Wall -DDEBUG -std=c++14 # -O3
% : %.cpp Net.h
g++ $< -o $@ ${flags}
all : Example Test
test : Test
./Test
server : Example
./Example
client : Example
./Example 127.0.0.1
clean:
rm -f Test Example
/*
Simple Network Library from "Networking for Game Programmers"
http://www.gaffer.org/networking-for-game-programmers
Author: Glenn Fiedler <gaffer@gaffer.org>
*/
#ifndef NET_H
#define NET_H
#include <cstring> // for memcpy
// platform detection
#define PLATFORM_WINDOWS 1
#define PLATFORM_MAC 2
#define PLATFORM_UNIX 3
#if defined(_WIN32)
#define PLATFORM PLATFORM_WINDOWS
#elif defined(__APPLE__)
#define PLATFORM PLATFORM_MAC
#else
#define PLATFORM PLATFORM_UNIX
#endif
#if PLATFORM == PLATFORM_WINDOWS
#include <winsock2.h>
#pragma comment( lib, "wsock32.lib" )
#elif PLATFORM == PLATFORM_MAC || PLATFORM == PLATFORM_UNIX
#include <sys/socket.h>
#include <netinet/in.h>
#include <fcntl.h>
#else
#error unknown platform!
#endif
#include <assert.h>
#include <vector>
#include <map>
#include <stack>
#include <list>
#include <algorithm>
#include <functional>
namespace net
{
// platform independent wait for n seconds
#if PLATFORM == PLATFORM_WINDOWS
void wait( float seconds )
{
Sleep( (int) ( seconds * 1000.0f ) );
}
#else
#include <unistd.h>
void wait( float seconds ) { usleep( (int) ( seconds * 1000000.0f ) ); }
#endif
// internet address
class Address
{
public:
Address()
{
address = 0;
port = 0;
}
Address( unsigned char a, unsigned char b, unsigned char c, unsigned char d, unsigned short port )
{
this->address = ( a << 24 ) | ( b << 16 ) | ( c << 8 ) | d;
this->port = port;
}
Address( unsigned int address, unsigned short port )
{
this->address = address;
this->port = port;
}
unsigned int GetAddress() const
{
return address;
}
unsigned char GetA() const
{
return ( unsigned char ) ( address >> 24 );
}
unsigned char GetB() const
{
return ( unsigned char ) ( address >> 16 );
}
unsigned char GetC() const
{
return ( unsigned char ) ( address >> 8 );
}
unsigned char GetD() const
{
return ( unsigned char ) ( address );
}
unsigned short GetPort() const
{
return port;
}
bool operator == ( const Address & other ) const
{
return address == other.address && port == other.port;
}
bool operator != ( const Address & other ) const
{
return ! ( *this == other );
}
bool operator < ( const Address & other ) const
{
// note: this is so we can use address as a key in std::map
if ( address < other.address )
return true;
if ( address > other.address )
return false;
else
return port < other.port;
}
private:
unsigned int address;
unsigned short port;
};
// sockets
inline bool InitializeSockets()
{
#if PLATFORM == PLATFORM_WINDOWS
WSADATA WsaData;
return WSAStartup( MAKEWORD(2,2), &WsaData ) != NO_ERROR;
#else
return true;
#endif
}
inline void ShutdownSockets()
{
#if PLATFORM == PLATFORM_WINDOWS
WSACleanup();
#endif
}
class Socket
{
public:
Socket()
{
socket = 0;
}
~Socket()
{
Close();
}
bool Open( unsigned short port )
{
assert( !IsOpen() );
// create socket
socket = ::socket( AF_INET, SOCK_DGRAM, IPPROTO_UDP );
if ( socket <= 0 )
{
printf( "failed to create socket\n" );
socket = 0;
return false;
}
// bind to port
sockaddr_in address;
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons( (unsigned short) port );
if ( bind( socket, (const sockaddr*) &address, sizeof(sockaddr_in) ) < 0 )
{
printf( "failed to bind socket\n" );
Close();
return false;
}
// set non-blocking io
#if PLATFORM == PLATFORM_MAC || PLATFORM == PLATFORM_UNIX
int nonBlocking = 1;
if ( fcntl( socket, F_SETFL, O_NONBLOCK, nonBlocking ) == -1 )
{
printf( "failed to set non-blocking socket\n" );
Close();
return false;
}
#elif PLATFORM == PLATFORM_WINDOWS
DWORD nonBlocking = 1;
if ( ioctlsocket( socket, FIONBIO, &nonBlocking ) != 0 )
{
printf( "failed to set non-blocking socket\n" );
Close();
return false;
}
#endif
return true;
}
void Close()
{
if ( socket != 0 )
{
#if PLATFORM == PLATFORM_MAC || PLATFORM == PLATFORM_UNIX
close( socket );
#elif PLATFORM == PLATFORM_WINDOWS
closesocket( socket );
#endif
socket = 0;
}
}
bool IsOpen() const
{
return socket != 0;
}
bool Send( const Address & destination, const void * data, int size )
{
assert( data );
assert( size > 0 );
if ( socket == 0 )
return false;
assert( destination.GetAddress() != 0 );
assert( destination.GetPort() != 0 );
sockaddr_in address;
address.sin_family = AF_INET;
address.sin_addr.s_addr = htonl( destination.GetAddress() );
address.sin_port = htons( (unsigned short) destination.GetPort() );
int sent_bytes = sendto( socket, (const char*)data, size, 0, (sockaddr*)&address, sizeof(sockaddr_in) );
return sent_bytes == size;
}
int Receive( Address & sender, void * data, int size )
{
assert( data );
assert( size > 0 );
if ( socket == 0 )
return false;
#if PLATFORM == PLATFORM_WINDOWS
typedef int socklen_t;
#endif
sockaddr_in from;
socklen_t fromLength = sizeof( from );
int received_bytes = recvfrom( socket, (char*)data, size, 0, (sockaddr*)&from, &fromLength );
if ( received_bytes <= 0 )
return 0;
unsigned int address = ntohl( from.sin_addr.s_addr );
unsigned short port = ntohs( from.sin_port );
sender = Address( address, port );
return received_bytes;
}
private:
int socket;
};
// connection
class Connection
{
public:
enum Mode
{
None,
Client,
Server
};
Connection( unsigned int protocolId, float timeout )
{
this->protocolId = protocolId;
this->timeout = timeout;
mode = None;
running = false;
ClearData();
}
virtual ~Connection()
{
if ( IsRunning() )
Stop();
}
bool Start( int port )
{
assert( !running );
printf( "start connection on port %d\n", port );
if ( !socket.Open( port ) )
return false;
running = true;
OnStart();
return true;
}
void Stop()
{
assert( running );
printf( "stop connection\n" );
bool connected = IsConnected();
ClearData();
socket.Close();
running = false;
if ( connected )
OnDisconnect();
OnStop();
}
bool IsRunning() const
{
return running;
}
void Listen()
{
printf( "server listening for connection\n" );
bool connected = IsConnected();
ClearData();
if ( connected )
OnDisconnect();
mode = Server;
state = Listening;
}
void Connect( const Address & address )
{
printf( "client connecting to %d.%d.%d.%d:%d\n",
address.GetA(), address.GetB(), address.GetC(), address.GetD(), address.GetPort() );
bool connected = IsConnected();
ClearData();
if ( connected )
OnDisconnect();
mode = Client;
state = Connecting;
this->address = address;
}
bool IsConnecting() const
{
return state == Connecting;
}
bool ConnectFailed() const
{
return state == ConnectFail;
}
bool IsConnected() const
{
return state == Connected;
}
bool IsListening() const
{
return state == Listening;
}
Mode GetMode() const
{
return mode;
}
virtual void Update( float deltaTime )
{
assert( running );
timeoutAccumulator += deltaTime;
if ( timeoutAccumulator > timeout )
{
if ( state == Connecting )
{
printf( "connect timed out\n" );
ClearData();
state = ConnectFail;
OnDisconnect();
}
else if ( state == Connected )
{
printf( "connection timed out\n" );
ClearData();
if ( state == Connecting )
state = ConnectFail;
OnDisconnect();
}
}
}
virtual bool SendPacket( const unsigned char data[], int size )
{
assert( running );
if ( address.GetAddress() == 0 )
return false;
unsigned char packet[size+4];
packet[0] = (unsigned char) ( protocolId >> 24 );
packet[1] = (unsigned char) ( ( protocolId >> 16 ) & 0xFF );
packet[2] = (unsigned char) ( ( protocolId >> 8 ) & 0xFF );
packet[3] = (unsigned char) ( ( protocolId ) & 0xFF );
std::memcpy( &packet[4], data, size );
return socket.Send( address, packet, size + 4 );
}
virtual int ReceivePacket( unsigned char data[], int size )
{
assert( running );
unsigned char packet[size+4];
Address sender;
int bytes_read = socket.Receive( sender, packet, size + 4 );
if ( bytes_read == 0 )
return 0;
if ( bytes_read <= 4 )
return 0;
if ( packet[0] != (unsigned char) ( protocolId >> 24 ) ||
packet[1] != (unsigned char) ( ( protocolId >> 16 ) & 0xFF ) ||
packet[2] != (unsigned char) ( ( protocolId >> 8 ) & 0xFF ) ||
packet[3] != (unsigned char) ( protocolId & 0xFF ) )
return 0;
if ( mode == Server && !IsConnected() )
{
printf( "server accepts connection from client %d.%d.%d.%d:%d\n",
sender.GetA(), sender.GetB(), sender.GetC(), sender.GetD(), sender.GetPort() );
state = Connected;
address = sender;
OnConnect();
}
if ( sender == address )
{
if ( mode == Client && state == Connecting )
{
printf( "client completes connection with server\n" );
state = Connected;
OnConnect();
}
timeoutAccumulator = 0.0f;
memcpy( data, &packet[4], bytes_read - 4 );
return bytes_read - 4;
}
return 0;
}
int GetHeaderSize() const
{
return 4;
}
protected:
virtual void OnStart() {}
virtual void OnStop() {}
virtual void OnConnect() {}
virtual void OnDisconnect() {}
private:
void ClearData()
{
state = Disconnected;
timeoutAccumulator = 0.0f;
address = Address();
}
enum State
{
Disconnected,
Listening,
Connecting,
ConnectFail,
Connected
};
unsigned int protocolId;
float timeout;
bool running;
Mode mode;
State state;
Socket socket;
float timeoutAccumulator;
Address address;
};
// packet queue to store information about sent and received packets sorted in sequence order
// + we define ordering using the "sequence_more_recent" function, this works provided there is a large gap when sequence wrap occurs
struct PacketData
{
unsigned int sequence; // packet sequence number
float time; // time offset since packet was sent or received (depending on context)
int size; // packet size in bytes
};
inline bool sequence_more_recent( unsigned int s1, unsigned int s2, unsigned int max_sequence )
{
auto half_max = max_sequence / 2;
return (
(( s1 > s2 ) && ( s1 - s2 <= half_max ))
||
(( s2 > s1 ) && ( s2 - s1 > half_max ))
);
}
class PacketQueue : public std::list<PacketData>
{
public:
bool exists( unsigned int sequence )
{
for ( iterator itor = begin(); itor != end(); ++itor )
if ( itor->sequence == sequence )
return true;
return false;
}
void insert_sorted( const PacketData & p, unsigned int max_sequence )
{
if ( empty() )
{
push_back( p );
}
else
{
if ( !sequence_more_recent( p.sequence, front().sequence, max_sequence ) )
{
push_front( p );
}
else if ( sequence_more_recent( p.sequence, back().sequence, max_sequence ) )
{
push_back( p );
}
else
{
for ( PacketQueue::iterator itor = begin(); itor != end(); itor++ )
{
assert( itor->sequence != p.sequence );
if ( sequence_more_recent( itor->sequence, p.sequence, max_sequence ) )
{
insert( itor, p );
break;
}
}
}
}
}
void verify_sorted( unsigned int max_sequence )
{
PacketQueue::iterator prev = end();
for ( PacketQueue::iterator itor = begin(); itor != end(); itor++ )
{
assert( itor->sequence <= max_sequence );
if ( prev != end() )
{
assert( sequence_more_recent( itor->sequence, prev->sequence, max_sequence ) );
prev = itor;
}
}
}
};
// reliability system to support reliable connection
// + manages sent, received, pending ack and acked packet queues
// + separated out from reliable connection because it is quite complex and i want to unit test it!
class ReliabilitySystem
{
public:
ReliabilitySystem( unsigned int max_sequence = 0xFFFFFFFF )
{
this->rtt_maximum = rtt_maximum;
this->max_sequence = max_sequence;
Reset();
}
void Reset()
{
local_sequence = 0;
remote_sequence = 0;
sentQueue.clear();
receivedQueue.clear();
pendingAckQueue.clear();
ackedQueue.clear();
sent_packets = 0;
recv_packets = 0;
lost_packets = 0;
acked_packets = 0;
sent_bandwidth = 0.0f;
acked_bandwidth = 0.0f;
rtt = 0.0f;
rtt_maximum = 1.0f;
}
void PacketSent( int size )
{
if ( sentQueue.exists( local_sequence ) )
{
printf( "local sequence %d exists\n", local_sequence );
for ( PacketQueue::iterator itor = sentQueue.begin(); itor != sentQueue.end(); ++itor )
printf( " + %d\n", itor->sequence );
}
assert( !sentQueue.exists( local_sequence ) );
assert( !pendingAckQueue.exists( local_sequence ) );
PacketData data;
data.sequence = local_sequence;
data.time = 0.0f;
data.size = size;
sentQueue.push_back( data );
pendingAckQueue.push_back( data );
sent_packets++;
local_sequence++;
if ( local_sequence > max_sequence )
local_sequence = 0;
}
void PacketReceived( unsigned int sequence, int size )
{
recv_packets++;
if ( receivedQueue.exists( sequence ) )
return;
PacketData data;
data.sequence = sequence;
data.time = 0.0f;
data.size = size;
receivedQueue.push_back( data );
if ( sequence_more_recent( sequence, remote_sequence, max_sequence ) )
remote_sequence = sequence;
}
unsigned int GenerateAckBits()
{
return generate_ack_bits( GetRemoteSequence(), receivedQueue, max_sequence );
}
void ProcessAck( unsigned int ack, unsigned int ack_bits )
{
process_ack( ack, ack_bits, pendingAckQueue, ackedQueue, acks, acked_packets, rtt, max_sequence );
}
void Update( float deltaTime )
{
acks.clear();
AdvanceQueueTime( deltaTime );
UpdateQueues();
UpdateStats();
#ifdef NET_UNIT_TEST
Validate();
#endif
}
void Validate()
{
sentQueue.verify_sorted( max_sequence );
receivedQueue.verify_sorted( max_sequence );
pendingAckQueue.verify_sorted( max_sequence );
ackedQueue.verify_sorted( max_sequence );
}
// utility functions
/*
static bool sequence_more_recent( unsigned int s1, unsigned int s2, unsigned int max_sequence )
{
return ( s1 > s2 ) && ( s1 - s2 <= max_sequence/2 ) || ( s2 > s1 ) && ( s2 - s1 > max_sequence/2 );
}
*/
static int bit_index_for_sequence( unsigned int sequence, unsigned int ack, unsigned int max_sequence )
{
assert( sequence != ack );
assert( !sequence_more_recent( sequence, ack, max_sequence ) );
if ( sequence > ack )
{
assert( ack < 33 );
assert( max_sequence >= sequence );
return ack + ( max_sequence - sequence );
}
else
{
assert( ack >= 1 );
assert( sequence <= ack - 1 );
return ack - 1 - sequence;
}
}
static unsigned int generate_ack_bits( unsigned int ack, const PacketQueue & received_queue, unsigned int max_sequence )
{
unsigned int ack_bits = 0;
for ( PacketQueue::const_iterator itor = received_queue.begin(); itor != received_queue.end(); itor++ )
{
if ( itor->sequence == ack || sequence_more_recent( itor->sequence, ack, max_sequence ) )
break;
int bit_index = bit_index_for_sequence( itor->sequence, ack, max_sequence );
if ( bit_index <= 31 )
ack_bits |= 1 << bit_index;
}
return ack_bits;
}
static void process_ack( unsigned int ack, unsigned int ack_bits,
PacketQueue & pending_ack_queue, PacketQueue & acked_queue,
std::vector<unsigned int> & acks, unsigned int & acked_packets,
float & rtt, unsigned int max_sequence )
{
if ( pending_ack_queue.empty() )
return;
PacketQueue::iterator itor = pending_ack_queue.begin();
while ( itor != pending_ack_queue.end() )
{
bool acked = false;
if ( itor->sequence == ack )
{
acked = true;
}
else if ( !sequence_more_recent( itor->sequence, ack, max_sequence ) )
{
int bit_index = bit_index_for_sequence( itor->sequence, ack, max_sequence );
if ( bit_index <= 31 )
acked = ( ack_bits >> bit_index ) & 1;
}
if ( acked )
{
rtt += ( itor->time - rtt ) * 0.1f;
acked_queue.insert_sorted( *itor, max_sequence );
acks.push_back( itor->sequence );
acked_packets++;
itor = pending_ack_queue.erase( itor );
}
else
++itor;
}
}
// data accessors
unsigned int GetLocalSequence() const
{
return local_sequence;
}
unsigned int GetRemoteSequence() const
{
return remote_sequence;
}
unsigned int GetMaxSequence() const
{
return max_sequence;
}
void GetAcks( unsigned int ** acks, int & count )
{
*acks = &this->acks[0];
count = (int) this->acks.size();
}
unsigned int GetSentPackets() const
{
return sent_packets;
}
unsigned int GetReceivedPackets() const
{
return recv_packets;
}
unsigned int GetLostPackets() const
{
return lost_packets;
}
unsigned int GetAckedPackets() const
{
return acked_packets;
}
float GetSentBandwidth() const
{
return sent_bandwidth;
}
float GetAckedBandwidth() const
{
return acked_bandwidth;
}
float GetRoundTripTime() const
{
return rtt;
}
int GetHeaderSize() const
{
return 12;
}
protected:
void AdvanceQueueTime( float deltaTime )
{
for ( PacketQueue::iterator itor = sentQueue.begin(); itor != sentQueue.end(); itor++ )
itor->time += deltaTime;
for ( PacketQueue::iterator itor = receivedQueue.begin(); itor != receivedQueue.end(); itor++ )
itor->time += deltaTime;
for ( PacketQueue::iterator itor = pendingAckQueue.begin(); itor != pendingAckQueue.end(); itor++ )
itor->time += deltaTime;
for ( PacketQueue::iterator itor = ackedQueue.begin(); itor != ackedQueue.end(); itor++ )
itor->time += deltaTime;
}
void UpdateQueues()
{
const float epsilon = 0.001f;
while ( sentQueue.size() && sentQueue.front().time > rtt_maximum + epsilon )
sentQueue.pop_front();
if ( receivedQueue.size() )
{
const unsigned int latest_sequence = receivedQueue.back().sequence;
const unsigned int minimum_sequence = latest_sequence >= 34 ? ( latest_sequence - 34 ) : max_sequence - ( 34 - latest_sequence );
while ( receivedQueue.size() && !sequence_more_recent( receivedQueue.front().sequence, minimum_sequence, max_sequence ) )
receivedQueue.pop_front();
}
while ( ackedQueue.size() && ackedQueue.front().time > rtt_maximum * 2 - epsilon )
ackedQueue.pop_front();
while ( pendingAckQueue.size() && pendingAckQueue.front().time > rtt_maximum + epsilon )
{
pendingAckQueue.pop_front();
lost_packets++;
}
}
void UpdateStats()
{
int sent_bytes_per_second = 0;
for ( PacketQueue::iterator itor = sentQueue.begin(); itor != sentQueue.end(); ++itor )
sent_bytes_per_second += itor->size;
int acked_packets_per_second = 0;
int acked_bytes_per_second = 0;
for ( PacketQueue::iterator itor = ackedQueue.begin(); itor != ackedQueue.end(); ++itor )
{
if ( itor->time >= rtt_maximum )
{
acked_packets_per_second++;
acked_bytes_per_second += itor->size;
}
}
sent_bytes_per_second /= rtt_maximum;
acked_bytes_per_second /= rtt_maximum;
sent_bandwidth = sent_bytes_per_second * ( 8 / 1000.0f );
acked_bandwidth = acked_bytes_per_second * ( 8 / 1000.0f );
}
private:
unsigned int max_sequence; // maximum sequence value before wrap around (used to test sequence wrap at low # values)
unsigned int local_sequence; // local sequence number for most recently sent packet
unsigned int remote_sequence; // remote sequence number for most recently received packet
unsigned int sent_packets; // total number of packets sent
unsigned int recv_packets; // total number of packets received
unsigned int lost_packets; // total number of packets lost
unsigned int acked_packets; // total number of packets acked
float sent_bandwidth; // approximate sent bandwidth over the last second
float acked_bandwidth; // approximate acked bandwidth over the last second
float rtt; // estimated round trip time
float rtt_maximum; // maximum expected round trip time (hard coded to one second for the moment)
std::vector<unsigned int> acks; // acked packets from last set of packet receives. cleared each update!
PacketQueue sentQueue; // sent packets used to calculate sent bandwidth (kept until rtt_maximum)
PacketQueue pendingAckQueue; // sent packets which have not been acked yet (kept until rtt_maximum * 2 )
PacketQueue receivedQueue; // received packets for determining acks to send (kept up to most recent recv sequence - 32)
PacketQueue ackedQueue; // acked packets (kept until rtt_maximum * 2)
};
// connection with reliability (seq/ack)
class ReliableConnection : public Connection
{
public:
ReliableConnection( unsigned int protocolId, float timeout, unsigned int max_sequence = 0xFFFFFFFF )
: Connection( protocolId, timeout ), reliabilitySystem( max_sequence )
{
ClearData();
#ifdef NET_UNIT_TEST
packet_loss_mask = 0;
#endif
}
~ReliableConnection()
{
if ( IsRunning() )
Stop();
}
// overriden functions from "Connection"
bool SendPacket( const unsigned char data[], int size )
{
#ifdef NET_UNIT_TEST
if ( reliabilitySystem.GetLocalSequence() & packet_loss_mask )
{
reliabilitySystem.PacketSent( size );
return true;
}
#endif
const int header = 12;
unsigned char packet[header+size];
unsigned int seq = reliabilitySystem.GetLocalSequence();
unsigned int ack = reliabilitySystem.GetRemoteSequence();
unsigned int ack_bits = reliabilitySystem.GenerateAckBits();
WriteHeader( packet, seq, ack, ack_bits );
std::memcpy( packet + header, data, size );
if ( !Connection::SendPacket( packet, size + header ) )
return false;
reliabilitySystem.PacketSent( size );
return true;
}
int ReceivePacket( unsigned char data[], int size )
{
const int header = 12;
if ( size <= header )
return false;
unsigned char packet[header+size];
int received_bytes = Connection::ReceivePacket( packet, size + header );
if ( received_bytes == 0 )
return false;
if ( received_bytes <= header )
return false;
unsigned int packet_sequence = 0;
unsigned int packet_ack = 0;
unsigned int packet_ack_bits = 0;
ReadHeader( packet, packet_sequence, packet_ack, packet_ack_bits );
reliabilitySystem.PacketReceived( packet_sequence, received_bytes - header );
reliabilitySystem.ProcessAck( packet_ack, packet_ack_bits );
std::memcpy( data, packet + header, received_bytes - header );
return received_bytes - header;
}
void Update( float deltaTime )
{
Connection::Update( deltaTime );
reliabilitySystem.Update( deltaTime );
}
int GetHeaderSize() const
{
return Connection::GetHeaderSize() + reliabilitySystem.GetHeaderSize();
}
ReliabilitySystem & GetReliabilitySystem()
{
return reliabilitySystem;
}
// unit test controls
#ifdef NET_UNIT_TEST
void SetPacketLossMask( unsigned int mask )
{
packet_loss_mask = mask;
}
#endif
protected:
void WriteInteger( unsigned char * data, unsigned int value )
{
data[0] = (unsigned char) ( value >> 24 );
data[1] = (unsigned char) ( ( value >> 16 ) & 0xFF );
data[2] = (unsigned char) ( ( value >> 8 ) & 0xFF );
data[3] = (unsigned char) ( value & 0xFF );
}
void WriteHeader( unsigned char * header, unsigned int sequence, unsigned int ack, unsigned int ack_bits )
{
WriteInteger( header, sequence );
WriteInteger( header + 4, ack );
WriteInteger( header + 8, ack_bits );
}
void ReadInteger( const unsigned char * data, unsigned int & value )
{
value = ( ( (unsigned int)data[0] << 24 ) | ( (unsigned int)data[1] << 16 ) |
( (unsigned int)data[2] << 8 ) | ( (unsigned int)data[3] ) );
}
void ReadHeader( const unsigned char * header, unsigned int & sequence, unsigned int & ack, unsigned int & ack_bits )
{
ReadInteger( header, sequence );
ReadInteger( header + 4, ack );
ReadInteger( header + 8, ack_bits );
}
virtual void OnStop()
{
ClearData();
}
virtual void OnDisconnect()
{
ClearData();
}
private:
void ClearData()
{
reliabilitySystem.Reset();
}
#ifdef NET_UNIT_TEST
unsigned int packet_loss_mask; // mask sequence number, if non-zero, drop packet - for unit test only
#endif
ReliabilitySystem reliabilitySystem; // reliability system: manages sequence numbers and acks, tracks network stats etc.
};
}
#endif
/*
Unit Tests for Reliable Connection
From "Networking for Game Programmers" - http://www.gaffer.org/networking-for-game-programmers
Author: Glenn Fiedler <gaffer@gaffer.org>
*/
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#define NET_UNIT_TEST
#include "Net.h"
using namespace std;
using namespace net;
#ifdef DEBUG
#define check assert
#else
#define check(n) if ( !n ) { printf( "check failed\n" ); exit(1); }
#endif
#define CHECK_ACKS check( ack_count == 0 || (ack_count != 0 && acks) );
void test_packet_queue()
{
printf( "-----------------------------------------------------\n" );
printf( "test packet queue\n" );
printf( "-----------------------------------------------------\n" );
const unsigned int MaximumSequence = 255;
PacketQueue packetQueue;
printf( "check insert back\n" );
for ( int i = 0; i < 100; ++i )
{
PacketData data;
data.sequence = i;
packetQueue.insert_sorted( data, MaximumSequence );
packetQueue.verify_sorted( MaximumSequence );
}
printf( "check insert front\n" );
packetQueue.clear();
for ( int i = 100; i < 0; ++i )
{
PacketData data;
data.sequence = i;
packetQueue.insert_sorted( data, MaximumSequence );
packetQueue.verify_sorted( MaximumSequence );
}
printf( "check insert random\n" );
packetQueue.clear();
for ( int i = 100; i < 0; ++i )
{
PacketData data;
data.sequence = rand() & 0xFF;
packetQueue.insert_sorted( data, MaximumSequence );
packetQueue.verify_sorted( MaximumSequence );
}
printf( "check insert wrap around\n" );
packetQueue.clear();
for ( int i = 200; i <= 255; ++i )
{
PacketData data;
data.sequence = i;
packetQueue.insert_sorted( data, MaximumSequence );
packetQueue.verify_sorted( MaximumSequence );
}
for ( int i = 0; i <= 50; ++i )
{
PacketData data;
data.sequence = i;
packetQueue.insert_sorted( data, MaximumSequence );
packetQueue.verify_sorted( MaximumSequence );
}
}
void test_reliability_system()
{
printf( "-----------------------------------------------------\n" );
printf( "test reliability system\n" );
printf( "-----------------------------------------------------\n" );
const int MaximumSequence = 255;
printf( "check bit index for sequence\n" );
check( ReliabilitySystem::bit_index_for_sequence( 99, 100, MaximumSequence ) == 0 );
check( ReliabilitySystem::bit_index_for_sequence( 90, 100, MaximumSequence ) == 9 );
check( ReliabilitySystem::bit_index_for_sequence( 0, 1, MaximumSequence ) == 0 );
check( ReliabilitySystem::bit_index_for_sequence( 255, 0, MaximumSequence ) == 0 );
check( ReliabilitySystem::bit_index_for_sequence( 255, 1, MaximumSequence ) == 1 );
check( ReliabilitySystem::bit_index_for_sequence( 254, 1, MaximumSequence ) == 2 );
check( ReliabilitySystem::bit_index_for_sequence( 254, 2, MaximumSequence ) == 3 );
printf( "check generate ack bits\n");
PacketQueue packetQueue;
for ( int i = 0; i < 32; ++i )
{
PacketData data;
data.sequence = i;
packetQueue.insert_sorted( data, MaximumSequence );
packetQueue.verify_sorted( MaximumSequence );
}
check( ReliabilitySystem::generate_ack_bits( 32, packetQueue, MaximumSequence ) == 0xFFFFFFFF );
check( ReliabilitySystem::generate_ack_bits( 31, packetQueue, MaximumSequence ) == 0x7FFFFFFF );
check( ReliabilitySystem::generate_ack_bits( 33, packetQueue, MaximumSequence ) == 0xFFFFFFFE );
check( ReliabilitySystem::generate_ack_bits( 16, packetQueue, MaximumSequence ) == 0x0000FFFF );
check( ReliabilitySystem::generate_ack_bits( 48, packetQueue, MaximumSequence ) == 0xFFFF0000 );
printf( "check generate ack bits with wrap\n");
packetQueue.clear();
for ( int i = 255 - 31; i <= 255; ++i )
{
PacketData data;
data.sequence = i;
packetQueue.insert_sorted( data, MaximumSequence );
packetQueue.verify_sorted( MaximumSequence );
}
check( packetQueue.size() == 32 );
check( ReliabilitySystem::generate_ack_bits( 0, packetQueue, MaximumSequence ) == 0xFFFFFFFF );
check( ReliabilitySystem::generate_ack_bits( 255, packetQueue, MaximumSequence ) == 0x7FFFFFFF );
check( ReliabilitySystem::generate_ack_bits( 1, packetQueue, MaximumSequence ) == 0xFFFFFFFE );
check( ReliabilitySystem::generate_ack_bits( 240, packetQueue, MaximumSequence ) == 0x0000FFFF );
check( ReliabilitySystem::generate_ack_bits( 16, packetQueue, MaximumSequence ) == 0xFFFF0000 );
printf( "check process ack (1)\n" );
{
PacketQueue pendingAckQueue;
for ( int i = 0; i < 33; ++i )
{
PacketData data;
data.sequence = i;
data.time = 0.0f;
pendingAckQueue.insert_sorted( data, MaximumSequence );
pendingAckQueue.verify_sorted( MaximumSequence );
}
PacketQueue ackedQueue;
std::vector<unsigned int> acks;
float rtt = 0.0f;
unsigned int acked_packets = 0;
ReliabilitySystem::process_ack( 32, 0xFFFFFFFF, pendingAckQueue, ackedQueue, acks, acked_packets, rtt, MaximumSequence );
check( acks.size() == 33 );
check( acked_packets == 33 );
check( ackedQueue.size() == 33 );
check( pendingAckQueue.size() == 0 );
ackedQueue.verify_sorted( MaximumSequence );
for ( unsigned int i = 0; i < acks.size(); ++i )
check( acks[i] == i );
unsigned int i = 0;
for ( PacketQueue::iterator itor = ackedQueue.begin(); itor != ackedQueue.end(); ++itor, ++i )
check( itor->sequence == i );
}
printf( "check process ack (2)\n" );
{
PacketQueue pendingAckQueue;
for ( int i = 0; i < 33; ++i )
{
PacketData data;
data.sequence = i;
data.time = 0.0f;
pendingAckQueue.insert_sorted( data, MaximumSequence );
pendingAckQueue.verify_sorted( MaximumSequence );
}
PacketQueue ackedQueue;
std::vector<unsigned int> acks;
float rtt = 0.0f;
unsigned int acked_packets = 0;
ReliabilitySystem::process_ack( 32, 0x0000FFFF, pendingAckQueue, ackedQueue, acks, acked_packets, rtt, MaximumSequence );
check( acks.size() == 17 );
check( acked_packets == 17 );
check( ackedQueue.size() == 17 );
check( pendingAckQueue.size() == 33 - 17 );
ackedQueue.verify_sorted( MaximumSequence );
unsigned int i = 0;
for ( PacketQueue::iterator itor = pendingAckQueue.begin(); itor != pendingAckQueue.end(); ++itor, ++i )
check( itor->sequence == i );
i = 0;
for ( PacketQueue::iterator itor = ackedQueue.begin(); itor != ackedQueue.end(); ++itor, ++i )
check( itor->sequence == i + 16 );
for ( unsigned int i = 0; i < acks.size(); ++i )
check( acks[i] == i + 16 );
}
printf( "check process ack (3)\n" );
{
PacketQueue pendingAckQueue;
for ( int i = 0; i < 32; ++i )
{
PacketData data;
data.sequence = i;
data.time = 0.0f;
pendingAckQueue.insert_sorted( data, MaximumSequence );
pendingAckQueue.verify_sorted( MaximumSequence );
}
PacketQueue ackedQueue;
std::vector<unsigned int> acks;
float rtt = 0.0f;
unsigned int acked_packets = 0;
ReliabilitySystem::process_ack( 48, 0xFFFF0000, pendingAckQueue, ackedQueue, acks, acked_packets, rtt, MaximumSequence );
check( acks.size() == 16 );
check( acked_packets == 16 );
check( ackedQueue.size() == 16 );
check( pendingAckQueue.size() == 16 );
ackedQueue.verify_sorted( MaximumSequence );
unsigned int i = 0;
for ( PacketQueue::iterator itor = pendingAckQueue.begin(); itor != pendingAckQueue.end(); ++itor, ++i )
check( itor->sequence == i );
i = 0;
for ( PacketQueue::iterator itor = ackedQueue.begin(); itor != ackedQueue.end(); ++itor, ++i )
check( itor->sequence == i + 16 );
for ( unsigned int i = 0; i < acks.size(); ++i )
check( acks[i] == i + 16 );
}
printf( "check process ack wrap around (1)\n" );
{
PacketQueue pendingAckQueue;
for ( int i = 255 - 31; i <= 256; ++i )
{
PacketData data;
data.sequence = i & 0xFF;
data.time = 0.0f;
pendingAckQueue.insert_sorted( data, MaximumSequence );
pendingAckQueue.verify_sorted( MaximumSequence );
}
check( pendingAckQueue.size() == 33 );
PacketQueue ackedQueue;
std::vector<unsigned int> acks;
float rtt = 0.0f;
unsigned int acked_packets = 0;
ReliabilitySystem::process_ack( 0, 0xFFFFFFFF, pendingAckQueue, ackedQueue, acks, acked_packets, rtt, MaximumSequence );
check( acks.size() == 33 );
check( acked_packets == 33 );
check( ackedQueue.size() == 33 );
check( pendingAckQueue.size() == 0 );
ackedQueue.verify_sorted( MaximumSequence );
for ( unsigned int i = 0; i < acks.size(); ++i )
check( acks[i] == ( (i+255-31) & 0xFF ) );
unsigned int i = 0;
for ( PacketQueue::iterator itor = ackedQueue.begin(); itor != ackedQueue.end(); ++itor, ++i )
check( itor->sequence == ( (i+255-31) & 0xFF ) );
}
printf( "check process ack wrap around (2)\n" );
{
PacketQueue pendingAckQueue;
for ( int i = 255 - 31; i <= 256; ++i )
{
PacketData data;
data.sequence = i & 0xFF;
data.time = 0.0f;
pendingAckQueue.insert_sorted( data, MaximumSequence );
pendingAckQueue.verify_sorted( MaximumSequence );
}
check( pendingAckQueue.size() == 33 );
PacketQueue ackedQueue;
std::vector<unsigned int> acks;
float rtt = 0.0f;
unsigned int acked_packets = 0;
ReliabilitySystem::process_ack( 0, 0x0000FFFF, pendingAckQueue, ackedQueue, acks, acked_packets, rtt, MaximumSequence );
check( acks.size() == 17 );
check( acked_packets == 17 );
check( ackedQueue.size() == 17 );
check( pendingAckQueue.size() == 33 - 17 );
ackedQueue.verify_sorted( MaximumSequence );
for ( unsigned int i = 0; i < acks.size(); ++i )
check( acks[i] == ( (i+255-15) & 0xFF ) );
unsigned int i = 0;
for ( PacketQueue::iterator itor = pendingAckQueue.begin(); itor != pendingAckQueue.end(); ++itor, ++i )
check( itor->sequence == i + 255 - 31 );
i = 0;
for ( PacketQueue::iterator itor = ackedQueue.begin(); itor != ackedQueue.end(); ++itor, ++i )
check( itor->sequence == ( (i+255-15) & 0xFF ) );
}
printf( "check process ack wrap around (3)\n" );
{
PacketQueue pendingAckQueue;
for ( int i = 255 - 31; i <= 255; ++i )
{
PacketData data;
data.sequence = i & 0xFF;
data.time = 0.0f;
pendingAckQueue.insert_sorted( data, MaximumSequence );
pendingAckQueue.verify_sorted( MaximumSequence );
}
check( pendingAckQueue.size() == 32 );
PacketQueue ackedQueue;
std::vector<unsigned int> acks;
float rtt = 0.0f;
unsigned int acked_packets = 0;
ReliabilitySystem::process_ack( 16, 0xFFFF0000, pendingAckQueue, ackedQueue, acks, acked_packets, rtt, MaximumSequence );
check( acks.size() == 16 );
check( acked_packets == 16 );
check( ackedQueue.size() == 16 );
check( pendingAckQueue.size() == 16 );
ackedQueue.verify_sorted( MaximumSequence );
for ( unsigned int i = 0; i < acks.size(); ++i )
check( acks[i] == ( (i+255-15) & 0xFF ) );
unsigned int i = 0;
for ( PacketQueue::iterator itor = pendingAckQueue.begin(); itor != pendingAckQueue.end(); ++itor, ++i )
check( itor->sequence == i + 255 - 31 );
i = 0;
for ( PacketQueue::iterator itor = ackedQueue.begin(); itor != ackedQueue.end(); ++itor, ++i )
check( itor->sequence == ( (i+255-15) & 0xFF ) );
}
}
void test_join()
{
printf( "-----------------------------------------------------\n" );
printf( "test join\n" );
printf( "-----------------------------------------------------\n" );
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11112222;
const float DeltaTime = 0.001f;
const float TimeOut = 1.0f;
ReliableConnection client( ProtocolId, TimeOut );
ReliableConnection server( ProtocolId, TimeOut );
check( client.Start( ClientPort ) );
check( server.Start( ServerPort ) );
client.Connect( Address(127,0,0,1,ServerPort ) );
server.Listen();
while ( true )
{
if ( client.IsConnected() && server.IsConnected() )
break;
if ( !client.IsConnecting() && client.ConnectFailed() )
break;
unsigned char client_packet[] = "client to server";
client.SendPacket( client_packet, sizeof( client_packet ) );
unsigned char server_packet[] = "server to client";
server.SendPacket( server_packet, sizeof( server_packet ) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
client.Update( DeltaTime );
server.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
}
void test_join_timeout()
{
printf( "-----------------------------------------------------\n" );
printf( "test join timeout\n" );
printf( "-----------------------------------------------------\n" );
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11112222;
const float DeltaTime = 0.001f;
const float TimeOut = 0.1f;
ReliableConnection client( ProtocolId, TimeOut );
check( client.Start( ClientPort ) );
client.Connect( Address(127,0,0,1,ServerPort ) );
while ( true )
{
if ( !client.IsConnecting() )
break;
unsigned char client_packet[] = "client to server";
client.SendPacket( client_packet, sizeof( client_packet ) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
client.Update( DeltaTime );
}
check( !client.IsConnected() );
check( client.ConnectFailed() );
}
void test_join_busy()
{
printf( "-----------------------------------------------------\n" );
printf( "test join busy\n" );
printf( "-----------------------------------------------------\n" );
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11112222;
const float DeltaTime = 0.001f;
const float TimeOut = 0.1f;
// connect client to server
ReliableConnection client( ProtocolId, TimeOut );
ReliableConnection server( ProtocolId, TimeOut );
check( client.Start( ClientPort ) );
check( server.Start( ServerPort ) );
client.Connect( Address(127,0,0,1,ServerPort ) );
server.Listen();
while ( true )
{
if ( client.IsConnected() && server.IsConnected() )
break;
if ( !client.IsConnecting() && client.ConnectFailed() )
break;
unsigned char client_packet[] = "client to server";
client.SendPacket( client_packet, sizeof( client_packet ) );
unsigned char server_packet[] = "server to client";
server.SendPacket( server_packet, sizeof( server_packet ) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
client.Update( DeltaTime );
server.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
// attempt another connection, verify connect fails (busy)
ReliableConnection busy( ProtocolId, TimeOut );
check( busy.Start( ClientPort + 1 ) );
busy.Connect( Address(127,0,0,1,ServerPort ) );
while ( true )
{
if ( !busy.IsConnecting() || busy.IsConnected() )
break;
unsigned char client_packet[] = "client to server";
client.SendPacket( client_packet, sizeof( client_packet ) );
unsigned char server_packet[] = "server to client";
server.SendPacket( server_packet, sizeof( server_packet ) );
unsigned char busy_packet[] = "i'm so busy!";
busy.SendPacket( busy_packet, sizeof( busy_packet ) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
while ( true )
{
unsigned char packet[256];
int bytes_read = busy.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
client.Update( DeltaTime );
server.Update( DeltaTime );
busy.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
check( !busy.IsConnected() );
check( busy.ConnectFailed() );
}
void test_rejoin()
{
printf( "-----------------------------------------------------\n" );
printf( "test rejoin\n" );
printf( "-----------------------------------------------------\n" );
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11112222;
const float DeltaTime = 0.001f;
const float TimeOut = 0.1f;
ReliableConnection client( ProtocolId, TimeOut );
ReliableConnection server( ProtocolId, TimeOut );
check( client.Start( ClientPort ) );
check( server.Start( ServerPort ) );
// connect client and server
client.Connect( Address(127,0,0,1,ServerPort ) );
server.Listen();
while ( true )
{
if ( client.IsConnected() && server.IsConnected() )
break;
if ( !client.IsConnecting() && client.ConnectFailed() )
break;
unsigned char client_packet[] = "client to server";
client.SendPacket( client_packet, sizeof( client_packet ) );
unsigned char server_packet[] = "server to client";
server.SendPacket( server_packet, sizeof( server_packet ) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
client.Update( DeltaTime );
server.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
// let connection timeout
while ( client.IsConnected() || server.IsConnected() )
{
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
client.Update( DeltaTime );
server.Update( DeltaTime );
}
check( !client.IsConnected() );
check( !server.IsConnected() );
// reconnect client
client.Connect( Address(127,0,0,1,ServerPort ) );
while ( true )
{
if ( client.IsConnected() && server.IsConnected() )
break;
if ( !client.IsConnecting() && client.ConnectFailed() )
break;
unsigned char client_packet[] = "client to server";
client.SendPacket( client_packet, sizeof( client_packet ) );
unsigned char server_packet[] = "server to client";
server.SendPacket( server_packet, sizeof( server_packet ) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
}
client.Update( DeltaTime );
server.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
}
void test_payload()
{
printf( "-----------------------------------------------------\n" );
printf( "test payload\n" );
printf( "-----------------------------------------------------\n" );
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11112222;
const float DeltaTime = 0.001f;
const float TimeOut = 0.1f;
ReliableConnection client( ProtocolId, TimeOut );
ReliableConnection server( ProtocolId, TimeOut );
check( client.Start( ClientPort ) );
check( server.Start( ServerPort ) );
client.Connect( Address(127,0,0,1,ServerPort ) );
server.Listen();
while ( true )
{
if ( client.IsConnected() && server.IsConnected() )
break;
if ( !client.IsConnecting() && client.ConnectFailed() )
break;
unsigned char client_packet[] = "client to server";
client.SendPacket( client_packet, sizeof( client_packet ) );
unsigned char server_packet[] = "server to client";
server.SendPacket( server_packet, sizeof( server_packet ) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( strcmp( (const char*) packet, "server to client" ) == 0 );
}
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( strcmp( (const char*) packet, "client to server" ) == 0 );
}
client.Update( DeltaTime );
server.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
}
void test_acks()
{
printf( "-----------------------------------------------------\n" );
printf( "test acks\n" );
printf( "-----------------------------------------------------\n" );
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11112222;
const float DeltaTime = 0.001f;
const float TimeOut = 0.1f;
const unsigned int PacketCount = 100;
ReliableConnection client( ProtocolId, TimeOut );
ReliableConnection server( ProtocolId, TimeOut );
check( client.Start( ClientPort ) );
check( server.Start( ServerPort ) );
client.Connect( Address(127,0,0,1,ServerPort ) );
server.Listen();
bool clientAckedPackets[PacketCount];
bool serverAckedPackets[PacketCount];
for ( unsigned int i = 0; i < PacketCount; ++i )
{
clientAckedPackets[i] = false;
serverAckedPackets[i] = false;
}
bool allPacketsAcked = false;
while ( true )
{
if ( !client.IsConnecting() && client.ConnectFailed() )
break;
if ( allPacketsAcked )
break;
unsigned char packet[256];
for ( unsigned int i = 0; i < sizeof(packet); ++i )
packet[i] = (unsigned char) i;
server.SendPacket( packet, sizeof(packet) );
client.SendPacket( packet, sizeof(packet) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( bytes_read == sizeof(packet) );
for ( unsigned int i = 0; i < sizeof(packet); ++i )
check( packet[i] == (unsigned char) i );
}
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( bytes_read == sizeof(packet) );
for ( unsigned int i = 0; i < sizeof(packet); ++i )
check( packet[i] == (unsigned char) i );
}
int ack_count = 0;
unsigned int * acks = NULL;
client.GetReliabilitySystem().GetAcks( &acks, ack_count );
CHECK_ACKS
for ( int i = 0; i < ack_count; ++i )
{
unsigned int ack = acks[i];
if ( ack < PacketCount )
{
check( clientAckedPackets[ack] == false );
clientAckedPackets[ack] = true;
}
}
server.GetReliabilitySystem().GetAcks( &acks, ack_count );
CHECK_ACKS
for ( int i = 0; i < ack_count; ++i )
{
unsigned int ack = acks[i];
if ( ack < PacketCount )
{
check( serverAckedPackets[ack] == false );
serverAckedPackets[ack] = true;
}
}
unsigned int clientAckCount = 0;
unsigned int serverAckCount = 0;
for ( unsigned int i = 0; i < PacketCount; ++i )
{
clientAckCount += clientAckedPackets[i];
serverAckCount += serverAckedPackets[i];
}
allPacketsAcked = clientAckCount == PacketCount && serverAckCount == PacketCount;
client.Update( DeltaTime );
server.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
}
void test_ack_bits()
{
printf( "-----------------------------------------------------\n" );
printf( "test ack bits\n" );
printf( "-----------------------------------------------------\n" );
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11112222;
const float DeltaTime = 0.001f;
const float TimeOut = 0.1f;
const unsigned int PacketCount = 100;
ReliableConnection client( ProtocolId, TimeOut );
ReliableConnection server( ProtocolId, TimeOut );
check( client.Start( ClientPort ) );
check( server.Start( ServerPort ) );
client.Connect( Address(127,0,0,1,ServerPort ) );
server.Listen();
bool clientAckedPackets[PacketCount];
bool serverAckedPackets[PacketCount];
for ( unsigned int i = 0; i < PacketCount; ++i )
{
clientAckedPackets[i] = false;
serverAckedPackets[i] = false;
}
bool allPacketsAcked = false;
while ( true )
{
if ( !client.IsConnecting() && client.ConnectFailed() )
break;
if ( allPacketsAcked )
break;
unsigned char packet[256];
for ( unsigned int i = 0; i < sizeof(packet); ++i )
packet[i] = (unsigned char) i;
for ( int i = 0; i < 10; ++i )
{
client.SendPacket( packet, sizeof(packet) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( bytes_read == sizeof(packet) );
for ( unsigned int i = 0; i < sizeof(packet); ++i )
check( packet[i] == (unsigned char) i );
}
int ack_count = 0;
unsigned int * acks = NULL;
client.GetReliabilitySystem().GetAcks( &acks, ack_count );
CHECK_ACKS
for ( int i = 0; i < ack_count; ++i )
{
unsigned int ack = acks[i];
if ( ack < PacketCount )
{
check( !clientAckedPackets[ack] );
clientAckedPackets[ack] = true;
}
}
client.Update( DeltaTime * 0.1f );
}
server.SendPacket( packet, sizeof(packet) );
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( bytes_read == sizeof(packet) );
for ( unsigned int i = 0; i < sizeof(packet); ++i )
check( packet[i] == (unsigned char) i );
}
int ack_count = 0;
unsigned int * acks = NULL;
server.GetReliabilitySystem().GetAcks( &acks, ack_count );
CHECK_ACKS
for ( int i = 0; i < ack_count; ++i )
{
unsigned int ack = acks[i];
if ( ack < PacketCount )
{
check( !serverAckedPackets[ack] );
serverAckedPackets[ack] = true;
}
}
unsigned int clientAckCount = 0;
unsigned int serverAckCount = 0;
for ( unsigned int i = 0; i < PacketCount; ++i )
{
if ( clientAckedPackets[i] )
clientAckCount++;
if ( serverAckedPackets[i] )
serverAckCount++;
}
// printf( "client ack count = %d, server ack count = %d\n", clientAckCount, serverAckCount );
allPacketsAcked = clientAckCount == PacketCount && serverAckCount == PacketCount;
server.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
}
void test_packet_loss()
{
printf( "-----------------------------------------------------\n" );
printf( "test packet loss\n" );
printf( "-----------------------------------------------------\n" );
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11112222;
const float DeltaTime = 0.001f;
const float TimeOut = 0.1f;
const unsigned int PacketCount = 100;
ReliableConnection client( ProtocolId, TimeOut );
ReliableConnection server( ProtocolId, TimeOut );
client.SetPacketLossMask( 1 );
server.SetPacketLossMask( 1 );
check( client.Start( ClientPort ) );
check( server.Start( ServerPort ) );
client.Connect( Address(127,0,0,1,ServerPort ) );
server.Listen();
bool clientAckedPackets[PacketCount];
bool serverAckedPackets[PacketCount];
for ( unsigned int i = 0; i < PacketCount; ++i )
{
clientAckedPackets[i] = false;
serverAckedPackets[i] = false;
}
bool allPacketsAcked = false;
while ( true )
{
if ( !client.IsConnecting() && client.ConnectFailed() )
break;
if ( allPacketsAcked )
break;
unsigned char packet[256];
for ( unsigned int i = 0; i < sizeof(packet); ++i )
packet[i] = (unsigned char) i;
for ( int i = 0; i < 10; ++i )
{
client.SendPacket( packet, sizeof(packet) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( bytes_read == sizeof(packet) );
for ( unsigned int i = 0; i < sizeof(packet); ++i )
check( packet[i] == (unsigned char) i );
}
int ack_count = 0;
unsigned int * acks = NULL;
client.GetReliabilitySystem().GetAcks( &acks, ack_count );
CHECK_ACKS
for ( int i = 0; i < ack_count; ++i )
{
unsigned int ack = acks[i];
if ( ack < PacketCount )
{
check( !clientAckedPackets[ack] );
check ( ( ack & 1 ) == 0 );
clientAckedPackets[ack] = true;
}
}
client.Update( DeltaTime * 0.1f );
}
server.SendPacket( packet, sizeof(packet) );
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( bytes_read == sizeof(packet) );
for ( unsigned int i = 0; i < sizeof(packet); ++i )
check( packet[i] == (unsigned char) i );
}
int ack_count = 0;
unsigned int * acks = NULL;
server.GetReliabilitySystem().GetAcks( &acks, ack_count );
CHECK_ACKS
for ( int i = 0; i < ack_count; ++i )
{
unsigned int ack = acks[i];
if ( ack < PacketCount )
{
check( !serverAckedPackets[ack] );
check( ( ack & 1 ) == 0 );
serverAckedPackets[ack] = true;
}
}
unsigned int clientAckCount = 0;
unsigned int serverAckCount = 0;
for ( unsigned int i = 0; i < PacketCount; ++i )
{
if ( ( i & 1 ) != 0 )
{
check( clientAckedPackets[i] == false );
check( serverAckedPackets[i] == false );
}
if ( clientAckedPackets[i] )
clientAckCount++;
if ( serverAckedPackets[i] )
serverAckCount++;
}
allPacketsAcked = clientAckCount == PacketCount / 2 && serverAckCount == PacketCount / 2;
server.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
}
void test_sequence_wrap_around()
{
printf( "-----------------------------------------------------\n" );
printf( "test sequence wrap around\n" );
printf( "-----------------------------------------------------\n" );
const int ServerPort = 30000;
const int ClientPort = 30001;
const int ProtocolId = 0x11112222;
const float DeltaTime = 0.05f;
const float TimeOut = 1000.0f;
const unsigned int PacketCount = 256;
const unsigned int MaxSequence = 31; // [0,31]
ReliableConnection client( ProtocolId, TimeOut, MaxSequence );
ReliableConnection server( ProtocolId, TimeOut, MaxSequence );
check( client.Start( ClientPort ) );
check( server.Start( ServerPort ) );
client.Connect( Address(127,0,0,1,ServerPort ) );
server.Listen();
unsigned int clientAckCount[MaxSequence+1];
unsigned int serverAckCount[MaxSequence+1];
for ( unsigned int i = 0; i <= MaxSequence; ++i )
{
clientAckCount[i] = 0;
serverAckCount[i] = 0;
}
bool allPacketsAcked = false;
while ( true )
{
if ( !client.IsConnecting() && client.ConnectFailed() )
break;
if ( allPacketsAcked )
break;
unsigned char packet[256];
for ( unsigned int i = 0; i < sizeof(packet); ++i )
packet[i] = (unsigned char) i;
server.SendPacket( packet, sizeof(packet) );
client.SendPacket( packet, sizeof(packet) );
while ( true )
{
unsigned char packet[256];
int bytes_read = client.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( bytes_read == sizeof(packet) );
for ( unsigned int i = 0; i < sizeof(packet); ++i )
check( packet[i] == (unsigned char) i );
}
while ( true )
{
unsigned char packet[256];
int bytes_read = server.ReceivePacket( packet, sizeof(packet) );
if ( bytes_read == 0 )
break;
check( bytes_read == sizeof(packet) );
for ( unsigned int i = 0; i < sizeof(packet); ++i )
check( packet[i] == (unsigned char) i );
}
int ack_count = 0;
unsigned int * acks = NULL;
client.GetReliabilitySystem().GetAcks( &acks, ack_count );
CHECK_ACKS
for ( int i = 0; i < ack_count; ++i )
{
unsigned int ack = acks[i];
check( ack <= MaxSequence );
clientAckCount[ack] += 1;
}
server.GetReliabilitySystem().GetAcks( &acks, ack_count );
CHECK_ACKS
for ( int i = 0; i < ack_count; ++i )
{
unsigned int ack = acks[i];
check( ack <= MaxSequence );
serverAckCount[ack]++;
}
unsigned int totalClientAcks = 0;
unsigned int totalServerAcks = 0;
for ( unsigned int i = 0; i <= MaxSequence; ++i )
{
totalClientAcks += clientAckCount[i];
totalServerAcks += serverAckCount[i];
}
allPacketsAcked = totalClientAcks >= PacketCount && totalServerAcks >= PacketCount;
// note: test above is not very specific, we can do better...
client.Update( DeltaTime );
server.Update( DeltaTime );
}
check( client.IsConnected() );
check( server.IsConnected() );
}
void tests()
{
test_packet_queue();
test_reliability_system();
test_join();
test_join_timeout();
test_join_busy();
test_rejoin();
test_payload();
test_acks();
test_ack_bits();
test_packet_loss();
test_sequence_wrap_around();
printf( "-----------------------------------------------------\n" );
printf( "passed!\n" );
}
int main( int argc, char * argv[] )
{
if ( !InitializeSockets() )
{
printf( "failed to initialize sockets\n" );
return 1;
}
tests();
ShutdownSockets();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment