Skip to content

Instantly share code, notes, and snippets.

@njsmith
Created March 30, 2018 08:06
Show Gist options
  • Save njsmith/072e7bd83b571b55ef70d5d6d894ee77 to your computer and use it in GitHub Desktop.
Save njsmith/072e7bd83b571b55ef70d5d6d894ee77 to your computer and use it in GitHub Desktop.
import trio
import struct
import traceback
################################################################
# This part is a helper for reading N bytes from a stream
################################################################
class UnexpectedEOFError(Exception):
pass
# Utility function
async def receive_exactly(stream, num_bytes):
data = bytearray()
while len(data) < num_bytes:
chunk = await stream.receive_some(num_bytes - len(data))
if not chunk:
raise UnexpectedEOFError("other side closed connection")
data += chunk
assert len(data) == num_bytes
return data
################################################################
# This uses receive_exactly to read messages that start with a 2 byte length
# field
################################################################
async def read_message(stream):
header = await receive_exactly(stream, 2)
# This assumes the size field is little-endian; if it's big-endian use
# ">H" instead
(message_size,) = struct.unpack("<H", header)
body = await receive_exactly(stream, message_size)
return header + body
################################################################
# This is the actual rewrite logic, you need to fill it in!
################################################################
def rewrite_request(request):
# ... fill this in ...
return request
################################################################
# And this is the I/O code to glue it together
################################################################
async def handle_one_client(client_stream):
try:
async with await trio.open_tcp_stream(SERVER_HOST, SERVER_PORT) as server_stream:
request = await read_message(client_stream)
rewritten_request = rewrite_request(request)
await server_stream.send_all(rewritten_request)
response = await read_message(server_stream)
await client_stream.send_all(response)
except Exception:
# how do you want to handle errors? maybe log them and then throw them
# away?
# ... fill this in ...
print("Got an error:")
traceback.print_exc()
async def main():
await trio.serve_tcp(PROXY_PORT, handle_one_client)
trio.run(main)
@smurfix
Copy link

smurfix commented Apr 14, 2018

Should we add receive_exactly to trio?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment