Skip to content

Instantly share code, notes, and snippets.

@ashleysommer
Created September 30, 2021 06:18
Show Gist options
  • Save ashleysommer/477d2ee24f0421e56986a6f758397dcb to your computer and use it in GitHub Desktop.
Save ashleysommer/477d2ee24f0421e56986a6f758397dcb to your computer and use it in GitHub Desktop.
A simple reverse-proxy server in Sanic using HTTPX streaming responses
#!/bin/python3
#
"""Example Sanic Reverse Proxy"""
from urllib.parse import urlsplit, urlunsplit
from sanic import Sanic, Request
from sanic.response import stream, HTTPResponse
import httpx
env = {"hostname-from": "localhost", "port-from": 8000, "port-to": 8080}
host = Sanic("host")
@host.before_server_start
def setup(app, loop):
app.ctx.remote = f"http://{env['hostname-from']}:{env['port-from']}"
app.ctx.client = httpx.AsyncClient()
@host.route("/<path>", methods=("GET", "HEAD", "OPTIONS"))
async def proxy(request: Request, path):
app_ctx = request.app.ctx
to = "/".join((app_ctx.remote, path))
headers = [(k, v) for k, v in request.headers.items() if k != "host"]
headers.append(("x-forwarded-for", request.conn_info.client))
resp_ctx = app_ctx.client.stream(request.method, to, headers=headers, allow_redirects=False)
httpx_resp = await resp_ctx.__aenter__()
try:
if httpx_resp.status_code in (201, 301, 302, 303, 307, 308):
loc = httpx_resp.headers.get("location", "").split(",")[0]
if loc and "://" not in loc:
split_result = urlsplit(loc)
scheme, netloc, url, query, fragment = split_result
split_loc = netloc.split(":", 1)
if split_loc[0] == env['hostname-from'] and \
(len(split_loc) < 2 or split_loc[1] == str(env['port-from'])):
if len(split_loc) < 2:
if (scheme in ("https", "wss") and env['port-to'] != 443) or env['port-to'] != 80:
netloc = ":".join(request.host, str(env['port-to']))
else:
netloc = request.host
elif (scheme in ("https", "wss") and env['port-to'] == 443) or env['port-to'] == 80:
netloc = request.host
else:
netloc = ":".join(request.host, str(env['port-to']))
loc = urlunsplit((scheme, netloc, url, query, None))
del httpx_resp.headers["location"]
httpx_resp.headers["location"] = loc
resp_headers = httpx_resp.headers.multi_items()
except Exception:
await resp_ctx.__aexit__(None, None, None)
raise
if request.method in ("HEAD", "OPTIONS") or httpx_resp.status_code in (301, 302, 303, 307, 308):
await httpx_resp.aclose()
return HTTPResponse(None, status=httpx_resp.status_code, headers=resp_headers)
async def passthrough(response):
async for chunk in httpx_resp.aiter_raw():
await response.stream.send(chunk, end_stream=False)
await httpx_resp.aclose()
return stream(passthrough, status=httpx_resp.status_code, headers=resp_headers)
@host.route("/", methods=("GET", "HEAD", "OPTIONS"))
async def index(request):
return await proxy(request, "")
host.run(host="0.0.0.0", port=env['port-to'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment