Skip to content

Instantly share code, notes, and snippets.

@adriangb
Last active June 22, 2022 06:13
Show Gist options
  • Save adriangb/2f003410b05783924d3d6bf3ff18ad6f to your computer and use it in GitHub Desktop.
Save adriangb/2f003410b05783924d3d6bf3ff18ad6f to your computer and use it in GitHub Desktop.
ASGI middleware to log request bodies
from typing import Iterable, List, Protocol, Generator
import pytest
from starlette.responses import Response
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Scope, Send, Receive, Message
class Logger(Protocol):
def info(self, message: str) -> None:
...
class BodyLoggingMiddleware:
def __init__(
self,
app: ASGIApp,
logger: Logger,
) -> None:
self.app = app
self.logger = logger
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
done = False
chunks: "List[bytes]" = []
async def wrapped_receive() -> Message:
nonlocal done
message = await receive()
if message["type"] == "http.disconnect":
done = True
return message
body = message.get("body", b"")
more_body = message.get("more_body", False)
if not more_body:
done = True
chunks.append(body)
return message
try:
await self.app(scope, wrapped_receive, send)
finally:
while not done:
await wrapped_receive()
self.logger.info(b"".join(chunks).decode()) # or somethin
async def consume_body_app(scope: Scope, receive: Receive, send: Send) -> None:
done = False
while not done:
msg = await receive()
done = "more_body" not in msg
await Response()(scope, receive, send)
async def consume_partial_body_app(scope: Scope, receive: Receive, send: Send) -> None:
await receive()
await Response()(scope, receive, send)
class TestException(Exception):
pass
async def consume_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None:
done = False
while not done:
msg = await receive()
done = "more_body" not in msg
raise TestException
async def consume_partial_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None:
await receive()
raise TestException
class TestLogger:
def __init__(self, recorder: List[str]) -> None:
self.recorder = recorder
def info(self, message: str) -> None:
self.recorder.append(message)
@pytest.mark.parametrize(
"chunks, expected_logs", [
([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]),
]
)
@pytest.mark.parametrize(
"app",
[consume_body_app, consume_partial_body_app]
)
def test_body_logging_middleware_no_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None:
logs: List[str] = []
client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs)))
def chunk_gen() -> Generator[bytes, None, None]:
yield from iter(chunks)
resp = client.get("/", data=chunk_gen())
assert resp.status_code == 200
assert logs == expected_logs
@pytest.mark.parametrize(
"chunks, expected_logs", [
([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]),
]
)
@pytest.mark.parametrize(
"app",
[consume_body_and_error_app, consume_partial_body_and_error_app]
)
def test_body_logging_middleware_with_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None:
logs: List[str] = []
client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs)))
def chunk_gen() -> Generator[bytes, None, None]:
yield from iter(chunks)
with pytest.raises(TestException):
client.get("/", data=chunk_gen())
assert logs == expected_logs
if __name__ == "__main__":
import os
pytest.main(args=[os.path.abspath(__file__)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment