Skip to content

Instantly share code, notes, and snippets.

@joeydebreuk
Created March 17, 2021 19:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joeydebreuk/2e1333fb8da82220bca5300ee81d225c to your computer and use it in GitHub Desktop.
Save joeydebreuk/2e1333fb8da82220bca5300ee81d225c to your computer and use it in GitHub Desktop.
strawberry-django-pagination
"""
Usage like:
class SomeResolver(PaginationMixin, ModelResolver):
model = SomeModel
and
class SomeOtherResolver(PaginationMixin, ModelResolver):
model = SomeOtherModel
@strawberry.field
def related(
self,
root: Show,
before: Optional[ID] = UNSET,
after: Optional[ID] = UNSET,
first: Optional[int] = UNSET,
last: Optional[int] = UNSET,
) -> SomeResolver.get_connection_type():
return paginate_django_queryset(
queryset=root.related.all(),
connection_type=SomeResolver.get_connection_type(),
edge_type=SomeResolver.get_edge_type(),
before=before,
after=after,
first=first,
last=last,
)
"""
from typing import List, Optional, Type
import strawberry
from django.conf import settings
from django.db import models
from graphql import GraphQLError
from strawberry import ID
from strawberry.arguments import UNSET, is_unset
from strawberry.types import Info
from strawberry_django.resolvers import get_permission_classes
@strawberry.type(
description=(
"Information about pagination in a connection, "
"based on https://graphql.github.io/learn/pagination/"
)
)
class Pagination:
has_next_page: bool = strawberry.field(
description="When paginating forwards, are there more items?"
)
has_previous_page: bool = strawberry.field(
description="When paginating backwards, are there more items?"
)
start_cursor: Optional[str] = strawberry.field(
description="When paginating forwards, the cursor to continue"
)
end_cursor: Optional[str] = strawberry.field(
description="When paginating backwards, the cursor to continue"
)
def paginate_django_queryset(
queryset: models.QuerySet,
# connection_type is a strawberry field that implements page_info and edges
connection_type: Type[strawberry.field],
edge_type: Type[strawberry.field],
first: Optional[int],
last: Optional[int] = None,
before: Optional[str] = None,
after: Optional[str] = None,
):
limit = getattr(settings, "GRAPHQL_MAX_PAGINATION_LIMIT", 500)
first_is_set = first is not None and not is_unset(first)
last_is_set = last is not None and not is_unset(last)
if not first_is_set and not last_is_set:
# Default pagination to first x results
first = 50
first_is_set = True
if first_is_set and last_is_set:
raise GraphQLError("Passing both `first` and `last` is not supported")
if first_is_set:
if first < 0:
raise GraphQLError("Negative indexing not supported")
if first > limit:
raise GraphQLError(
f"Requesting {first} records on the {connection_type.__name__} "
f"exceeds the `first` limit of {limit} records"
)
if last_is_set:
if last < 0:
raise GraphQLError("Negative indexing not supported")
if last > limit:
raise GraphQLError(
f"Requesting {last} records on the {connection_type.__name__} "
f"exceeds the `last` limit of {limit} "
f"records"
)
queryset = queryset.order_by("pk")
clean_queryset = queryset
total_count = queryset.count()
if after:
queryset = queryset.filter(pk__gt=after)
if before:
queryset = queryset.filter(pk__lt=before)
if first is not None and first_is_set:
queryset = queryset[:first]
elif last is not None and last_is_set:
offset = total_count - last
offset = max(offset, 0) # Make sure we don't accidentally negative index
queryset = queryset[offset:]
if queryset:
has_previous_page = clean_queryset.filter(pk__lt=queryset[0].pk).count() > 0
else:
has_previous_page = False
if queryset:
last_obj_pk = queryset[len(queryset) - 1].pk
has_next_page = clean_queryset.filter(pk__gt=last_obj_pk).count() > 0
else:
has_next_page = False
return connection_type(
total_count=total_count,
page_info=Pagination(
start_cursor=queryset[0].pk if queryset else None,
end_cursor=queryset[len(queryset) - 1].pk if queryset else None,
has_next_page=has_next_page,
has_previous_page=has_previous_page,
),
edges=[edge_type.from_queryset(q) for q in queryset],
)
class PaginationMixin:
edge_type = None
connection_type = None
@classmethod
def get_edge_type(cls):
if cls.edge_type:
return cls.edge_type
name = f"{cls.get_pacalcase_name()}Edge"
@strawberry.type(name=name)
class Edge:
instance: strawberry.Private[cls.model]
@strawberry.field
def cursor(self) -> int:
return self.instance.id
@strawberry.field
def node(self) -> cls.output_type:
return self.instance
@staticmethod
def from_queryset(queryset: models.QuerySet) -> "Edge":
return Edge(queryset)
cls.edge_type = type(name, (Edge,), {})
cls.edge_type = Edge
return cls.edge_type
@classmethod
def get_pacalcase_name(cls) -> str:
return cls.model._meta.object_name
@classmethod
def get_connection_type(cls):
if cls.connection_type:
return cls.connection_type
name = f"{cls.get_pacalcase_name()}Connection"
@strawberry.type(name=name)
class Connection:
edges: List[Optional[cls.get_edge_type()]]
page_info: Pagination
total_count: int = strawberry.field(
description="Identifies the total count of items in the connection."
)
cls.connection_type = type(name, (Connection,), {})
return cls.connection_type
@classmethod
def list_field(cls):
permission_classes = get_permission_classes(cls, "view")
edge_type = cls.get_edge_type()
connection_type = cls.get_connection_type()
@strawberry.field(permission_classes=permission_classes)
def list_field(
info: Info,
root,
filters: Optional[List[str]] = None,
before: Optional[ID] = UNSET,
after: Optional[ID] = UNSET,
first: Optional[int] = UNSET,
last: Optional[int] = UNSET,
) -> connection_type:
instance = cls(info, root)
return paginate_django_queryset(
queryset=instance.list(filters=filters),
connection_type=connection_type,
edge_type=edge_type,
before=before,
after=after,
first=first,
last=last,
)
return list_field
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment