Skip to content

Instantly share code, notes, and snippets.

@AndrewIngram
Last active April 13, 2021 10:35
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AndrewIngram/b1a6e66ce92d2d0befd2f2f65eb62ca5 to your computer and use it in GitHub Desktop.
Save AndrewIngram/b1a6e66ce92d2d0befd2f2f65eb62ca5 to your computer and use it in GitHub Desktop.
Proper cursors with Graphene + Django. Graphene-Django's stock connections use limit/offset logic under the hood, making the whole cursor-based connection modelling kinda pointless.
import datetime
import operator
from base64 import b64decode as _unbase64
from base64 import b64encode as _base64
from functools import reduce
from django.db.models import Q
from graphene import relay
from graphql_relay.connection import connectiontypes
def base64(s):
return _base64(s.encode('utf-8')).decode('utf-8')
def unbase64(s):
return _unbase64(s).decode('utf-8')
def get_attribute(instance, name):
if hasattr(instance, name):
return getattr(instance, name)
names = name.split("__")
name = names.pop(0)
if len(names) == 0:
return None
if hasattr(instance, name):
value = getattr(instance, name)
return get_attribute(value, "__".join(names))
return None
def attr_from_sort(sort):
if sort[0] == '-':
return sort[1:]
return sort
def build_q_objects(sort, cursor_parts):
attr = attr_from_sort(sort[-1])
sort_direction = 'lt' if sort[-1][0] == '-' else 'gt'
kwargs = {
f"{attr}__{sort_direction}": cursor_parts[attr],
}
for x in sort[0:-1]:
x_attr = attr_from_sort(x)
kwargs[x_attr] = cursor_parts[x_attr]
q = Q(**kwargs)
if len(sort) == 1:
return [q]
acc = build_q_objects(sort[0: -1], cursor_parts)
acc.append(q)
return acc
def cursor_string_from_parts(parts, sort):
bits = []
for x in sort:
attr = attr_from_sort(x)
bits.append(parts[attr])
return base64('|'.join(bits))
def parts_from_cursor_string(cursor, sort):
cursor_parts = {}
bits = unbase64(cursor).split('|')
for i, x in enumerate(sort):
cursor_parts[attr_from_sort(x)] = bits[i]
return cursor_parts
def cursor_string_from_obj(obj, sort):
cursor_parts = {}
for x in sort:
attr_name = attr_from_sort(x)
attr = get_attribute(obj, attr_name)
if isinstance(attr, datetime.datetime):
attr = attr.isoformat()
else:
attr = str(attr)
cursor_parts[attr_name] = attr
return cursor_string_from_parts(cursor_parts, sort)
def filter_queryset(qs, cursor, sort):
cursor_parts = parts_from_cursor_string(cursor, sort)
q_objects = build_q_objects(sort, cursor_parts)
return qs.filter(reduce(operator.__or__, q_objects))
class QuerysetConnectionField(relay.ConnectionField):
def __init__(self, type, *args, **kwargs):
return_value = super().__init__(
type,
*args,
**kwargs
)
# Validate class methods
assert hasattr(type, 'get_queryset'), f'Connection type {type} needs a `get_queryset` method'
assert hasattr(type, 'get_sort'), f'Connection type {type} needs a `get_sort` method'
return return_value
@classmethod
def connection_resolver(cls, resolver, connection_type, root, info, **args):
if hasattr(connection_type, 'of_type'):
connection_type = connection_type.of_type
first = args.get('first')
last = args.get('last')
after = args.get('after')
# before = args.get('before')
sort = connection_type.get_sort(**args)
# Validate connection arguments
assert first or last, (
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
).format(info.field_name)
assert not (first and last), (
'You cannot define both `first` and `last` values on `{}` connection.'
).format(info.field_name)
assert not last, 'last` argument is not supported'
qs = connection_type.get_queryset(root, info, **args).order_by(*sort)
if after:
qs = filter_queryset(qs, after, sort)
total_length = qs.count()
if first:
qs = qs[:first]
edge_type = connection_type.Edge or connectiontypes.Edge
edges = [
edge_type(
node=node,
cursor=cursor_string_from_obj(node, sort)
)
for node in qs.iterator()
]
first_edge_cursor = edges[0].cursor if edges else None
last_edge_cursor = edges[-1].cursor if edges else None
page_info = relay.PageInfo(
start_cursor=first_edge_cursor,
end_cursor=last_edge_cursor,
has_previous_page=False, # TODO
has_next_page=isinstance(first, int) and (total_length > first),
)
return connection_type(
edges=edges,
page_info=page_info,
)
class User(graphene.ObjectType):
name = graphene.String(required=True)
class AllUsersConnection(graphene.relay.Connection):
@classmethod
def get_queryset(cls, root, info, **kwargs):
# Root would be the Query type instance in this case
return DjangoUser.objects.all()
@classmethod
def get_sort(cls, **kwargs):
return ("name",)
class Meta:
node = User
class Query(graphene.ObjectType):
all_users = QuerysetConnectionField(AllUsersConnection, required=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment