Skip to content

Instantly share code, notes, and snippets.

@jackton1
Last active October 6, 2023 22:37
Show Gist options
  • Star 22 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save jackton1/dea24643d461bfc87755c2504abbab11 to your computer and use it in GitHub Desktop.
Save jackton1/dea24643d461bfc87755c2504abbab11 to your computer and use it in GitHub Desktop.
Optimize Django Rest Framework model views queries.
from django.db import ProgrammingError, models
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query import normalize_prefetch_lookups
from rest_framework import serializers
from rest_framework.utils import model_meta
class OptimizeModelViewSetMetaclass(type):
"""
This metaclass optimizes the REST API view queryset using `prefetch_related` and `select_related`
if the `serializer_class` is an instance of `serializers.ModelSerializer`.
It determines the `ForeignKey`, `OneToOneField`, and `ManyToMany` fields declared on the model
Serializer class to be added to `prefetch_related` and `select_related` calls.
"""
@staticmethod
def get_many_to_many_rel(info, meta_fields):
many_to_many_fields = [field_name for field_name, relation_info in info.relations.items()
if relation_info.to_many and field_name in meta_fields]
return [lookup for lookup in meta_fields if lookup in many_to_many_fields]
@staticmethod
def get_lookups(fields, strict=False):
field_lookups = [(lookup.split(LOOKUP_SEP, 1)[0], lookup) for lookup in fields]
if strict:
field_lookups = [f for f in field_lookups if LOOKUP_SEP in f[1]]
return field_lookups
@staticmethod
def get_many_to_one_rel(info, meta_fields):
many_to_one_fields = [field_name for field_name, relation_info in info.forward_relations.items()
if isinstance(relation_info[0], models.ForeignKey) and field_name in meta_fields]
return [lookup for lookup in meta_fields if lookup in many_to_one_fields]
@staticmethod
def get_one_to_one_or_one_to_many_rel(info, meta_fields):
return [field_name for field_name, relation_info in info.forward_relations.items()
if field_name in meta_fields and not relation_info.to_many]
def __new__(cls, name, bases, attrs):
serializer_class = attrs.get('serializer_class', None)
queryset = attrs.get('queryset')
many_to_many_fields = []
many_to_one_fields = []
one_to_one_or_one_to_many_fields = []
related_fields = []
info = None
if serializer_class and not issubclass(serializer_class, serializers.ModelSerializer):
return super(OptimizeModelViewSetMetaclass, cls).__new__(cls, name, bases, attrs)
if serializer_class and issubclass(serializer_class, serializers.ModelSerializer):
model_meta_fields = serializer_class._declared_fields.keys()
many_to_many_fields.extend(
field_name for field_name in model_meta_fields
if isinstance(serializer_class._declared_fields[field_name], serializers.ManyRelatedField)
)
many_to_one_fields.extend(
field_name for field_name in model_meta_fields
if isinstance(serializer_class._declared_fields[field_name], serializers.PrimaryKeyRelatedField)
)
one_to_one_or_one_to_many_fields.extend(
field_name for field_name in model_meta_fields
if isinstance(serializer_class._declared_fields[field_name], serializers.RelatedField)
)
if hasattr(serializer_class.Meta, 'model'):
model = serializer_class.Meta.model
info = model_meta.get_field_info(model)
meta_fields = list(serializer_class.Meta.fields)
many_to_many_fields.extend(meta_fields)
many_to_one_fields.extend(meta_fields)
one_to_one_or_one_to_many_fields.extend(meta_fields)
if info is not None:
many_to_many_fields = cls.get_many_to_many_rel(info, set(many_to_many_fields))
many_to_one_fields = cls.get_many_to_one_rel(info, set(many_to_one_fields))
one_to_one_or_one_to_many_fields = cls.get_one_to_one_or_one_to_many_rel(info, set(one_to_one_or_one_to_many_fields))
try:
if queryset is not None:
if many_to_many_fields:
queryset = queryset.prefetch_related(*normalize_prefetch_lookups(many_to_many_fields))
if one_to_one_or_one_to_many_fields:
queryset = queryset.select_related(*one_to_one_or_one_to_many_fields)
attrs['queryset'] = queryset.all()
except ProgrammingError:
pass
return super(OptimizeModelViewSetMetaclass, cls).__new__(cls, name, bases, attrs)
@jackton1
Copy link
Author

jackton1 commented Mar 23, 2018

Usage

models.py

from django.db import models


class Book(models.Model):
    title = models.CharField(max_length=100)
    author = models.CharField(max_length=100)
    publication_date = models.DateField()
    # ...

serializer.py

from rest_framework import serializers

from app.models import Book


class BookSerializer(serializers.ModelSerializer):
    class Meta:
        model = Book
        fields = '__all__'

viewset.py

from drf_optimize import OptimizedViewSetMetaclass
from rest_framework import viewsets

from app.models import Book
from app.api.serializer import BookSerializer


class OptimizedBookViewSet(viewsets.ModelViewSet, metaclass=OptimizedViewSetMetaclass):
    serializer_class = BookSerializer
    queryset = Book.objects.all()

NOTE: This works with all model-based view classes https://www.django-rest-framework.org/api-guide/viewsets/#api-reference

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment