FastAPI middleware to override Host
header value by X-Forwarded-Host
header value if it is exists.
FastAPI(Starlette) can make response used by X-Forwarded-Host
header for slash-tailed redirection.
from typing import List, Tuple | |
from starlette.types import ASGIApp, Receive, Scope, Send | |
Headers = List[Tuple[bytes, bytes]] | |
class ForwardedHostMiddleware: | |
def __init__(self, app: ASGIApp): | |
self.app = app | |
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | |
if scope["type"] not in ("http", "websocket"): | |
await self.app(scope, receive, send) | |
return | |
scope["headers"] = self.remap_headers( | |
scope["headers"], b"host", b"x-forwarded-host" | |
) | |
await self.app(scope, receive, send) | |
return | |
def remap_headers(self, src: Headers, before: bytes, after: bytes) -> Headers: | |
remapped = [] | |
before_value = None | |
after_value = None | |
for header in src: | |
k, v = header | |
if k == before: | |
before_value = v | |
continue | |
elif k == after: | |
after_value = v | |
continue | |
remapped.append(header) | |
if after_value: | |
remapped.append((before, after_value)) | |
elif before_value: | |
remapped.append((before, before_value)) | |
return remapped |
"""Tests for ``attakei_net.routes.uploads``. | |
""" | |
from typing import Tuple | |
from fastapi import APIRouter, FastAPI | |
from fastapi.testclient import TestClient | |
from attakei_net.middleware import ForwardedHostMiddleware | |
def configure_client() -> Tuple[FastAPI, TestClient]: | |
app = FastAPI() | |
app.add_middleware(ForwardedHostMiddleware) | |
client = TestClient(app) | |
return app, client | |
def test_forwarded(): | |
app, client = configure_client() | |
router = APIRouter() | |
router.get("/{path}/")(lambda path: "OK") | |
app.include_router(router) | |
resp = client.get( | |
"/sample", allow_redirects=False, headers={"X-Forwarded-Host": "test2"} | |
) | |
assert resp.status_code == 307 | |
assert resp.headers["location"].startswith("http://test2/") | |
def test_no_forwarded(): | |
app, client = configure_client() | |
router = APIRouter() | |
router.get("/{path}/")(lambda path: "OK") | |
app.include_router(router) | |
resp = client.get( | |
"/sample", allow_redirects=False | |
) | |
assert resp.status_code == 307 | |
assert resp.headers["location"].startswith("http://testserver/") |