Skip to content

Instantly share code, notes, and snippets.

@adriangb
Forked from Kludex/main.py
Last active June 2, 2022 16:58
Show Gist options
  • Save adriangb/b21424afee4b2464399e3592fe86b601 to your computer and use it in GitHub Desktop.
Save adriangb/b21424afee4b2464399e3592fe86b601 to your computer and use it in GitHub Desktop.
Initial implementation of a Hook system to build middlewares.
from typing import (
Awaitable,
Iterable,
Mapping,
Optional,
Protocol,
Tuple,
TypeVar,
Union,
cast,
)
from asgiref.typing import (
ASGI3Application,
ASGIReceiveCallable,
ASGISendCallable,
ASGISendEvent,
ASGIReceiveEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
Scope,
)
from starlette.datastructures import MutableHeaders
class ScopeHook(Protocol):
def __call__(self, __scope: Scope) -> Union[Awaitable[None], None]:
...
class SendHook(Protocol):
def __call__(
self, __scope: Scope, __message: ASGISendEvent
) -> Union[Awaitable[None], None]:
...
class ReceiveHook(Protocol):
def __call__(
self, __scope: Scope, __message: ASGIReceiveEvent
) -> Union[Awaitable[None], None]:
...
class HookMiddleware:
def __init__(
self,
app: ASGI3Application,
scope_hook: Optional[ScopeHook] = None,
send_hook: Optional[SendHook] = None,
receive_hook: Optional[ReceiveHook] = None,
) -> None:
self._app = app
self._scope_hook = scope_hook
self._send_hook = send_hook
self._receive_hook = receive_hook
async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
async def wrapped_send(message: ASGISendEvent) -> None:
if self._send_hook:
maybe_awaitable = self._send_hook(scope, message)
if maybe_awaitable is not None:
await maybe_awaitable
await self._app(scope, receive, wrapped_send)
def http_response_start_filter(hook: SendHook) -> SendHook:
async def wrapped_hook(scope: Scope, message: ASGISendEvent) -> None:
if message["type"] == "http.response.start":
maybe_aw = hook(scope, message)
if maybe_aw is not None:
await maybe_aw
return wrapped_hook
def add_headers(
headers: Union[Iterable[Tuple[str, str]], Mapping[str, str]]
) -> SendHook:
if isinstance(headers, Mapping):
headers = cast("Mapping[str, str]", headers)
items = [(key, value) for key, value in headers.items()]
else:
items = [(key, value) for key, value in headers]
def wrapped_send(scope: Scope, message: ASGISendEvent) -> None:
resp_headers = MutableHeaders(scope=message) # type: ignore
for key, value in items:
resp_headers.append(key, value)
return wrapped_send
async def app(
scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
await send(
HTTPResponseStartEvent(type="http.response.start", status=200, headers=[])
)
await send(
HTTPResponseBodyEvent(
type="http.response.body", body=b"Hello, world!", more_body=False
)
)
wrapped_app = HookMiddleware(
app=app,
send_hook=http_response_start_filter(add_headers({"x-foo": "bar"})),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment