Skip to content

Instantly share code, notes, and snippets.

@a-luna
Last active November 28, 2023 12:40
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save a-luna/e5f275de9a5b111f08b5e38be7042f04 to your computer and use it in GitHub Desktop.
Save a-luna/e5f275de9a5b111f08b5e38be7042f04 to your computer and use it in GitHub Desktop.
C# Extension Methods which wrap Socket APM methods in awaitable TPL wrappers. All methods return Task<Result> or Task<Result<T>> objects, the code for the Result class is given as well.
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result<Socket>> AcceptAsync(this Socket socket)
{
Socket transferSocket;
try
{
var acceptTask = Task<Socket>.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null);
transferSocket = await acceptTask.ConfigureAwait(false);
}
catch (SocketException ex)
{
return Result.Fail<Socket>($"{ex.Message} ({ex.GetType()})");
}
catch (InvalidOperationException ex)
{
return Result.Fail<Socket>($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok(transferSocket);
}
}
}
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result> ConnectWithTimeoutAsync(this Socket socket, string remoteIpAddress, int port, int timeoutMs)
{
try
{
var connectTask = Task.Factory.FromAsync(
socket.BeginConnect,
socket.EndConnect,
remoteIpAddress,
port,
null);
if (connectTask == await Task.WhenAny(connectTask, Task.Delay(timeoutMs)).ConfigureAwait(false))
{
await connectTask.ConfigureAwait(false);
}
else
{
throw new TimeoutException();
}
}
catch (SocketException ex)
{
return Result.Fail($"{ex.Message} ({ex.GetType()})");
}
catch (TimeoutException ex)
{
return Result.Fail($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok();
}
}
}
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result<int>> ReceiveAsync(
this Socket socket,
byte[] buffer,
int offset,
int size,
SocketFlags socketFlags)
{
int bytesReceived;
try
{
var asyncResult = socket.BeginReceive(buffer, offset, size, socketFlags, null, null);
bytesReceived = await Task<int>.Factory.FromAsync(asyncResult, _ => socket.EndReceive(asyncResult));
}
catch (SocketException ex)
{
return Result.Fail<int>($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok(bytesReceived);
}
}
}
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result<int>> ReceiveWithTimeoutAsync(
this Socket socket,
byte[] buffer,
int offset,
int size,
SocketFlags socketFlags,
int timeoutMs)
{
int bytesReceived;
try
{
var asyncResult = socket.BeginReceive(buffer, offset, size, socketFlags, null, null);
var receiveTask = Task<int>.Factory.FromAsync(asyncResult, _ => socket.EndReceive(asyncResult));
if (receiveTask == await Task.WhenAny(receiveTask, Task.Delay(timeoutMs)).ConfigureAwait(false))
{
bytesReceived = await receiveTask.ConfigureAwait(false);
}
else
{
throw new TimeoutException();
}
}
catch (SocketException ex)
{
return Result.Fail<int>($"{ex.Message} ({ex.GetType()})");
}
catch (TimeoutException ex)
{
return Result.Fail<int>($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok(bytesReceived);
}
}
}
namespace AaronLuna.Common.Result
{
public class Result
{
protected Result(bool success, string error)
{
Success = success;
Error = error;
}
public bool Success { get; }
public string Error { get; }
public bool Failure => !Success;
public static Result Fail(string message)
{
return new Result(false, message);
}
public static Result<T> Fail<T>(string message)
{
return new Result<T>(default(T), false, message);
}
public static Result Ok()
{
return new Result(true, string.Empty);
}
public static Result<T> Ok<T>(T value)
{
return new Result<T>(value, true, string.Empty);
}
public static Result Combine(params Result[] results)
{
foreach (Result result in results)
{
if (result.Failure)
{
return result;
}
}
return Ok();
}
}
public class Result<T> : Result
{
protected internal Result(T value, bool success, string error)
: base(success, error)
{
Value = value;
}
public T Value { get; }
}
}
namespace AaronLuna.Common.Result
{
using System;
public static class ResultExtensions
{
public static Result OnSuccess(this Result result, Func<Result> func)
{
if (result.Failure)
{
return result;
}
return func();
}
public static Result OnSuccess(this Result result, Action action)
{
if (result.Failure)
{
return result;
}
action();
return Result.Ok();
}
public static Result OnSuccess<T>(this Result<T> result, Action<T> action)
{
if (result.Failure)
{
return result;
}
action(result.Value);
return Result.Ok();
}
public static Result<T> OnSuccess<T>(this Result result, Func<T> func)
{
if (result.Failure)
{
return Result.Fail<T>(result.Error);
}
return Result.Ok(func());
}
public static Result<T> OnSuccess<T>(this Result result, Func<Result<T>> func)
{
if (result.Failure)
{
return Result.Fail<T>(result.Error);
}
return func();
}
public static Result OnSuccess<T>(this Result<T> result, Func<T, Result> func)
{
if (result.Failure)
{
return result;
}
return func(result.Value);
}
public static Result OnFailure(this Result result, Action action)
{
if (result.Failure)
{
action();
}
return result;
}
public static Result OnBoth(this Result result, Action<Result> action)
{
action(result);
return result;
}
public static T OnBoth<T>(this Result result, Func<Result, T> func)
{
return func(result);
}
}
}
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result> SendFileAsync(this Socket socket, string filePath)
{
try
{
await Task.Factory.FromAsync(socket.BeginSendFile, socket.EndSendFile, filePath, null).ConfigureAwait(false);
}
catch (SocketException ex)
{
return Result.Fail($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok();
}
}
}
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result> SendWithTimeoutAsync(
this Socket socket,
byte[] buffer,
int offset,
int size,
SocketFlags socketFlags,
int timeoutMs)
{
try
{
var asyncResult = socket.BeginSend(buffer, offset, size, socketFlags, null, null);
var sendTask = Task<int>.Factory.FromAsync(asyncResult, _ => socket.EndSend(asyncResult));
if (sendTask != await Task.WhenAny(sendTask, Task.Delay(timeoutMs)).ConfigureAwait(false))
{
throw new TimeoutException();
}
}
catch (SocketException ex)
{
return Result.Fail($"{ex.Message} ({ex.GetType()})");
}
catch (TimeoutException ex)
{
return Result.Fail($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok();
}
}
}
namespace AaronLuna.TplSockets
{
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public class TplSocketExample
{
const int BufferSize = 8 * 1024;
const int ConnectTimeoutMs = 3000;
const int ReceiveTimeoutMs = 3000;
const int SendTimeoutMs = 3000;
Socket _listenSocket;
Socket _clientSocket;
Socket _transferSocket;
public async Task<Result> SendAndReceiveTextMesageAsync()
{
_listenSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_transferSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
var serverPort = 7003;
var ipHostInfo = Dns.GetHostEntry(Dns.GetHostName());
var ipAddress =
ipHostInfo.AddressList.Select(ip => ip)
.FirstOrDefault(ip => ip.AddressFamily == AddressFamily.InterNetwork);
var ipEndPoint = new IPEndPoint(ipAddress, serverPort);
// Step 1: Bind a socket to a local TCP port and Listen for incoming connections
_listenSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_listenSocket.Bind(ipEndPoint);
_listenSocket.Listen(5);
// Step 2: Create a Task and accept the next incoming connection (ServerAcceptTask)
// NOTE: This call is not awaited so the method continues executing
var acceptTask = Task.Run(AcceptConnectionTask);
// Step 3: With another socket, connect to the bound socket and await the result (ClientConnectTask)
var connectResult =
await _clientSocket.ConnectWithTimeoutAsync(
ipAddress.ToString(),
serverPort,
ConnectTimeoutMs).ConfigureAwait(false);
// Step 4: Await the result of the ServerAcceptTask
var acceptResult = await acceptTask.ConfigureAwait(false);
// If either ServerAcceptTask or ClientConnectTask did not complete successfully,stop execution and report the error
if (Result.Combine(acceptResult, connectResult).Failure)
{
return Result.Fail("There was an error connecting to the server/accepting connection from the client");
}
// Step 5: Store the transfer socket if ServerAcceptTask was successful
_transferSocket = acceptResult.Value;
// Step 6: Create a Task and recieve data from the transfer socket (ServerReceiveBytesTask)
// NOTE: This call is not awaited so the method continues executing
var receiveTask = Task.Run(ReceiveMessageAsync);
// Step 7: Encode a string message before sending it to the server
var messageToSend = "this is a text message from a socket";
var messageData = Encoding.ASCII.GetBytes(messageToSend);
// Step 8: Send the message data to the server and await the result (ClientSendBytesTask)
var sendResult =
await _clientSocket.SendWithTimeoutAsync(
messageData,
0,
messageData.Length,
0,
SendTimeoutMs).ConfigureAwait(false);
// Step 9: Await the result of ServerReceiveBytesTask
var receiveResult = await receiveTask.ConfigureAwait(false);
// Step 10: If either ServerReceiveBytesTask or ClientSendBytesTask did not complete successfully,stop execution and report the error
if (Result.Combine(sendResult, receiveResult).Failure)
{
return Result.Fail("There was an error sending/receiving data from the client");
}
// Step 11: Compare the string that was received to what was sent, report an error if not matching
var messageReceived = receiveResult.Value;
if (messageToSend != messageReceived)
{
return Result.Fail("Error: Message received from client did not match what was sent");
}
// Step 12: Report the entire task was successful since all subtasks were successful
return Result.Ok();
}
async Task<Result<Socket>> AcceptConnectionTask()
{
return await _listenSocket.AcceptAsync().ConfigureAwait(false);
}
async Task<Result<string>> ReceiveMessageAsync()
{
var message = string.Empty;
var buffer = new byte[BufferSize];
var receiveResult =
await _transferSocket.ReceiveWithTimeoutAsync(
buffer,
0,
BufferSize,
0,
ReceiveTimeoutMs).ConfigureAwait(false);
var bytesReceived = receiveResult.Value;
if (bytesReceived == 0)
{
return Result.Fail<string>("Error reading message from client, no data was received");
}
message = Encoding.ASCII.GetString(buffer, 0, bytesReceived);
return Result.Ok(message);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment