Skip to content

Instantly share code, notes, and snippets.

@kosorin
Last active November 12, 2020 01:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kosorin/131472f513bda2a8a16cf9dccbc8fa95 to your computer and use it in GitHub Desktop.
Save kosorin/131472f513bda2a8a16cf9dccbc8fa95 to your computer and use it in GitHub Desktop.
TCP socket wrapper with stream parser
public class TcpSocket : IDisposable
{
private readonly ProtocolProcessor _protocolProcessor = new ProtocolProcessor();
private readonly Socket _socket;
private readonly AutoResetEvent _connectEvent = new AutoResetEvent(false);
private readonly SocketAsyncEventArgs _receiveToken;
private readonly IObjectPool<SocketAsyncEventArgs> _sendTokenPool;
public TcpSocket(InternetEndPoint remoteEndPoint)
{
_socket = new Socket(remoteEndPoint.EndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
_receiveToken = CreateReceiveToken();
_sendTokenPool = new ObjectPool<SocketAsyncEventArgs>(CreateSendToken);
RemoteEndPoint = remoteEndPoint;
}
internal TcpSocket(Socket socket)
{
_socket = socket;
_receiveToken = CreateReceiveToken();
_sendTokenPool = new ObjectPool<SocketAsyncEventArgs>(CreateSendToken);
RemoteEndPoint = new InternetEndPoint(_socket.RemoteEndPoint);
StartReceive();
}
public InternetEndPoint RemoteEndPoint { get; }
public void Close()
{
if (_socket.Connected)
{
_socket.Shutdown(SocketShutdown.Both);
}
_socket.Close();
_socket.Dispose();
_connectEvent.Dispose();
_receiveToken.Dispose();
_sendTokenPool.Dispose();
}
public event TypedEventHandler<TcpSocket> Disconnected;
public void Connect()
{
try
{
_connectEvent.Reset();
var token = CreateConnectToken();
if (!_socket.ConnectAsync(token))
{
ProcessConnect();
}
_connectEvent.WaitOne();
if (token.SocketError != SocketError.Success)
{
throw new NetException("Could not connect.");
}
StartReceive();
}
catch (SocketException e)
{
throw new NetException("Could not connect.", e);
}
}
private SocketAsyncEventArgs CreateConnectToken()
{
var token = new SocketAsyncEventArgs();
token.RemoteEndPoint = RemoteEndPoint.EndPoint;
token.Completed += IO_Completed;
return token;
}
private void ProcessConnect()
{
_connectEvent.Set();
}
public event TcpDataReceivedHandler DataReceived;
private SocketAsyncEventArgs CreateReceiveToken()
{
var buffer = new byte[ushort.MaxValue];
var token = new SocketAsyncEventArgs();
token.Completed += IO_Completed;
token.SetBuffer(buffer, 0, buffer.Length);
return token;
}
private void StartReceive()
{
try
{
if (!_socket.ReceiveAsync(_receiveToken))
{
ProcessReceive(_receiveToken);
}
}
catch (ObjectDisposedException) { }
}
private void ProcessReceive(SocketAsyncEventArgs token)
{
if (token.SocketError == SocketError.Success && token.BytesTransferred > 0)
{
var reader = new NetDataReader(token.Buffer, token.Offset, token.BytesTransferred);
DataReceived?.Invoke(reader);
StartReceive();
}
else
{
// TODO: Disconnected or error?
Disconnected?.Invoke(this);
Close();
}
}
public void SendPacket(byte channelId, IChannelPacket packet)
{
var token = _sendTokenPool.Rent();
var writer = (NetDataWriter)token.UserToken;
try
{
WritePacket(writer, channelId, packet);
}
catch
{
_sendTokenPool.Return(token);
return;
}
token.SetBuffer(writer.Data, writer.Offset, writer.Length);
StartSend(token);
}
private SocketAsyncEventArgs CreateSendToken()
{
var token = new SocketAsyncEventArgs
{
UserToken = new NetDataWriter(),
};
token.Completed += IO_Completed;
return token;
}
private void StartSend(SocketAsyncEventArgs token)
{
try
{
if (!_socket.SendAsync(token))
{
ProcessSend(token);
}
}
catch (ObjectDisposedException)
{
_sendTokenPool.Return(token);
}
}
private void ProcessSend(SocketAsyncEventArgs token)
{
_sendTokenPool.Return(token);
}
private void WritePacket(NetDataWriter writer, byte channelId, IChannelPacket packet)
{
writer.Reset();
writer.WriteUShort(_protocolProcessor.GetTotalLength(packet));
_protocolProcessor.Write(writer, channelId, packet);
writer.Flush();
}
private void IO_Completed(object sender, SocketAsyncEventArgs token)
{
switch (token.LastOperation)
{
case SocketAsyncOperation.Connect:
ProcessConnect();
break;
case SocketAsyncOperation.Receive:
ProcessReceive(token);
break;
case SocketAsyncOperation.Send:
ProcessSend(token);
break;
default:
throw new InvalidOperationException("Unexpected socket async operation.");
}
}
private bool disposed;
public void Dispose()
{
Dispose(true);
}
private void Dispose(bool disposing)
{
if (!disposed)
{
if (disposing)
{
Close();
}
disposed = true;
}
}
}
public class TcpStreamParser
{
private enum State
{
Header,
Data,
}
private readonly NetDataReader _headerReader;
private readonly NetDataWriter _headerWriter;
private int _dataLength;
private readonly NetDataWriter _dataBuffer;
private State _state;
public TcpStreamParser()
{
var headerBuffer = new byte[sizeof(ushort)];
_headerReader = new NetDataReader(headerBuffer, 0, headerBuffer.Length);
_headerWriter = new NetDataWriter(headerBuffer, 0, headerBuffer.Length);
_dataLength = 0;
_dataBuffer = new NetDataWriter();
_state = State.Header;
}
/// <summary>
/// Gets data of last parsed packet.
/// </summary>
public INetBuffer Buffer => _dataBuffer;
/// <summary>
/// Reads data until it parse entire packet.
/// </summary>
/// <param name="reader"></param>
/// <returns>Returns <c>true</c> if new packet is available.</returns>
public bool Next(NetDataReader reader)
{
if (_state == State.Header)
{
ParseHeader(reader);
}
if (_state == State.Data)
{
return ParseData(reader);
}
return false;
}
private void ParseHeader(NetDataReader reader)
{
var headerLeft = _headerWriter.Capacity - _headerWriter.Length;
var headerToRead = Math.Min(headerLeft, reader.Length - reader.Position);
for (int i = 0; i < headerToRead; i++)
{
_headerWriter.WriteByte(reader.ReadByte());
}
headerLeft -= headerToRead;
if (headerLeft == 0)
{
_headerReader.Seek();
_dataLength = _headerReader.ReadUShort();
if (_dataLength < 0)
{
throw new NetException("Bad TCP data.");
}
else
{
_state = State.Data;
_dataBuffer.Reset();
}
}
}
private bool ParseData(NetDataReader reader)
{
var dataLeft = _dataLength - _dataBuffer.Length;
var dataToRead = Math.Min(dataLeft, reader.Length - reader.Position);
if (dataToRead > 0)
{
_dataBuffer.WriteBytes(reader.ReadBytes(dataToRead));
dataLeft -= dataToRead;
}
if (dataLeft == 0)
{
_state = State.Header;
_headerWriter.Reset();
return true;
}
return false;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment