Skip to content

Instantly share code, notes, and snippets.

@Towdium
Created September 21, 2021 11:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Towdium/0d2a4deaf841cd3e7bc39342ab346704 to your computer and use it in GitHub Desktop.
Save Towdium/0d2a4deaf841cd3e7bc39342ab346704 to your computer and use it in GitHub Desktop.
Prototype of DSAgent - part of DSTest project: the simplist client-driven udp tunneling
import asyncio
import socket
from asyncio.exceptions import IncompleteReadError
from enum import Enum
from queue import Queue
class Type(Enum):
PACKET = 1
SETUP = 2
SETUPACK = 3
def int2byte(val, length):
return int.to_bytes(val, length, 'big')
def byte2int(bytes):
return int.from_bytes(bytes, 'big')
async def readmsg(reader):
header = await reader.readexactly(2)
try:
while len(header) == 2:
typ = header[0]
length = header[1]
if length == 255:
length = byte2int(await reader.readexactly(2))
data = await reader.readexactly(length)
yield typ, data
header = await reader.readexactly(2)
except IncompleteReadError:
pass
except ConnectionError:
pass
def writepkt(writer, pkt):
l = len(pkt)
data = bytearray([Type.PACKET.value, 255, l // 256, l % 256]) + pkt
writer.write(data)
class Connection:
def __init__(self):
self.sock = None
self.connected = False
self.bound = False
self.writer = None
@staticmethod
def _checked(writer):
unchecked = writer.write
def write(*args, **kwargs):
if writer.is_closing():
raise ConnectionResetError('Connection lost')
unchecked(*args, **kwargs)
writer.write = write
@staticmethod
def _sock2buffer(sock, buffer):
def receive():
while True:
try:
buffer.put(sock.recv(4096))
except ConnectionError:
return
except OSError:
return
loop = asyncio.get_event_loop()
loop.run_in_executor(None, receive)
@staticmethod
def _buffer2writer(buffer, writer):
async def relay():
try:
while True:
if not buffer.empty():
pkt = buffer.get()
writepkt(writer, pkt)
else:
await writer.drain()
await asyncio.sleep(0.001)
except ConnectionError:
pass
asyncio.create_task(relay())
def _setup(self, data):
assert len(data) == 12
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sport = byte2int(data[4:6])
if sport != 0:
self.sock.bind((socket.inet_ntop(socket.AF_INET, data[0:4]), sport))
self.bound = True
buffer = Queue(4096)
Connection._sock2buffer(self.sock, buffer)
Connection._buffer2writer(buffer, self.writer)
dport = byte2int(data[10:12])
if dport != 0:
self.sock.connect((socket.inet_ntop(socket.AF_INET, data[6:10]), dport))
self.connected = True
self.writer.write(bytearray([Type.SETUPACK.value, 0]))
print(f'Connection established')
def cleanup(self):
if self.sock is not None:
if self.bound:
self.sock.shutdown(socket.SHUT_RD)
if self.connected:
self.sock.shutdown(socket.SHUT_WR)
self.sock.close()
async def serve(self, reader, writer):
Connection._checked(writer)
self.writer = writer
try:
async for typ, data in readmsg(reader):
if typ == Type.SETUP.value:
self._setup(data)
elif typ == Type.PACKET.value:
if self.connected:
self.sock.send(data)
finally:
self.cleanup()
print("Connection lost")
class Tunnel:
def __init__(self):
self.writer = None
self.buffer = None
async def send(self, packet):
writepkt(self.writer, packet)
await self.writer.drain()
async def receive(self):
return await self.buffer.get()
async def connect(self, agtip, agtport, rsip, rsport, rdip, rdport):
self.buffer = asyncio.Queue()
reader, writer = await asyncio.open_connection(agtip, agtport)
self.writer = writer
data = bytearray([Type.SETUP.value, 12])
data += socket.inet_pton(socket.AF_INET, rsip) + int2byte(rsport, 2)
data += socket.inet_pton(socket.AF_INET, rdip) + int2byte(rdport, 2)
writer.write(data)
await writer.drain()
setup = asyncio.Event()
async def read():
async for typ, data in readmsg(reader):
if typ == Type.PACKET.value:
if not self.buffer.full():
await self.buffer.put(data)
if typ == Type.SETUPACK.value:
setup.set()
asyncio.create_task(read())
await setup.wait()
def close(self):
self.server.close()
if __name__ == "__main__":
def exception_handler(loop, context):
if 'exception' in context and isinstance(context['exception'], KeyboardInterrupt):
return
loop.default_exception_handler(context)
async def run():
loop = asyncio.get_event_loop()
loop.set_exception_handler(exception_handler)
server = await asyncio.start_server(lambda r, w: Connection().serve(r, w), '0.0.0.0', 9999)
await server.serve_forever()
try:
asyncio.run(run())
except KeyboardInterrupt:
print('Terminated by user')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment