Skip to content

Instantly share code, notes, and snippets.

@GeeWee
Last active April 15, 2019 18:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GeeWee/5f82fd11d5e9c95bfe61993c07624ccc to your computer and use it in GitHub Desktop.
Save GeeWee/5f82fd11d5e9c95bfe61993c07624ccc to your computer and use it in GitHub Desktop.
import inspect
from collections import OrderedDict
from json import loads, dumps
from pprint import pprint
from typing import Type, Union
from django.core.exceptions import FieldError
from django.test import TestCase
from rest_framework.relations import RelatedField, ManyRelatedField
from rest_framework.serializers import ModelSerializer, Serializer, BaseSerializer, ListSerializer
from rest_framework.test import APIRequestFactory
from rest_framework.utils.serializer_helpers import ReturnList
#This snippet automatically prefetches all related objects a Serializer needs.
#Example usage in your ViewSets get_queryset method:
def get_queryset():
qs = YOUR_MODEL.objects.all() # Or whatever queryset you want to use
qs = prefetch(self.get_serializer_class(), queryset) # This line prefetches the related model depending on the serializer.
return qs
def prefetch(queryset, serializer: Type[ModelSerializer]):
select_related, prefetch_related = _prefetch(serializer)
return queryset.select_related(*select_related).prefetch_related(*prefetch_related)
def _prefetch(serializer: Union[Type[BaseSerializer], BaseSerializer], path=None, indentation=0):
"""
Returns prefetch_related, select_related
:param serializer:
:return:
"""
prepend = f'{path}__' if path is not None else ''
class_name = getattr(serializer, '__name__', serializer.__class__.__name__)
print(f'{" " * indentation}LOOKING AT SERIALIZER:', class_name, 'from path: ', prepend)
select_related = set()
prefetch_related = set()
print()
if inspect.isclass(serializer):
print('serializer is a class')
serializer_instance = serializer()
else:
serializer_instance = serializer
try:
fields = getattr(serializer_instance, 'child', serializer_instance).fields.fields.items()
except AttributeError:
# This can happen if there's no further fields, e.g. if we're passed a PrimaryKeyRelatedField
# as the nested representation of a ManyToManyField
return (set(), set())
for name, field_instance in fields:
field_type_name = field_instance.__class__.__name__
print(f'{" " * indentation} Field "{name}", type: {field_type_name}, src: "{field_instance.source}"')
# We potentially need to recurse deeper
if isinstance(field_instance, (BaseSerializer, RelatedField, ManyRelatedField)):
print(f'{" " * indentation}Found: {field_type_name} ({type(field_instance)}) - recursing deeper')
field_path = f'{prepend}{field_instance.source}'
# Fields where the field name *is* the model.
if isinstance(field_instance, RelatedField):
print(f'{" " * indentation} Found related field: ', field_type_name)
select_related.add(f'{prepend}{name}')
"""
If we have multiple entities, we need to use prefetch_related instead of select_related
We also need to do this for all further calls
"""
elif isinstance(field_instance, (ListSerializer, ManyRelatedField)):
print(f'{" " * indentation} Found *:m relation: ', field_type_name)
prefetch_related.add(field_path)
# If it's a ManyRelatedField, we can only get the actual underlying field by querying child_relation
nested_field = getattr(field_instance, 'child_relation', field_instance)
select, prefetch = _prefetch(nested_field, field_path, indentation + 4)
prefetch_related |= select
prefetch_related |= prefetch
else:
print(f'{" " * indentation} Found *:1 relation: ', field_type_name)
select_related.add(field_path)
select, prefetch = _prefetch(field_instance, field_path, indentation + 4)
select_related |= select
prefetch_related |= prefetch
return (select_related, prefetch_related)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment