Skip to content

Instantly share code, notes, and snippets.

@SanskarSans
Created October 6, 2020 00:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SanskarSans/a60b4c2244fd6217e5babf64e37fc849 to your computer and use it in GitHub Desktop.
Save SanskarSans/a60b4c2244fd6217e5babf64e37fc849 to your computer and use it in GitHub Desktop.
import json
from graphql_relay.utils import base64, unbase64
from django.db.models import Model as Q
from typing import Any, Dict, Iterable, List, Tuple, Union
from ariadne.graphql import GraphQLError
ConnectionArguments = Dict[str, Any]
def filter_by_query_param(queryset, query, search_fields):
"""Filter queryset according to given parameters.
Keyword arguments:
queryset - queryset to be filtered
query - search string
search_fields - fields considered in filtering
"""
if query:
query_by = {
'{0}__{1}'.format(
field, 'icontains'): query for field in search_fields}
query_objects = Q()
for q in query_by:
query_objects |= Q(**{q: query_by[q]})
return queryset.filter(query_objects)
return queryset
def to_global_cursor(values):
if not isinstance(values, Iterable):
values = [values]
values = [value if value is None else str(value) for value in values]
return base64(json.dumps(values))
def from_global_cursor(cursor) -> List[str]:
print('cursor', cursor)
values = unbase64(cursor)
print('values', values)
return json.loads(values)
def get_field_value(instance, field_name: str):
"""Get field value for given field in filter format 'field__foreign_key_field'."""
field_path = field_name.split("__")
attr = instance
for elem in field_path:
attr = getattr(attr, elem)
if callable(attr):
return "%s" % attr()
return attr
def _prepare_filter_expression(
field_name: str,
index: int,
cursor: List[str],
sorting_fields: List[str],
sorting_direction: str,
) -> Tuple[Q, Dict[str, Union[str, bool]]]:
field_expression: Dict[str, Union[str, bool]] = {}
extra_expression = Q()
for cursor_id, cursor_value in enumerate(cursor[:index]):
field_expression[sorting_fields[cursor_id]] = cursor_value
if sorting_direction == "gt":
extra_expression |= Q(**{f"{field_name}__{sorting_direction}": cursor[index]})
extra_expression |= Q(**{f"{field_name}__isnull": True})
elif cursor[index] is not None:
field_expression[f"{field_name}__{sorting_direction}"] = cursor[index]
else:
field_expression[f"{field_name}__isnull"] = False
return extra_expression, field_expression
def _prepare_filter(
cursor: List[str], sorting_fields: List[str], sorting_direction: str
) -> Q:
"""Create filter arguments based on sorting fields.
:param cursor: list of values that are passed from page_info, used for filtering.
:param sorting_fields: list of fields that were used for sorting.
:param sorting_direction: keyword direction ('lt', gt').
:return: Q() in following format
(OR: ('first_field__gt', 'first_value_form_cursor'),
(AND: ('second_field__gt', 'second_value_form_cursor'),
('first_field', 'first_value_form_cursor')),
(AND: ('third_field__gt', 'third_value_form_cursor'),
('second_field', 'second_value_form_cursor'),
('first_field', 'first_value_form_cursor'))
)
"""
filter_kwargs = Q()
for index, field_name in enumerate(sorting_fields):
if cursor[index] is None and sorting_direction == "gt":
continue
extra_expression, field_expression = _prepare_filter_expression(
field_name, index, cursor, sorting_fields, sorting_direction
)
filter_kwargs |= Q(extra_expression, **field_expression)
return filter_kwargs
def _get_sorting_fields(sort_by, qs):
sorting_fields = sort_by.get("field")
sorting_attribute = sort_by.get("attribute_id")
if sorting_fields and not isinstance(sorting_fields, list):
return [sorting_fields]
elif not sorting_fields and sorting_attribute is not None:
return qs.model.sort_by_attribute_fields()
elif not sorting_fields:
raise ValueError("Error while preparing cursor values.")
return sorting_fields
def _get_sorting_direction(sort_by, last=None):
direction = sort_by.get("direction", "")
sorting_desc = direction == 'DESC'
if last:
sorting_desc = not sorting_desc
return "lt" if sorting_desc else "gt"
def _get_page_info(matching_records, cursor, first, last):
requested_count = first or last
page_info = {
"has_previous_page": False,
"has_next_page": False,
"start_cursor": None,
"end_cursor": None,
}
records_left = False
if requested_count is not None:
records_left = len(matching_records) > requested_count
has_pages_before = True if cursor else False
if first:
page_info["has_next_page"] = records_left
page_info["has_previous_page"] = has_pages_before
elif last:
page_info["has_next_page"] = has_pages_before
page_info["has_previous_page"] = records_left
return page_info
def _get_edges_for_connection(qs, args, sorting_fields):
before = args.get("before")
after = args.get("after")
first = args.get("first")
last = args.get("last")
cursor = after or before
requested_count = first or last
if last:
start_slice, end_slice = 1, None
else:
start_slice, end_slice = 0, requested_count
matching_records = list(qs)
if last:
matching_records = list(reversed(matching_records))
if len(matching_records) <= requested_count:
start_slice = 0
page_info = _get_page_info(matching_records, cursor, first, last)
matching_records = matching_records[start_slice:end_slice]
edges = [
{
"node": record,
"cursor": to_global_cursor([get_field_value(record, field) for field in sorting_fields])
} for record in matching_records
]
if edges:
page_info["start_cursor"] = edges[0]['cursor']
page_info["end_cursor"] = edges[-1]['cursor']
return edges, page_info
def connection_from_queryset_slice(
qs,
args: ConnectionArguments = None,
):
"""Create a connection object from a QuerySet."""
args = args or {}
before = args.get("before")
after = args.get("after")
first = args.get("first")
last = args.get("last")
# _validate_connection_args(args)
requested_count = first or last
end_margin = requested_count + 1 if requested_count else None
cursor = after or before
cursor = from_global_cursor(cursor) if cursor else None
sort_by = args.get("sort_by", {})
sorting_fields = _get_sorting_fields(sort_by, qs)
sorting_direction = _get_sorting_direction(sort_by, last)
if cursor and len(cursor) != len(sorting_fields):
raise GraphQLError("Received cursor is invalid.")
filter_kwargs = (
_prepare_filter(cursor, sorting_fields, sorting_direction) if cursor else Q()
)
qs = qs.filter(filter_kwargs)
qs = qs[:end_margin]
edges, page_info = _get_edges_for_connection(qs, args, sorting_fields)
return {
"edges": edges,
"page_info": page_info
}
@SanskarSans
Copy link
Author

Can anyone convert this to support sqlalchemy core? This might be helpful to others as well.

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment