Skip to content

Instantly share code, notes, and snippets.

@frankie567
Last active October 3, 2023 11:26
Show Gist options
  • Save frankie567/7aad9491f47cd7442cd8e1e9073f6457 to your computer and use it in GitHub Desktop.
Save frankie567/7aad9491f47cd7442cd8e1e9073f6457 to your computer and use it in GitHub Desktop.
Type-hinted sorting fields dependency for FastAPI
from enum import StrEnum
from fastapi import FastAPI
from .sorting import Sorting
class SortingField(StrEnum):
FIELD_A = "field_a"
FIELD_B = "field_b"
app = FastAPI()
@app.get("/sorting-literal")
async def test_sorting_literal(sorting: Sorting[Literal["field_a", "field_b"]]):
...
@app.get("/sorting-enum")
async def test_sorting_enum(sorting: Sorting[SortingFields]):
...
import enum
from functools import cached_property
from typing import (
Annotated,
Generic,
Literal,
TypeGuard,
TypeVar,
get_args,
get_origin,
)
from fastapi import Depends, Query
ASF = TypeVar("ASF", bound=str)
SortingType = list[tuple[ASF, bool]]
class SortingGetterInvalidConfiguration(Exception):
def __init__(self) -> None:
message = (
"The type you provided to Sorting is not supported. "
"Please use a `Literal` or an `Enum`."
)
super().__init__(message)
class SortingFieldNotAllowed(Exception):
def __init__(self, field: str, allowed_fields: set[str]) -> None:
self.field = field
self.allowed_fields = allowed_fields
message = (
f'You cannot sort by the field "{field}". '
f"Allowed fields are: {', '.join(allowed_fields)}"
)
super().__init__(message)
class SortingGetter(Generic[ASF]):
def __call__(self, sort: str = Query(None)) -> Sorting[ASF]:
sorting: Sorting[ASF] = []
for field in sort.split(","):
is_desc = False
if field.startswith("-"):
is_desc = True
field = field[1:]
if not self._is_allowed_field(field):
raise SortingFieldNotAllowed(field, self.allowed_fields)
sorting.append((field, is_desc))
return sorting
def _is_allowed_field(self, field: str) -> TypeGuard[ASF]:
return field in self.allowed_fields
@cached_property
def allowed_fields(self) -> set[str]:
generic_type = self.__orig_class__.__args__[0] # type: ignore
if get_origin(generic_type) is Literal:
return set(get_args(generic_type))
elif isinstance(generic_type, enum.EnumType):
return set(item.value for item in generic_type) # type: ignore
raise SortingGetterInvalidConfiguration()
Sorting = Annotated[SortingType[ASF], Depends(SortingGetter[ASF])]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment