Skip to content

Instantly share code, notes, and snippets.

@Object905
Last active May 15, 2021 17:55
Show Gist options
  • Save Object905/324ad346be59d6cbe8fa83aac58e9429 to your computer and use it in GitHub Desktop.
Save Object905/324ad346be59d6cbe8fa83aac58e9429 to your computer and use it in GitHub Desktop.
Generic pagination for fastapi and sqlalchemy (easyli adapted to other orms/db drivers)
import math
from typing import Generic, List, Optional, TypeVar
from fastapi import Query
from pydantic.generics import GenericModel
from starlette.datastructures import URL
DEFAULT_PAGE_SIZE = 25
MAX_PAGE_SIZE = 100
ItemT = TypeVar("ItemT")
class Paginated(GenericModel, Generic[ItemT]):
item_count: int
page_count: int
current_page: int
previous: Optional[str]
next: Optional[str]
data: List[ItemT]
class PageInfo:
def __init__(
self,
page: int = Query(0, ge=0),
page_size: int = Query(
DEFAULT_PAGE_SIZE,
le=MAX_PAGE_SIZE,
alias="pageSize",
),
):
self.page = page
self.page_size = page_size
def paginate(self, sqla_query, url: URL) -> Optional[Paginated[ItemT]]:
items = self.get_current_page_items(sqla_query)
item_count = self.get_total_item_count(sqla_query)
paging_kwargs = dict(
item_count=item_count,
page_count=self.total_page_count(item_count),
current_page=self.page,
)
if self.has_next_page(len(items), item_count):
paging_kwargs["next"] = self.next_page_url(url)
if self.has_previous_page():
paging_kwargs["previous"] = self.previous_page_url(url)
return Paginated(data=items, **paging_kwargs)
@staticmethod
def get_total_item_count(sqla_query):
return sqla_query.order_by(None).count()
def get_current_page_items(self, sqla_query):
return sqla_query.limit(self.page_size).offset(self.offset()).all()
def offset(self) -> int:
return self.page * self.page_size
def total_page_count(self, total_item_count: int) -> int:
return int(math.ceil(total_item_count / float(self.page_size)))
def has_next_page(self, current_page_item_count: int, total_item_count: int):
seen_item_count = self.offset() + current_page_item_count
return seen_item_count < total_item_count
def has_previous_page(self) -> bool:
return self.page > 0
def next_page_url(self, base_url: URL) -> str:
return self.next().change_paging_url_params(base_url)
def previous_page_url(self, base_url: URL) -> str:
return self.prev().change_paging_url_params(base_url)
def change_paging_url_params(self, url: URL) -> str:
self_dict = self.dict()
removed_previous_paging = url.remove_query_params(self_dict.keys())
return str(removed_previous_paging.include_query_params(**self_dict))
def dict(self) -> dict:
return {"page": self.page, "page_size": self.page_size}
def next(self) -> "PageInfo":
return PageInfo(page=self.page + 1, page_size=self.page_size)
def prev(self) -> "PageInfo":
if self.page == 0:
raise ValueError("Can't go before page 0")
return PageInfo(page=self.page - 1, page_size=self.page_size)
# ### USAGE
# @router.get("/users", response_model=Paginated[User])
# def get_users(request: Request, page: PageInfo = Depends(), db: Session = Depends()):
# return page.paginate(
# query_users(db), request.url
# ) # query_users should return sqlalchemy query
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment