Last active July 17, 2024 04:08
import inspect
from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints
from fastapi import APIRouter, Depends
from pydantic.typing import is_classvar
from starlette.routing import Route, WebSocketRoute
T = TypeVar("T")
CBV_CLASS_KEY = "__cbv_class__"
def cbv(router: APIRouter) -> Callable[[Type[T]], Type[T]]:
def decorator(cls: Type[T]) -> Type[T]:
return _cbv(router, cls)
return decorator
def _cbv(router: APIRouter, cls: Type[T]) -> Type[T]:
cbv_router = APIRouter()
functions = inspect.getmembers(cls, inspect.isfunction)
routes_by_endpoint = {
route.endpoint: route for route in router.routes if isinstance(route, (Route, WebSocketRoute))
for _, func in functions:
route = routes_by_endpoint.get(func)
if route is None:
_update_cbv_route_endpoint_signature(cls, route)
return cls
def _update_cbv_route_endpoint_signature(cls: Type[Any], route: Union[Route, WebSocketRoute]) -> None:
old_endpoint = route.endpoint
old_signature = inspect.signature(old_endpoint)
old_parameters: List[inspect.Parameter] = list(old_signature.parameters.values())
old_first_parameter = old_parameters[0]
new_first_parameter = old_first_parameter.replace(default=Depends(cls))
new_parameters = [new_first_parameter] + [
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:]
new_signature = old_signature.replace(parameters=new_parameters)
setattr(route.endpoint, "__signature__", new_signature)
def _init_cbv(cls: Type[Any]) -> None:
if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover
return # Already initialized
old_init: Callable[..., Any] = cls.__init__
old_signature = inspect.signature(old_init)
old_parameters = list(old_signature.parameters.values())[1:] # drop `self` parameter
new_parameters = [
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
dependency_names: List[str] = []
for name, hint in get_type_hints(cls).items():
if is_classvar(hint):
parameter_kwargs = {}
parameter_kwargs["default"] = getattr(cls, name, Ellipsis)
inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, **parameter_kwargs)
new_signature = old_signature.replace(parameters=new_parameters)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
for dep_name in dependency_names:
dep_value = kwargs.pop(dep_name)
setattr(self, dep_name, dep_value)
old_init(self, *args, **kwargs)
setattr(cls, "__signature__", new_signature)
setattr(cls, "__init__", new_init)
setattr(cls, CBV_CLASS_KEY, True)
from fastapi import APIRouter, Depends, FastAPI
from starlette.testclient import TestClient
from fastapi_cbv import cbv
router = APIRouter()
def dependency() -> int:
return 1
class CBV:
x: int = Depends(dependency)
def __init__(self, z: int = Depends(dependency)):
self.y = 1
self.z = z
@router.get("/", response_model=int)
def f(self) -> int:
return self.x + self.y + self.z
app = FastAPI()
client = TestClient(app)
assert client.get("/").content == b"3"
