Skip to content

Instantly share code, notes, and snippets.

@ThirVondukr
Last active March 14, 2022 11:50
Show Gist options
  • Save ThirVondukr/4f37a8b5f67d6677621f3b3d7f455da6 to your computer and use it in GitHub Desktop.
Save ThirVondukr/4f37a8b5f67d6677621f3b3d7f455da6 to your computer and use it in GitHub Desktop.
Strawberry GraphQL Cursor Pagination
import base64
import enum
from typing import Any, Optional, Annotated
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute, DeclarativeMeta
from sqlalchemy.sql import Select
from gql.modules.users._fields import UserOrder
def model_fields_enum(
*columns: InstrumentedAttribute, name: str
) -> type[enum.Enum]:
return enum.Enum( # type: ignore
name,
names={column.name.upper(): column.prop for column in columns},
module=__name__,
)
def encode_cursor(
model: DeclarativeMeta,
order_by: list[InstrumentedAttribute],
) -> str:
values = [
str(getattr(model, attr.prop.class_attribute.name))
for attr in order_by
]
return ":".join(
base64.b64encode(value.encode()).decode()
for value in values
)
def decode_cursor(
cursor: str,
order_by: list[InstrumentedAttribute],
# order_by isn't used but you could use it to validate cursor structure
) -> list[Any]:
return [
base64.b64decode(value.encode()).decode()
for value in cursor.split(":")
]
class Paginator:
def __init__(
self,
query: Select,
order_by: list[enum.Enum] | list[InstrumentedAttribute],
):
self.query = query
self.order_by: list[InstrumentedAttribute] = [
col
if isinstance(col, InstrumentedAttribute)
else col.value.class_attribute
for col in order_by
]
async def paginate(
self,
after: str,
before: str,
first: Optional[int],
last: Optional[int],
session: AsyncSession,
) -> tuple[list[Edge[Any]], PageInfo]:
if first and first < 0:
raise ValueError
if last and last < 0:
raise ValueError
if first and last:
raise ValueError
order_clause = tuple_(*self.order_by)
if last:
order_clause = order_clause.desc()
query = self.query.order_by(order_clause)
query = query.limit(first or last + 1)
if after:
query = query.filter(
decode_cursor(after, self.order_by) < self.order_by
)
if before:
query = query.filter(
self.order_by < decode_cursor(before, self.order_by)
)
nodes = list(await session.scalars(query))
page_info = PageInfo(
has_previous_page=False,
has_next_page=False,
start_cursor="",
end_cursor="",
)
if first and len(nodes) > first:
page_info.has_next_page = True
nodes = nodes[:first]
if last:
# We need to reverse nodes since we used order_by.desc()
nodes.reverse()
if len(nodes) > last:
page_info.has_previous_page = True
nodes = nodes[-last:]
if nodes:
page_info.start_cursor = encode_cursor(nodes[0], self.order_by)
page_info.end_cursor = encode_cursor(nodes[-1], self.order_by)
edges = [
Edge(node=node, cursor=encode_cursor(node, self.order_by))
for node in nodes
]
return edges, page_info
@inject
async def all_users(
session: Annotated[AsyncSession, Inject],
order_by: list[UserOrder],
after: str = "",
before: str = "",
first: Optional[int] = None,
last: Optional[int] = None,
) -> Connection[Edge[UserType]]:
query = select(User)
paginator = Paginator(query, order_by)
edges, page_info = await paginator.paginate(
after=after,
before=before,
first=first,
last=last,
session=session,
)
return Connection(
edges=edges,
page_info=page_info,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment