-
-
Save wshayes/8e2341bb245a4125b294f6bd5da2df2d to your computer and use it in GitHub Desktop.
[FastAPI app with response shape wrapping] #fastapi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From FastAPI Gitter: | |
dmontagu @dmontagu 00:14 | |
@wshayes @intrepidOlivia here is a fully self-contained working implementation of a wrapped response | |
https://gist.github.com/dmontagu/9abbeb86fd53556e2c3d9bf8908f81bb | |
you can set context data and errors on the starlette Request and they get added to the response at the end | |
(@intrepidOlivia if you save the contents of that gist to main.py it should be possible to run via uvicorn main:app --reload) | |
if the endpoint failed in an expected way and you want to return a StandardResponse with no data field, you provide the type of StandardResponse you want to return instead of an instance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import lru_cache | |
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union | |
from pydantic import BaseModel, create_model | |
from starlette.requests import Request | |
from starlette.responses import JSONResponse | |
from fastapi import FastAPI | |
from fastapi.encoders import jsonable_encoder | |
class Error(BaseModel): | |
kind: str | |
detail: str | |
ErrorsT = List[Error] | |
ContextT = Dict[str, Any] | |
T = TypeVar("T", bound=BaseModel) | |
@lru_cache() | |
def get_standard_response_model(cls: Type[BaseModel]) -> Type[BaseModel]: | |
assert issubclass(cls, BaseModel) | |
return create_model( | |
f"StandardData[{cls.__name__}]", context=(ContextT, ...), errors=(ErrorsT, ...), data=(Optional[cls], None) | |
) | |
class StandardResponse(Generic[T]): | |
def __class_getitem__(cls, item): | |
return get_standard_response_model(item) | |
def __new__(cls, data: Union[T, Type[T]], request: Optional[Request] = None) -> "StandardResponse[T]": | |
if request is not None: | |
context = request.state.context # type: ignore | |
errors = request.state.errors # type: ignore | |
else: | |
context = {} | |
errors = [] | |
# noinspection PyUnusedLocal | |
response_data: Optional[BaseModel] | |
if isinstance(data, BaseModel): | |
response_type = get_standard_response_model(type(data)) | |
response_data = data | |
else: | |
assert issubclass(data, BaseModel) | |
response_type = get_standard_response_model(data) | |
response_data = None | |
# noinspection PyTypeChecker | |
return response_type(context=context, errors=errors, data=response_data) # type: ignore | |
class MyResponse1(BaseModel): | |
text: str | |
class MyResponse2(BaseModel): | |
number: int | |
app = FastAPI() | |
@app.get("/1", response_model=StandardResponse[MyResponse1]) | |
def get_response_1(request: Request) -> StandardResponse[MyResponse1]: | |
add_context(request, "endpoint", "1") | |
response = MyResponse1(text="hello world") | |
return StandardResponse(response, request=request) | |
@app.get("/2", response_model=StandardResponse[MyResponse2]) | |
def get_response_2(request: Request) -> StandardResponse[MyResponse2]: | |
add_context(request, "endpoint", "2") | |
response = MyResponse2(number=42) | |
return StandardResponse(response, request=request) | |
@app.get("/expected-error", response_model=StandardResponse[MyResponse1]) | |
def get_expected_error(request: Request): | |
add_context(request, "endpoint", "expected-error") | |
add_error(request, kind="expected", detail="expected error") | |
return StandardResponse(MyResponse1, request) | |
@app.get("/unexpected-error", response_model=StandardResponse[MyResponse1]) | |
def get_unexpected_error(request: Request): | |
add_context(request, "endpoint", "unexpected-error") | |
add_error(request, kind="expected", detail="expected error") | |
raise RuntimeError("whoops") | |
def add_context(request: Request, key: str, value: Any): | |
request.state.context[key] = value # type: ignore | |
def add_error(request: Request, kind: str, detail: str): | |
request.state.errors.append(Error(kind=kind, detail=detail)) # type: ignore | |
@app.middleware("http") | |
async def context_middleware(request: Request, call_next): | |
request.state.context: Dict[str, Any] = {} # type: ignore | |
request.state.errors: List[Error] = [] # type: ignore | |
return await call_next(request) | |
@app.exception_handler(Exception) | |
async def validation_exception_handler(request, exc): | |
response = JSONResponse( | |
jsonable_encoder( | |
{ | |
"errors": [Error(kind=type(exc).__name__, detail=str(exc))] + request.state.errors, | |
"context": request.state.context, | |
"data": None, | |
} | |
), | |
status_code=500, | |
) | |
return response |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment