This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Concurrent; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Net; | |
using System.Net.Sockets; | |
using System.Threading; | |
namespace ChannelBonder | |
{ | |
public class Program | |
{ | |
static void Main(string[] args) | |
{ | |
if(args.Length < 2) | |
{ | |
Exit(Usage()); | |
} | |
if(args[0] == "client") | |
{ | |
Client.Start(args); | |
} | |
else if(args[0] == "server") | |
{ | |
Server.Start(args); | |
} | |
else | |
{ | |
Exit(Usage()); | |
} | |
} | |
private static string Usage() | |
{ | |
return "ChannelBonder.exe client local_port:destination_ip:destination_port [local_ip]\nChannelBonder.exe server destination_ip:destination_port [local_port]"; | |
} | |
private static void Exit(string message) | |
{ | |
Console.WriteLine(message); | |
Environment.Exit(0); | |
} | |
} | |
public class Client | |
{ | |
private static Random random = new Random(); | |
public static void Start(string[] args) | |
{ | |
var config = args[1].Split(':'); | |
var localPort = int.Parse(config[0]); | |
var remoteDns = config[1]; | |
var remotePort = int.Parse(config[2]); | |
Console.WriteLine($"Routing from {localPort} to {remoteDns}:{remotePort}"); | |
IPEndPoint localEndPoint = new IPEndPoint(0, localPort); | |
Socket listener = new Socket(localEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); | |
listener.Bind(localEndPoint); | |
listener.Listen(100); | |
while (true) | |
{ | |
try | |
{ | |
while (true) | |
{ | |
Socket socket = listener.Accept(); | |
Console.WriteLine("New connection accepted to be scattered"); | |
new Thread(() => Start(socket, args.Skip(2).ToArray(), remoteDns, remotePort, random.Next())).Start(); | |
} | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
} | |
} | |
} | |
private static void Start(Socket clientSocket, string[] localDnss, string remoteDns, int remotePort, int identifier) | |
{ | |
try | |
{ | |
var senderSockets = localDnss.Select(localDns => new ReconnectableSocket(localDns, remoteDns, remotePort, identifier, true)).ToList(); | |
var scatter = new Scatter(clientSocket, senderSockets); | |
scatter.Start(); | |
}catch(Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
} | |
} | |
} | |
public class Server | |
{ | |
private static ConcurrentDictionary<int, Scatter> scatters = new ConcurrentDictionary<int, Scatter>(); | |
public static void Start(string[] args) | |
{ | |
var config = args[1].Split(':'); | |
var remoteDns = config[0]; | |
var remotePort = int.Parse(config[1]); | |
var localPorts = args.Skip(2).Select(int.Parse); | |
foreach (var thread in localPorts.Select(localPort => | |
{ | |
var thread = new Thread(() => | |
{ | |
Console.WriteLine($"Routing from {localPort} to {remoteDns}:{remotePort}"); | |
IPEndPoint localEndPoint = new IPEndPoint(0, localPort); | |
Socket listener = new Socket(localEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); | |
while (true) | |
{ | |
try | |
{ | |
listener.Bind(localEndPoint); | |
listener.Listen(100); | |
while (true) | |
{ | |
Socket socket = listener.Accept(); | |
Console.WriteLine("New connection accepted to be gathered"); | |
new Thread(() => Start(socket, remoteDns, remotePort)).Start(); | |
} | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
} | |
} | |
}); | |
thread.Start(); | |
return thread; | |
}).ToArray()) | |
{ | |
thread.Join(); | |
} | |
} | |
private static void Start(Socket clientSocket, string remoteDns, int remotePort) | |
{ | |
try | |
{ | |
var identifierBuffer = new byte[4]; | |
if (clientSocket.Receive(identifierBuffer) != 4) | |
{ | |
Console.WriteLine("Invalid identifier received"); | |
return; | |
} | |
clientSocket.Send(identifierBuffer, 4, SocketFlags.None); | |
var identifier = BitConverter.ToInt32(identifierBuffer, 0); | |
Console.WriteLine("Received identifier " + identifier); | |
lock (scatters) | |
{ | |
Scatter scatter; | |
if(scatters.TryGetValue(identifier, out scatter)) | |
{ | |
Console.WriteLine("New gatherer for identifier " + identifier); | |
scatter.AddSocket(new ReconnectableSocket(clientSocket)); | |
} | |
else | |
{ | |
IPEndPoint remoteEP = new IPEndPoint(Dns.GetHostEntry(remoteDns).AddressList[0], remotePort); | |
Socket senderSocket = new Socket(remoteEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp); | |
senderSocket.Connect(remoteEP); | |
Console.WriteLine("Socket connected to {0} for identifier {1}", senderSocket.RemoteEndPoint.ToString(), identifier); | |
scatter = new Scatter(senderSocket, new List<ReconnectableSocket> { new ReconnectableSocket(clientSocket) }); | |
scatters.TryAdd(identifier, scatter); | |
new Thread(() => scatter.Start()).Start(); | |
} | |
} | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
} | |
} | |
} | |
public static class Helpers | |
{ | |
public static byte[] Enrich(byte[] buffer, int length, int counter) | |
{ | |
var newBuffer = new byte[length + 8]; | |
Array.Copy(buffer, 0, newBuffer, 8, length); | |
Array.Copy(BitConverter.GetBytes(counter), 0, newBuffer, 0, 4); | |
Array.Copy(BitConverter.GetBytes(length), 0, newBuffer, 4, 4); | |
return newBuffer; | |
} | |
public static void ReadFully(byte[] array, int offset, int length, Socket socket) | |
{ | |
int howMany = 0; | |
while (howMany < length) | |
{ | |
var read = socket.Receive(array, offset, length - howMany, SocketFlags.None); | |
if (read == 0) throw new Exception("Socket closed"); | |
offset += read; | |
howMany += read; | |
} | |
} | |
} | |
public class Scatter | |
{ | |
private Socket clientSocket; | |
private ConcurrentBag<ReconnectableSocket> senderSockets; | |
private int counter; | |
private ConcurrentDictionary<int, bool> received = new ConcurrentDictionary<int, bool>(); | |
public Scatter(Socket clientSocket, List<ReconnectableSocket> senderSockets) | |
{ | |
this.clientSocket = clientSocket; | |
this.senderSockets = new ConcurrentBag<ReconnectableSocket>(senderSockets); | |
} | |
public void Start() | |
{ | |
Thread clientThread = new Thread(() => | |
{ | |
try | |
{ | |
var buffer = new byte[100000]; | |
while (true) | |
{ | |
var read = clientSocket.Receive(buffer); | |
if (read == 0) | |
{ | |
throw new Exception("Socket was closed"); | |
} | |
var newBuffer = Helpers.Enrich(buffer, read, counter++); | |
foreach (var socket in senderSockets) | |
{ | |
try | |
{ | |
socket.Send(newBuffer, read + 8); | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
} | |
} | |
} | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
foreach (var socket in senderSockets) | |
{ | |
socket.Close(); | |
} | |
clientSocket.Close(); | |
} | |
}); | |
foreach (var socket in senderSockets) | |
{ | |
try | |
{ | |
Thread senderThread = new Thread(() => KeepReading(socket)); | |
senderThread.Start(); | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
} | |
} | |
clientThread.Start(); | |
clientThread.Join(); | |
foreach (var socket in senderSockets) | |
{ | |
socket.Close(); | |
} | |
} | |
private void KeepReading(ReconnectableSocket socket) | |
{ | |
try | |
{ | |
while (true) | |
{ | |
var tuple = socket.Receive(); | |
HandleReceived(tuple.Item1, tuple.Item2); | |
} | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
} | |
} | |
private void HandleReceived(byte[] buffer, int length) | |
{ | |
var messageNumber = BitConverter.ToInt32(buffer, 0); | |
if (received.TryAdd(messageNumber, true)) | |
{ | |
clientSocket.Send(buffer, 8, length - 8, SocketFlags.None); | |
} | |
} | |
public void AddSocket(ReconnectableSocket socket) | |
{ | |
this.senderSockets.Add(socket); | |
Thread senderThread = new Thread(() => KeepReading(socket)); | |
senderThread.Start(); | |
} | |
} | |
public class ReconnectableSocket | |
{ | |
private Socket _socket; | |
private string localDns; | |
private string remoteDns; | |
private int remotePort; | |
private int identifier; | |
private bool canReconnect; | |
private object locker = new object(); | |
private BlockingCollection<Tuple<byte[], int>> toSend = new BlockingCollection<Tuple<byte[], int>>(); | |
private BlockingCollection<Tuple<byte[], int>> received = new BlockingCollection<Tuple<byte[], int>>(); | |
private Thread sendingThread; | |
private Thread readingThread; | |
public ReconnectableSocket(string localDns, string remoteDns, int remotePort, int identifier, bool canReconnect) | |
{ | |
this.localDns = localDns; | |
this.remoteDns = remoteDns; | |
this.remotePort = remotePort; | |
this.identifier = identifier; | |
this.canReconnect = canReconnect; | |
sendingThread = new Thread(KeepSending); | |
sendingThread.Start(); | |
readingThread = new Thread(KeepReceiving); | |
readingThread.Start(); | |
} | |
public ReconnectableSocket(Socket socket) | |
{ | |
this._socket = socket; | |
sendingThread = new Thread(KeepSending); | |
sendingThread.Start(); | |
readingThread = new Thread(KeepReceiving); | |
readingThread.Start(); | |
} | |
private void KeepSending() | |
{ | |
foreach(var tuple in toSend.GetConsumingEnumerable()) | |
{ | |
while (true) | |
{ | |
Socket socket = null; | |
try | |
{ | |
socket = OpenIfNeeded(); | |
if(socket == null) | |
{ | |
return; | |
} | |
socket.Send(tuple.Item1, tuple.Item2, SocketFlags.None); | |
break; | |
} | |
catch(Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
ClearSocketForReconnect(socket); | |
} | |
} | |
} | |
} | |
private void KeepReceiving() | |
{ | |
while (true) | |
{ | |
Socket socket = null; | |
try | |
{ | |
socket = OpenIfNeeded(); | |
if (socket == null) | |
{ | |
return; | |
} | |
var buffer = new byte[100000]; | |
Helpers.ReadFully(buffer, 0, 8, socket); | |
int length = BitConverter.ToInt32(buffer, 4); | |
Helpers.ReadFully(buffer, 8, length, socket); | |
received.Add(Tuple.Create(buffer, length + 8)); | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
ClearSocketForReconnect(socket); | |
} | |
} | |
} | |
private void ClearSocketForReconnect(Socket socket) | |
{ | |
lock (locker) | |
{ | |
if (this._socket == socket) | |
{ | |
this._socket = null; | |
} | |
} | |
} | |
private Socket OpenIfNeeded() | |
{ | |
lock (locker) | |
{ | |
if (this._socket == null) | |
{ | |
return Open(); | |
}; | |
return this._socket; | |
} | |
} | |
public Socket Open() | |
{ | |
if (!canReconnect) | |
{ | |
return null; | |
} | |
while (true) | |
{ | |
try | |
{ | |
Socket senderSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); | |
EndPoint localEndpoint = new IPEndPoint(new IPAddress(localDns.Split('.').Select(v => (byte)int.Parse(v)).ToArray()), 0); | |
senderSocket.Bind(localEndpoint); | |
senderSocket.Connect(remoteDns, remotePort); | |
Console.WriteLine("Socket connected to {0}", senderSocket.RemoteEndPoint.ToString()); | |
var identifierBuffer = BitConverter.GetBytes(identifier); | |
senderSocket.Send(identifierBuffer, 4, SocketFlags.None); | |
if (senderSocket.Receive(identifierBuffer) != 4) | |
{ | |
Console.WriteLine("Invalid identifier received"); | |
continue; | |
} | |
this._socket = senderSocket; | |
break; | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
} | |
} | |
return this._socket; | |
} | |
public void Close() | |
{ | |
try | |
{ | |
lock (locker) | |
{ | |
Console.WriteLine("Closing at " + Environment.StackTrace); | |
this.canReconnect = false; | |
this._socket?.Close(); | |
this._socket = null; | |
} | |
} | |
catch (Exception e) | |
{ | |
Console.WriteLine("Exception " + e); | |
} | |
} | |
public void Send(byte[] buffer, int length) | |
{ | |
toSend.Add(Tuple.Create(buffer, length)); | |
} | |
public Tuple<byte[], int> Receive() | |
{ | |
return received.Take(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment