-
-
Save adriangb/2f003410b05783924d3d6bf3ff18ad6f to your computer and use it in GitHub Desktop.
ASGI middleware to log request bodies
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Iterable, List, Protocol, Generator | |
import pytest | |
from starlette.responses import Response | |
from starlette.testclient import TestClient | |
from starlette.types import ASGIApp, Scope, Send, Receive, Message | |
class Logger(Protocol): | |
def info(self, message: str) -> None: | |
... | |
class BodyLoggingMiddleware: | |
def __init__( | |
self, | |
app: ASGIApp, | |
logger: Logger, | |
) -> None: | |
self.app = app | |
self.logger = logger | |
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | |
if scope["type"] != "http": | |
await self.app(scope, receive, send) | |
return | |
done = False | |
chunks: "List[bytes]" = [] | |
async def wrapped_receive() -> Message: | |
nonlocal done | |
message = await receive() | |
if message["type"] == "http.disconnect": | |
done = True | |
return message | |
body = message.get("body", b"") | |
more_body = message.get("more_body", False) | |
if not more_body: | |
done = True | |
chunks.append(body) | |
return message | |
try: | |
await self.app(scope, wrapped_receive, send) | |
finally: | |
while not done: | |
await wrapped_receive() | |
self.logger.info(b"".join(chunks).decode()) # or somethin | |
async def consume_body_app(scope: Scope, receive: Receive, send: Send) -> None: | |
done = False | |
while not done: | |
msg = await receive() | |
done = "more_body" not in msg | |
await Response()(scope, receive, send) | |
async def consume_partial_body_app(scope: Scope, receive: Receive, send: Send) -> None: | |
await receive() | |
await Response()(scope, receive, send) | |
class TestException(Exception): | |
pass | |
async def consume_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None: | |
done = False | |
while not done: | |
msg = await receive() | |
done = "more_body" not in msg | |
raise TestException | |
async def consume_partial_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None: | |
await receive() | |
raise TestException | |
class TestLogger: | |
def __init__(self, recorder: List[str]) -> None: | |
self.recorder = recorder | |
def info(self, message: str) -> None: | |
self.recorder.append(message) | |
@pytest.mark.parametrize( | |
"chunks, expected_logs", [ | |
([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]), | |
] | |
) | |
@pytest.mark.parametrize( | |
"app", | |
[consume_body_app, consume_partial_body_app] | |
) | |
def test_body_logging_middleware_no_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None: | |
logs: List[str] = [] | |
client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs))) | |
def chunk_gen() -> Generator[bytes, None, None]: | |
yield from iter(chunks) | |
resp = client.get("/", data=chunk_gen()) | |
assert resp.status_code == 200 | |
assert logs == expected_logs | |
@pytest.mark.parametrize( | |
"chunks, expected_logs", [ | |
([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]), | |
] | |
) | |
@pytest.mark.parametrize( | |
"app", | |
[consume_body_and_error_app, consume_partial_body_and_error_app] | |
) | |
def test_body_logging_middleware_with_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None: | |
logs: List[str] = [] | |
client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs))) | |
def chunk_gen() -> Generator[bytes, None, None]: | |
yield from iter(chunks) | |
with pytest.raises(TestException): | |
client.get("/", data=chunk_gen()) | |
assert logs == expected_logs | |
if __name__ == "__main__": | |
import os | |
pytest.main(args=[os.path.abspath(__file__)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment