Skip to content

Instantly share code, notes, and snippets.

@afish
Created Dec 21, 2021
Embed
What would you like to do?
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Threading;
namespace ChannelBonder
{
public class Program
{
static void Main(string[] args)
{
if(args.Length < 2)
{
Exit(Usage());
}
if(args[0] == "client")
{
Client.Start(args);
}
else if(args[0] == "server")
{
Server.Start(args);
}
else
{
Exit(Usage());
}
}
private static string Usage()
{
return "ChannelBonder.exe client local_port:destination_ip:destination_port [local_ip]\nChannelBonder.exe server destination_ip:destination_port [local_port]";
}
private static void Exit(string message)
{
Console.WriteLine(message);
Environment.Exit(0);
}
}
public class Client
{
private static Random random = new Random();
public static void Start(string[] args)
{
var config = args[1].Split(':');
var localPort = int.Parse(config[0]);
var remoteDns = config[1];
var remotePort = int.Parse(config[2]);
Console.WriteLine($"Routing from {localPort} to {remoteDns}:{remotePort}");
IPEndPoint localEndPoint = new IPEndPoint(0, localPort);
Socket listener = new Socket(localEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
listener.Bind(localEndPoint);
listener.Listen(100);
while (true)
{
try
{
while (true)
{
Socket socket = listener.Accept();
Console.WriteLine("New connection accepted to be scattered");
new Thread(() => Start(socket, args.Skip(2).ToArray(), remoteDns, remotePort, random.Next())).Start();
}
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
}
}
}
private static void Start(Socket clientSocket, string[] localDnss, string remoteDns, int remotePort, int identifier)
{
try
{
var senderSockets = localDnss.Select(localDns => new ReconnectableSocket(localDns, remoteDns, remotePort, identifier, true)).ToList();
var scatter = new Scatter(clientSocket, senderSockets);
scatter.Start();
}catch(Exception e)
{
Console.WriteLine("Exception " + e);
}
}
}
public class Server
{
private static ConcurrentDictionary<int, Scatter> scatters = new ConcurrentDictionary<int, Scatter>();
public static void Start(string[] args)
{
var config = args[1].Split(':');
var remoteDns = config[0];
var remotePort = int.Parse(config[1]);
var localPorts = args.Skip(2).Select(int.Parse);
foreach (var thread in localPorts.Select(localPort =>
{
var thread = new Thread(() =>
{
Console.WriteLine($"Routing from {localPort} to {remoteDns}:{remotePort}");
IPEndPoint localEndPoint = new IPEndPoint(0, localPort);
Socket listener = new Socket(localEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
while (true)
{
try
{
listener.Bind(localEndPoint);
listener.Listen(100);
while (true)
{
Socket socket = listener.Accept();
Console.WriteLine("New connection accepted to be gathered");
new Thread(() => Start(socket, remoteDns, remotePort)).Start();
}
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
}
}
});
thread.Start();
return thread;
}).ToArray())
{
thread.Join();
}
}
private static void Start(Socket clientSocket, string remoteDns, int remotePort)
{
try
{
var identifierBuffer = new byte[4];
if (clientSocket.Receive(identifierBuffer) != 4)
{
Console.WriteLine("Invalid identifier received");
return;
}
clientSocket.Send(identifierBuffer, 4, SocketFlags.None);
var identifier = BitConverter.ToInt32(identifierBuffer, 0);
Console.WriteLine("Received identifier " + identifier);
lock (scatters)
{
Scatter scatter;
if(scatters.TryGetValue(identifier, out scatter))
{
Console.WriteLine("New gatherer for identifier " + identifier);
scatter.AddSocket(new ReconnectableSocket(clientSocket));
}
else
{
IPEndPoint remoteEP = new IPEndPoint(Dns.GetHostEntry(remoteDns).AddressList[0], remotePort);
Socket senderSocket = new Socket(remoteEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
senderSocket.Connect(remoteEP);
Console.WriteLine("Socket connected to {0} for identifier {1}", senderSocket.RemoteEndPoint.ToString(), identifier);
scatter = new Scatter(senderSocket, new List<ReconnectableSocket> { new ReconnectableSocket(clientSocket) });
scatters.TryAdd(identifier, scatter);
new Thread(() => scatter.Start()).Start();
}
}
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
}
}
}
public static class Helpers
{
public static byte[] Enrich(byte[] buffer, int length, int counter)
{
var newBuffer = new byte[length + 8];
Array.Copy(buffer, 0, newBuffer, 8, length);
Array.Copy(BitConverter.GetBytes(counter), 0, newBuffer, 0, 4);
Array.Copy(BitConverter.GetBytes(length), 0, newBuffer, 4, 4);
return newBuffer;
}
public static void ReadFully(byte[] array, int offset, int length, Socket socket)
{
int howMany = 0;
while (howMany < length)
{
var read = socket.Receive(array, offset, length - howMany, SocketFlags.None);
if (read == 0) throw new Exception("Socket closed");
offset += read;
howMany += read;
}
}
}
public class Scatter
{
private Socket clientSocket;
private ConcurrentBag<ReconnectableSocket> senderSockets;
private int counter;
private ConcurrentDictionary<int, bool> received = new ConcurrentDictionary<int, bool>();
public Scatter(Socket clientSocket, List<ReconnectableSocket> senderSockets)
{
this.clientSocket = clientSocket;
this.senderSockets = new ConcurrentBag<ReconnectableSocket>(senderSockets);
}
public void Start()
{
Thread clientThread = new Thread(() =>
{
try
{
var buffer = new byte[100000];
while (true)
{
var read = clientSocket.Receive(buffer);
if (read == 0)
{
throw new Exception("Socket was closed");
}
var newBuffer = Helpers.Enrich(buffer, read, counter++);
foreach (var socket in senderSockets)
{
try
{
socket.Send(newBuffer, read + 8);
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
}
}
}
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
foreach (var socket in senderSockets)
{
socket.Close();
}
clientSocket.Close();
}
});
foreach (var socket in senderSockets)
{
try
{
Thread senderThread = new Thread(() => KeepReading(socket));
senderThread.Start();
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
}
}
clientThread.Start();
clientThread.Join();
foreach (var socket in senderSockets)
{
socket.Close();
}
}
private void KeepReading(ReconnectableSocket socket)
{
try
{
while (true)
{
var tuple = socket.Receive();
HandleReceived(tuple.Item1, tuple.Item2);
}
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
}
}
private void HandleReceived(byte[] buffer, int length)
{
var messageNumber = BitConverter.ToInt32(buffer, 0);
if (received.TryAdd(messageNumber, true))
{
clientSocket.Send(buffer, 8, length - 8, SocketFlags.None);
}
}
public void AddSocket(ReconnectableSocket socket)
{
this.senderSockets.Add(socket);
Thread senderThread = new Thread(() => KeepReading(socket));
senderThread.Start();
}
}
public class ReconnectableSocket
{
private Socket _socket;
private string localDns;
private string remoteDns;
private int remotePort;
private int identifier;
private bool canReconnect;
private object locker = new object();
private BlockingCollection<Tuple<byte[], int>> toSend = new BlockingCollection<Tuple<byte[], int>>();
private BlockingCollection<Tuple<byte[], int>> received = new BlockingCollection<Tuple<byte[], int>>();
private Thread sendingThread;
private Thread readingThread;
public ReconnectableSocket(string localDns, string remoteDns, int remotePort, int identifier, bool canReconnect)
{
this.localDns = localDns;
this.remoteDns = remoteDns;
this.remotePort = remotePort;
this.identifier = identifier;
this.canReconnect = canReconnect;
sendingThread = new Thread(KeepSending);
sendingThread.Start();
readingThread = new Thread(KeepReceiving);
readingThread.Start();
}
public ReconnectableSocket(Socket socket)
{
this._socket = socket;
sendingThread = new Thread(KeepSending);
sendingThread.Start();
readingThread = new Thread(KeepReceiving);
readingThread.Start();
}
private void KeepSending()
{
foreach(var tuple in toSend.GetConsumingEnumerable())
{
while (true)
{
Socket socket = null;
try
{
socket = OpenIfNeeded();
if(socket == null)
{
return;
}
socket.Send(tuple.Item1, tuple.Item2, SocketFlags.None);
break;
}
catch(Exception e)
{
Console.WriteLine("Exception " + e);
ClearSocketForReconnect(socket);
}
}
}
}
private void KeepReceiving()
{
while (true)
{
Socket socket = null;
try
{
socket = OpenIfNeeded();
if (socket == null)
{
return;
}
var buffer = new byte[100000];
Helpers.ReadFully(buffer, 0, 8, socket);
int length = BitConverter.ToInt32(buffer, 4);
Helpers.ReadFully(buffer, 8, length, socket);
received.Add(Tuple.Create(buffer, length + 8));
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
ClearSocketForReconnect(socket);
}
}
}
private void ClearSocketForReconnect(Socket socket)
{
lock (locker)
{
if (this._socket == socket)
{
this._socket = null;
}
}
}
private Socket OpenIfNeeded()
{
lock (locker)
{
if (this._socket == null)
{
return Open();
};
return this._socket;
}
}
public Socket Open()
{
if (!canReconnect)
{
return null;
}
while (true)
{
try
{
Socket senderSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
EndPoint localEndpoint = new IPEndPoint(new IPAddress(localDns.Split('.').Select(v => (byte)int.Parse(v)).ToArray()), 0);
senderSocket.Bind(localEndpoint);
senderSocket.Connect(remoteDns, remotePort);
Console.WriteLine("Socket connected to {0}", senderSocket.RemoteEndPoint.ToString());
var identifierBuffer = BitConverter.GetBytes(identifier);
senderSocket.Send(identifierBuffer, 4, SocketFlags.None);
if (senderSocket.Receive(identifierBuffer) != 4)
{
Console.WriteLine("Invalid identifier received");
continue;
}
this._socket = senderSocket;
break;
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
}
}
return this._socket;
}
public void Close()
{
try
{
lock (locker)
{
Console.WriteLine("Closing at " + Environment.StackTrace);
this.canReconnect = false;
this._socket?.Close();
this._socket = null;
}
}
catch (Exception e)
{
Console.WriteLine("Exception " + e);
}
}
public void Send(byte[] buffer, int length)
{
toSend.Add(Tuple.Create(buffer, length));
}
public Tuple<byte[], int> Receive()
{
return received.Take();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment