Skip to content

Instantly share code, notes, and snippets.

@turtiesio
Created October 28, 2022 02:17
Show Gist options
  • Save turtiesio/4ed6a093f26d2ba6682764ddd826e116 to your computer and use it in GitHub Desktop.
Save turtiesio/4ed6a093f26d2ba6682764ddd826e116 to your computer and use it in GitHub Desktop.
import ssl
import asyncio
import logging
import weakref
from typing import Tuple
class SNIStore(ssl.SSLContext):
""" SSLContextContainer to store the SNI value """
def __new__(cls, protocol) -> "SNIStore":
self = super().__new__(cls, protocol)
self._store = weakref.WeakKeyDictionary()
self.set_servername_callback(self._cb)
return self
def _cb(self, sslobj, servername, sslctx):
self._store.update({sslobj: servername}) # Keep SNI
def get_sni(self, sslobj) -> str:
return self._store.get(sslobj)
def resolve(transport: asyncio.BaseTransport) -> Tuple[str, int]:
""" Resolve SNI and ALPN based on the request headers """
obj: ssl.SSLObject = transport.get_extra_info("ssl_object")
ctx: SNIStore = obj.context # Downcast
sni, alpn = ctx.get_sni(obj), obj.selected_alpn_protocol()
logging.info(
f"Client connected: {transport.get_extra_info('peername')} with SNI: {sni} ALPN: {alpn}")
# TODO: Do something with the SNI and ALPN values
if alpn in ['http/1.1', 'h2']:
return "127.0.0.1", 8000 # to do test with 'python -m http.server'
return "127.0.0.1", 22 # forward traffic to ssh server
async def forward(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
try:
while True:
data = await reader.read(4096)
if not data:
break
writer.write(data)
await writer.drain()
finally:
writer.close()
await writer.wait_closed()
async def client_connected_cb(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
""" Client connected callback. resolve sni and alpn and forward traffic """
up_reader, up_writer = await asyncio.open_connection(*resolve(writer.transport))
logging.info("Connected to upstream: %s",
up_writer.transport.get_extra_info("peername"))
logging.info(f"Fowarding data from %s to %s",
writer.transport.get_extra_info("peername"),
up_writer.transport.get_extra_info("peername"))
await asyncio.gather(
forward(reader, up_writer),
forward(up_reader, writer),
)
logging.info(f"conenction closed")
async def main():
CERT_PATH = "/etc/letsencrypt/live/EXAMPLE.com"
CERTFILE = CERT_PATH + "/fullchain.pem"
KEYFILE = CERT_PATH + "/privkey.pem"
ACCEPTED_PROTOCOLS = ["http/1.1", "h2", "ssh"]
HOST, PORT = "", 443
ctx = SNIStore(ssl.PROTOCOL_TLS_SERVER)
ctx.load_cert_chain(CERTFILE, KEYFILE)
ctx.set_alpn_protocols(ACCEPTED_PROTOCOLS)
logging.info("Starting server on %s:%s", HOST, PORT)
server = await asyncio.start_server(client_connected_cb, host=HOST, port=PORT, ssl=ctx)
await server.serve_forever()
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format=f"%(asctime)-15s | %(levelname)-8s | [%(name)s] %(message)s"
)
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment