Skip to content

Instantly share code, notes, and snippets.

@dsschneidermann
Created July 22, 2021 09:48
Show Gist options
  • Save dsschneidermann/ddf9c9e4a782c2f6769d503ff0a0a42e to your computer and use it in GitHub Desktop.
Save dsschneidermann/ddf9c9e4a782c2f6769d503ff0a0a42e to your computer and use it in GitHub Desktop.
Django Rest Framework JSON:API AutoPrefetchRecursiveMixin
import typing
from django.db.models.fields.related_descriptors import (
ForwardManyToOneDescriptor,
ManyToManyDescriptor,
ReverseManyToOneDescriptor,
ReverseOneToOneDescriptor,
)
from django.utils.module_loading import import_string
from rest_framework import viewsets as rest_framework_viewsets
from rest_framework.utils import model_meta
from rest_framework_json_api.utils import get_included_resources
# Type checking that the mixins are inheriting from a ViewSet, while in runtime they
# have to inherit from object to be a valid mixin.
if typing.TYPE_CHECKING:
_Mixin_Base = (
rest_framework_viewsets.ViewSet
) # pylint: disable=invalid-name # pragma: no cover
else:
_Mixin_Base = object # pylint: disable=invalid-name
class AutoPrefetchRecursiveMixin(_Mixin_Base):
"""
If JsonApi results contain relations with reverse links (eg. foreign keys defined on the other model),
we are always incurring N+1 lookups to get the related item ids.
This mixin defines the prefetch of related entities to always be active, so we only incur
one lookup per relation instead of one per result row.
"""
# This is a modified version of the method from rest_framework_json_api.views.AutoPrefetchMixin
def get_queryset(self, *args, **kwargs):
"""This mixin adds automatic prefetching for OneToOne and ManyToMany fields, recursively,
when then relationship is not a simple forward foreign key relationship."""
included_resources = get_included_resources(self.request, self.serializer_class)
included_recursive_prefetch = set()
for included in included_resources + ["__all__"]:
included_serializers = getattr(
self.serializer_class, "included_serializers", {}
)
levels = included.split(".")
level_agg = ""
for level in levels:
level_agg += f"{level}."
include_serializer = None
if level == "__all__" and len(levels) == 1:
# For the __all__ item, use the current view serializer
include_serializer = self.serializer_class
level = "" # empty strings for adding current fields
level_agg = ""
else:
# Get serializer from level
include_serializer = included_serializers.get(level, None)
if not include_serializer:
print(
f"WARN: cannot find serializer '{level}' at step '{level_agg}' for include query: '{included}'"
)
break
level_serializer_class = (
include_serializer
if not isinstance(include_serializer, str)
else import_string(include_serializer)
)
level_model = getattr(level_serializer_class.Meta, "model")
info = model_meta.get_field_info(level_model)
relations_set = {
x for x in info.relations.keys() if not x.endswith("_set")
}
for field in relations_set:
field_info = info.relations[field]
if field_info.reverse or field_info.to_many:
included_recursive_prefetch.add(f"{level_agg}{field}")
# Update included_serializers to the current level
included_serializers = getattr(
level_serializer_class, "included_serializers", {}
)
# Set included_resources for AutoPrefetchMixin
included_resources = get_included_resources(
self.request, self.serializer_class
) + list(included_recursive_prefetch)
# Below is a duplicate of rest_framework_json_api.views.AutoPrefetchMixin
qs = super().get_queryset(*args, **kwargs)
for included in included_resources + ["__all__"]:
# If include was not defined, trying to resolve it automatically
included_model = None
levels = included.split(".")
level_model = qs.model
for level in levels:
if not hasattr(level_model, level):
break
field = getattr(level_model, level)
field_class = field.__class__
is_forward_relation = issubclass(
field_class, (ForwardManyToOneDescriptor, ManyToManyDescriptor)
)
is_reverse_relation = issubclass(
field_class, (ReverseManyToOneDescriptor, ReverseOneToOneDescriptor)
)
if not is_reverse_relation and not is_forward_relation:
break
if level == levels[-1]:
included_model = field
else:
if issubclass(field_class, ReverseOneToOneDescriptor):
model_field = field.related.field
else:
model_field = field.field
if is_forward_relation:
level_model = model_field.related_model
else:
level_model = model_field.model
if included_model is not None:
qs = qs.prefetch_related(included.replace(".", "__"))
return qs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment