Skip to content

Instantly share code, notes, and snippets.

@attakei

attakei/.readme.md

Last active Apr 7, 2020
Embed
What would you like to do?

FastAPI Middleware to forwarding hostname

Overview

FastAPI middleware to override Host header value by X-Forwarded-Host header value if it is exists.

FastAPI(Starlette) can make response used by X-Forwarded-Host header for slash-tailed redirection.

from typing import List, Tuple
from starlette.types import ASGIApp, Receive, Scope, Send
Headers = List[Tuple[bytes, bytes]]
class ForwardedHostMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
scope["headers"] = self.remap_headers(
scope["headers"], b"host", b"x-forwarded-host"
)
await self.app(scope, receive, send)
return
def remap_headers(self, src: Headers, before: bytes, after: bytes) -> Headers:
remapped = []
before_value = None
after_value = None
for header in src:
k, v = header
if k == before:
before_value = v
continue
elif k == after:
after_value = v
continue
remapped.append(header)
if after_value:
remapped.append((before, after_value))
elif before_value:
remapped.append((before, before_value))
return remapped
"""Tests for ``attakei_net.routes.uploads``.
"""
from typing import Tuple
from fastapi import APIRouter, FastAPI
from fastapi.testclient import TestClient
from attakei_net.middleware import ForwardedHostMiddleware
def configure_client() -> Tuple[FastAPI, TestClient]:
app = FastAPI()
app.add_middleware(ForwardedHostMiddleware)
client = TestClient(app)
return app, client
def test_forwarded():
app, client = configure_client()
router = APIRouter()
router.get("/{path}/")(lambda path: "OK")
app.include_router(router)
resp = client.get(
"/sample", allow_redirects=False, headers={"X-Forwarded-Host": "test2"}
)
assert resp.status_code == 307
assert resp.headers["location"].startswith("http://test2/")
def test_no_forwarded():
app, client = configure_client()
router = APIRouter()
router.get("/{path}/")(lambda path: "OK")
app.include_router(router)
resp = client.get(
"/sample", allow_redirects=False
)
assert resp.status_code == 307
assert resp.headers["location"].startswith("http://testserver/")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.