Skip to content

Instantly share code, notes, and snippets.

@exhuma
Last active April 14, 2024 10:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save exhuma/ee51a71e186d07041eded0e90b7c8fbd to your computer and use it in GitHub Desktop.
Save exhuma/ee51a71e186d07041eded0e90b7c8fbd to your computer and use it in GitHub Desktop.
Automatic Pagination with FastAPI and SQLAlchemy
from fastapi import Depends, FastAPI
from pydantic import BaseModel
from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Query, Session
from pagination import PaginatedList, paginated_get
Base = declarative_base()
# --- The SQLAlchemy Model ----------------------------------------------------
class CustomerDbModel(Base):
__tablename__ = "customers"
id = Column(Integer, primary_key=True, index=True)
name = Column(String)
# --- Core Business Logic -----------------------------------------------------
# This is a simple example of a function that returns a SQLAlchemy query. A
# real-world example may be more complex, use joins, filters, etc.
# The main point here is that it returns an SQLAlchemy query, which does not
# need to know about pagination. This ensures that the core business logic is
# not coupled to the pagination logic and other API-layer concerns.
# While it is almost guaranteed that a real-world implementation will require
# values from the user-request (f.ex. for filtering, sorting, etc.), it is
# important to separate *how* the user-request is transformed to the query from
# the *how* it is handled on the back-end. Using this pattern, this
# "mapping/decoupling" can be handled solely on the FastAPI route definition.
def get_customers(db: Session) -> Query[CustomerDbModel]:
result = db.query(CustomerDbModel)
return result
# --- FastAPI -----------------------------------------------------------------
def get_db():
db = Session()
try:
yield db
finally:
db.close()
class CustomerApiModel(BaseModel):
name: str
APP = FastAPI()
@paginated_get(
APP,
"/customers",
# Using "model_validate" here is a simple example. It could be made
# arbitrarily complex.
api_mapper=CustomerApiModel.model_validate,
response_model=PaginatedList[CustomerApiModel],
)
def customers(db: Session = Depends(get_db)):
# NOTE: This uses the `get_customers` function from the core business logic,
# returning a non-paginated SQLAlchemy query.
# The "paginated_get" decorator will handle the pagination and
# automatically inject the necessary HTTP query parameters.
return get_customers(db)
"""
This module provides a decorator that paginates a SQLAlchemy query and returns
the results as a Pydantic model.
"""
import inspect
from typing import Any, Callable
from fastapi import APIRouter, FastAPI
from pydantic import BaseModel
from sqlalchemy.orm import Query
_PARAMETER_ORDER = {
inspect.Parameter.POSITIONAL_ONLY: 1,
inspect.Parameter.POSITIONAL_OR_KEYWORD: 2,
inspect.Parameter.VAR_POSITIONAL: 3,
inspect.Parameter.KEYWORD_ONLY: 4,
inspect.Parameter.VAR_KEYWORD: 5,
}
class PaginatedList[T](BaseModel):
items: list[T]
page: int
total_items: int
def paginated_get[
T, U: BaseModel
](
app: FastAPI | APIRouter,
path: str,
api_mapper: Callable[[T], U],
*args,
default_page_size: int = 25,
**kwargs,
):
"""
Register a route with FastAPI that returns automatically paginates an
SQLAlchemy .
:param app: The FastAPI application or router to register the route with.
:param path: The path to register the route at.
:param api_mapper: A function that maps the SQLAlchemy model (of one item of
the query) to a Pydantic model.
:param default_page_size: The default number of items per page.
:param args: Additional positional arguments to pass to the FastAPI route.
:param kwargs: Additional keyword arguments to pass to the FastAPI route.
Example:
>>> from fastapi import FastAPI
>>> from pydantic import BaseModel
>>> from sqlalchemy.orm import Query
>>> from flightlogs.pagination import PaginatedList, paginated_get
>>>
>>> app = FastAPI()
>>>
>>> class CustomerApiModel(BaseModel):
... name: str
>>>
>>> def mapper(data: CustomerDbModel) -> CustomerApiModel:
... return CustomerApiModel(name=data.label)
>>>
>>> def get_customers(db: Session) -> Query[CustomerDbModel]:
... result = db.query(CustomerDbModel)
... return result
>>>
>>> @paginated_get(
... app,
... "/customers",
... api_mapper=mapper,
... response_model=PaginatedList[CustomerApiModel],
... )
... def customers(db: Session = Depends(get_db)):
... return get_customers(db)
"""
def decorator(
func: Callable[..., Query[Any]]
) -> Callable[..., PaginatedList[U]]:
"""
Decorator that registers a route with FastAPI that returns a paginated
list of items.
:param func: The function that returns the SQLAlchemy query to paginate.
"""
def wrapper(
*inner_args,
page: int = 1,
per_page: int = default_page_size,
**inner_kwargs,
) -> PaginatedList[U]:
data = func(*inner_args, **inner_kwargs)
total_count = data.count()
data = data.limit(per_page)
data = data.offset((page - 1) * per_page)
items = [api_mapper(row) for row in data]
return PaginatedList(
items=items,
page=page,
total_items=total_count,
)
# We need to merge the signatures of the original function and the
# wrapper function. This exposes the parameters to FastAPI. This enable
# all the functionality of FastAPI, such as automatic validation and
# documentation.
#
# When doing this, we need to make sure that the parameters are in the
# correct order for Python itself.
func_sig = inspect.signature(func)
wrapper_sig = inspect.signature(wrapper)
wrapper_params = list(wrapper_sig.parameters.values())
wrapper_params.extend(func_sig.parameters.values())
wrapper_params = [
p
for p in sorted(
wrapper_params,
key=lambda p: _PARAMETER_ORDER.get(p.kind, 0),
)
if p.kind
not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
)
]
wrapper_sig = wrapper_sig.replace(parameters=wrapper_params)
# We can't use "functools.wraps" here, because it would copy the
# signature of the wrapper function, undoing the work we did above.
wrapper.__signature__ = wrapper_sig
wrapper.__doc__ = func.__doc__
wrapper.__name__ = func.__name__
route = app.get(path, *args, **kwargs)(wrapper)
return route
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment