Created
July 22, 2021 09:48
-
-
Save dsschneidermann/ddf9c9e4a782c2f6769d503ff0a0a42e to your computer and use it in GitHub Desktop.
Django Rest Framework JSON:API AutoPrefetchRecursiveMixin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
from django.db.models.fields.related_descriptors import ( | |
ForwardManyToOneDescriptor, | |
ManyToManyDescriptor, | |
ReverseManyToOneDescriptor, | |
ReverseOneToOneDescriptor, | |
) | |
from django.utils.module_loading import import_string | |
from rest_framework import viewsets as rest_framework_viewsets | |
from rest_framework.utils import model_meta | |
from rest_framework_json_api.utils import get_included_resources | |
# Type checking that the mixins are inheriting from a ViewSet, while in runtime they | |
# have to inherit from object to be a valid mixin. | |
if typing.TYPE_CHECKING: | |
_Mixin_Base = ( | |
rest_framework_viewsets.ViewSet | |
) # pylint: disable=invalid-name # pragma: no cover | |
else: | |
_Mixin_Base = object # pylint: disable=invalid-name | |
class AutoPrefetchRecursiveMixin(_Mixin_Base): | |
""" | |
If JsonApi results contain relations with reverse links (eg. foreign keys defined on the other model), | |
we are always incurring N+1 lookups to get the related item ids. | |
This mixin defines the prefetch of related entities to always be active, so we only incur | |
one lookup per relation instead of one per result row. | |
""" | |
# This is a modified version of the method from rest_framework_json_api.views.AutoPrefetchMixin | |
def get_queryset(self, *args, **kwargs): | |
"""This mixin adds automatic prefetching for OneToOne and ManyToMany fields, recursively, | |
when then relationship is not a simple forward foreign key relationship.""" | |
included_resources = get_included_resources(self.request, self.serializer_class) | |
included_recursive_prefetch = set() | |
for included in included_resources + ["__all__"]: | |
included_serializers = getattr( | |
self.serializer_class, "included_serializers", {} | |
) | |
levels = included.split(".") | |
level_agg = "" | |
for level in levels: | |
level_agg += f"{level}." | |
include_serializer = None | |
if level == "__all__" and len(levels) == 1: | |
# For the __all__ item, use the current view serializer | |
include_serializer = self.serializer_class | |
level = "" # empty strings for adding current fields | |
level_agg = "" | |
else: | |
# Get serializer from level | |
include_serializer = included_serializers.get(level, None) | |
if not include_serializer: | |
print( | |
f"WARN: cannot find serializer '{level}' at step '{level_agg}' for include query: '{included}'" | |
) | |
break | |
level_serializer_class = ( | |
include_serializer | |
if not isinstance(include_serializer, str) | |
else import_string(include_serializer) | |
) | |
level_model = getattr(level_serializer_class.Meta, "model") | |
info = model_meta.get_field_info(level_model) | |
relations_set = { | |
x for x in info.relations.keys() if not x.endswith("_set") | |
} | |
for field in relations_set: | |
field_info = info.relations[field] | |
if field_info.reverse or field_info.to_many: | |
included_recursive_prefetch.add(f"{level_agg}{field}") | |
# Update included_serializers to the current level | |
included_serializers = getattr( | |
level_serializer_class, "included_serializers", {} | |
) | |
# Set included_resources for AutoPrefetchMixin | |
included_resources = get_included_resources( | |
self.request, self.serializer_class | |
) + list(included_recursive_prefetch) | |
# Below is a duplicate of rest_framework_json_api.views.AutoPrefetchMixin | |
qs = super().get_queryset(*args, **kwargs) | |
for included in included_resources + ["__all__"]: | |
# If include was not defined, trying to resolve it automatically | |
included_model = None | |
levels = included.split(".") | |
level_model = qs.model | |
for level in levels: | |
if not hasattr(level_model, level): | |
break | |
field = getattr(level_model, level) | |
field_class = field.__class__ | |
is_forward_relation = issubclass( | |
field_class, (ForwardManyToOneDescriptor, ManyToManyDescriptor) | |
) | |
is_reverse_relation = issubclass( | |
field_class, (ReverseManyToOneDescriptor, ReverseOneToOneDescriptor) | |
) | |
if not is_reverse_relation and not is_forward_relation: | |
break | |
if level == levels[-1]: | |
included_model = field | |
else: | |
if issubclass(field_class, ReverseOneToOneDescriptor): | |
model_field = field.related.field | |
else: | |
model_field = field.field | |
if is_forward_relation: | |
level_model = model_field.related_model | |
else: | |
level_model = model_field.model | |
if included_model is not None: | |
qs = qs.prefetch_related(included.replace(".", "__")) | |
return qs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment