Skip to content

Instantly share code, notes, and snippets.

@zgrge
Forked from njsmith/proxy.py
Last active April 6, 2018 14:54
Show Gist options
  • Save zgrge/be704c80b8775024dcfd072a2d005e09 to your computer and use it in GitHub Desktop.
Save zgrge/be704c80b8775024dcfd072a2d005e09 to your computer and use it in GitHub Desktop.
import trio
import struct
import traceback
################################################################
# This part is a helper for reading N bytes from a stream
################################################################
class UnexpectedEOFError(Exception):
pass
# Utility function
async def receive_exactly(stream, num_bytes):
data = bytearray()
while len(data) < num_bytes:
chunk = await stream.receive_some(num_bytes - len(data))
if not chunk:
raise UnexpectedEOFError("other side closed connection")
data += chunk
assert len(data) == num_bytes
return data
################################################################
# This uses receive_exactly to read messages that start with a 2 byte length
# field
################################################################
async def read_message(stream):
header = await receive_exactly(stream, 2)
# This assumes the size field is big-endian; if it's little-endian use
# "<H" instead
(message_size,) = struct.unpack(">H", header)
body = await receive_exactly(stream, message_size)
return header + body
################################################################
# This is the actual rewrite logic, you need to fill it in!
################################################################
def rewrite_request(request):
# ... fill this in ...
return request
################################################################
# And this is the I/O code to glue it together
################################################################
async def handle_one_client(client_stream):
try:
async with await trio.open_tcp_stream(SERVER_HOST, SERVER_PORT) as server_stream:
request = await read_message(client_stream)
rewritten_request = rewrite_request(request)
await server_stream.send_all(rewritten_request)
response = await read_message(server_stream)
await client_stream.send_all(response)
except Exception:
# how do you want to handle errors? maybe log them and then throw them
# away?
# ... fill this in ...
print("Got an error:")
traceback.print_exc()
async def main():
await trio.serve_tcp(handle_one_client, PROXY_PORT)
trio.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment