Created
October 28, 2022 02:17
-
-
Save turtiesio/4ed6a093f26d2ba6682764ddd826e116 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ssl | |
import asyncio | |
import logging | |
import weakref | |
from typing import Tuple | |
class SNIStore(ssl.SSLContext): | |
""" SSLContextContainer to store the SNI value """ | |
def __new__(cls, protocol) -> "SNIStore": | |
self = super().__new__(cls, protocol) | |
self._store = weakref.WeakKeyDictionary() | |
self.set_servername_callback(self._cb) | |
return self | |
def _cb(self, sslobj, servername, sslctx): | |
self._store.update({sslobj: servername}) # Keep SNI | |
def get_sni(self, sslobj) -> str: | |
return self._store.get(sslobj) | |
def resolve(transport: asyncio.BaseTransport) -> Tuple[str, int]: | |
""" Resolve SNI and ALPN based on the request headers """ | |
obj: ssl.SSLObject = transport.get_extra_info("ssl_object") | |
ctx: SNIStore = obj.context # Downcast | |
sni, alpn = ctx.get_sni(obj), obj.selected_alpn_protocol() | |
logging.info( | |
f"Client connected: {transport.get_extra_info('peername')} with SNI: {sni} ALPN: {alpn}") | |
# TODO: Do something with the SNI and ALPN values | |
if alpn in ['http/1.1', 'h2']: | |
return "127.0.0.1", 8000 # to do test with 'python -m http.server' | |
return "127.0.0.1", 22 # forward traffic to ssh server | |
async def forward(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): | |
try: | |
while True: | |
data = await reader.read(4096) | |
if not data: | |
break | |
writer.write(data) | |
await writer.drain() | |
finally: | |
writer.close() | |
await writer.wait_closed() | |
async def client_connected_cb(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): | |
""" Client connected callback. resolve sni and alpn and forward traffic """ | |
up_reader, up_writer = await asyncio.open_connection(*resolve(writer.transport)) | |
logging.info("Connected to upstream: %s", | |
up_writer.transport.get_extra_info("peername")) | |
logging.info(f"Fowarding data from %s to %s", | |
writer.transport.get_extra_info("peername"), | |
up_writer.transport.get_extra_info("peername")) | |
await asyncio.gather( | |
forward(reader, up_writer), | |
forward(up_reader, writer), | |
) | |
logging.info(f"conenction closed") | |
async def main(): | |
CERT_PATH = "/etc/letsencrypt/live/EXAMPLE.com" | |
CERTFILE = CERT_PATH + "/fullchain.pem" | |
KEYFILE = CERT_PATH + "/privkey.pem" | |
ACCEPTED_PROTOCOLS = ["http/1.1", "h2", "ssh"] | |
HOST, PORT = "", 443 | |
ctx = SNIStore(ssl.PROTOCOL_TLS_SERVER) | |
ctx.load_cert_chain(CERTFILE, KEYFILE) | |
ctx.set_alpn_protocols(ACCEPTED_PROTOCOLS) | |
logging.info("Starting server on %s:%s", HOST, PORT) | |
server = await asyncio.start_server(client_connected_cb, host=HOST, port=PORT, ssl=ctx) | |
await server.serve_forever() | |
if __name__ == "__main__": | |
logging.basicConfig( | |
level=logging.INFO, | |
format=f"%(asctime)-15s | %(levelname)-8s | [%(name)s] %(message)s" | |
) | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment