Created
March 8, 2024 21:57
-
-
Save marty0678/306de10831da9e429c2b58a0eb760bdc to your computer and use it in GitHub Desktop.
Django Rest Framework Base Model CRUD Class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Views | |
class UserProfileViewSet(BaseModelViewSet): | |
"""View set for the UserProfile.""" | |
ENABLE_LIST = True | |
ENABLE_RETRIEVE = True | |
ENABLE_PARTIAL_UPDATE = True | |
ENABLE_PAGINATION = True | |
LIST_CACHE_KEY = "users:user-profiles-list" | |
serializer_class = UserProfileSerializer | |
# Additional permission classes as needed, defaults defined in settings.py but | |
# you could put them on the parent class too | |
permission_classes = [...] | |
def get_queryset(self): | |
return UserProfile.objects.all() # Or however you want to filter it | |
# Urls | |
from rest_framework.routers import DefaultRouter | |
router = DefaultRouter() | |
router.register("user-profile", UserProfileViewSet, basename="users-api-user-profile") | |
urlpatterns = [ | |
path("", include(router.urls)), | |
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rest_framework.pagination import PageNumberPagination | |
class BaseModelPagination(PageNumberPagination): | |
page_size = 25 | |
page_size_query_param = 'page_size' | |
page_query_param = 'page' | |
max_page_size = 100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from rest_framework import serializers | |
class BaseBulkDeleteSerializer(serializers.Serializer): | |
ids = serializers.ListSerializer(child=serializers.IntegerField()) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.conf import settings | |
from django.db.models.deletion import RestrictedError | |
from django.db import transaction | |
from django.views.decorators.cache import cache_page | |
from rest_framework import viewsets, status | |
from rest_framework.response import Response | |
from rest_framework.exceptions import MethodNotAllowed, ValidationError | |
from rest_framework.decorators import action | |
from .paginators import BaseModelPagination | |
from .serializers import BaseBulkDeleteSerializer | |
class BaseModelViewSet(viewsets.ModelViewSet): | |
"""Base class for all model views that are sent via the API. | |
Feature toggles to enable different parts of CRUD. Are all disabled by default. | |
Designed to be used with a Router to auto configure routes and URLs.""" | |
ENABLE_ALL = False | |
# Standard actions | |
ENABLE_LIST = False | |
ENABLE_RETRIEVE = False | |
ENABLE_CREATE = False | |
ENABLE_UPDATE = False | |
ENABLE_PARTIAL_UPDATE = False | |
ENABLE_DESTROY = False | |
# Custom actions | |
ENABLE_BULK_CREATE = False | |
ENABLE_BULK_UPDATE = False | |
ENABLE_BULK_DELETE = False | |
# Control | |
ENABLE_PAGINATION = False | |
LIST_CACHE_KEY = None | |
LIST_CACHE_TIME = 60 * 15 # 15 minutes | |
def __init__(self, **kwargs): | |
if self.ENABLE_ALL: | |
self.ENABLE_LIST = True | |
self.ENABLE_RETRIEVE = True | |
self.ENABLE_CREATE = True | |
self.ENABLE_UPDATE = True | |
self.ENABLE_PARTIAL_UPDATE = True | |
self.ENABLE_DESTROY = True | |
self.ENABLE_BULK_CREATE = True | |
self.ENABLE_BULK_UPDATE = True | |
self.ENABLE_BULK_DELETE = True | |
if self.ENABLE_PAGINATION: | |
self.pagination_class = BaseModelPagination | |
super().__init__(**kwargs) | |
def get_serializer_context(self): | |
"""Adds the user to the serializer context on every request.""" | |
context = super().get_serializer_context() | |
context.update({"user": self.request.user}) | |
return context | |
def list(self, request, *args, **kwargs): | |
if self.ENABLE_LIST: | |
if self.LIST_CACHE_KEY and settings.USE_LIST_CACHE: | |
# Cache the page response based on the subclass's settings | |
decorated_view_func = cache_page( | |
self.LIST_CACHE_TIME, key_prefix=self.LIST_CACHE_KEY | |
)(super().list) | |
return decorated_view_func(request, *args, **kwargs) | |
return super().list(request, *args, **kwargs) | |
raise MethodNotAllowed(request.method) | |
def retrieve(self, request, *args, **kwargs): | |
if self.ENABLE_RETRIEVE: | |
return super().retrieve(request, *args, **kwargs) | |
raise MethodNotAllowed(request.method) | |
def create(self, request, *args, **kwargs): | |
if self.ENABLE_CREATE: | |
return super().create(request, *args, **kwargs) | |
raise MethodNotAllowed(request.method) | |
def update(self, request, *args, **kwargs): | |
if self.ENABLE_UPDATE or self.ENABLE_PARTIAL_UPDATE: | |
super().update(request, *args, **kwargs) | |
# On PATCH responses, DRF will not serialize read only fields, | |
# so we need to manually serialize the instance to include them | |
instance = self.get_object() | |
serializer = self.get_serializer(instance) | |
return Response(serializer.data) | |
raise MethodNotAllowed(request.method) | |
def partial_update(self, request, *args, **kwargs): | |
if self.ENABLE_PARTIAL_UPDATE: | |
return super().partial_update(request, *args, **kwargs) | |
raise MethodNotAllowed(request.method) | |
def destroy(self, request, *args, **kwargs): | |
if self.ENABLE_DESTROY: | |
try: | |
return super().destroy(request, *args, **kwargs) | |
except RestrictedError as e: | |
raise ValidationError(e.args[0]) | |
raise MethodNotAllowed(request.method) | |
@action(detail=False, methods=["post"], url_path="bulk-create") | |
def bulk_create(self, request, *args, **kwargs): | |
"""A modified create action that preforms a bulk create of the model. | |
Copied from DRF's source code but with many=True added.""" | |
if self.ENABLE_BULK_CREATE: | |
serializer = self.get_serializer(data=request.data, many=True) | |
serializer.is_valid(raise_exception=True) | |
self.perform_create(serializer) | |
headers = self.get_success_headers(serializer.data) | |
return Response( | |
serializer.data, status=status.HTTP_201_CREATED, headers=headers | |
) | |
raise MethodNotAllowed(request.method) | |
@action(detail=False, methods=["patch"], url_path="bulk-update") | |
def bulk_update(self, request, *args, **kwargs): | |
"""Updates a list of objects in bulk. Must be a list of objects with a id | |
key. Only performs the update in the database if all validations pass. | |
NOTE: This needs optimization. It's not efficient to update each instance one by one. | |
""" | |
if self.ENABLE_BULK_UPDATE: | |
# Separate out the id from the payload we are going to pass to the serializer | |
# Structure is: {'id': **rest_of_payload} | |
try: | |
data = { | |
instance["id"]: { | |
key: value for key, value in instance.items() if key != "id" | |
} | |
for instance in request.data | |
} | |
except KeyError: | |
raise ValidationError("All instances must have a id.") | |
except TypeError: | |
raise ValidationError("The data must be a list of instances.") | |
# Gets all instances based on the ids | |
instances = self.get_queryset().filter(id__in=data.keys()) | |
# Validate if any ids are not found | |
missing_ids = [ | |
id | |
for id in data.keys() | |
if id not in [instance.id for instance in instances] | |
] | |
if missing_ids: | |
raise ValidationError( | |
f"The following ids were not found: {missing_ids}" | |
) | |
# Now we validate each instance | |
serializer_instances_to_update = [] | |
for instance in instances: | |
serializer = self.get_serializer( | |
instance=instance, data=data[instance.id], partial=True | |
) | |
serializer.is_valid(raise_exception=True) | |
serializer_instances_to_update.append(serializer) | |
# Now we know we have no validation errors, we can update the instances | |
for serializer in serializer_instances_to_update: | |
self.perform_update(serializer) | |
return_serializer = self.get_serializer(instance=instances, many=True) | |
return Response(return_serializer.data) | |
raise MethodNotAllowed(request.method) | |
@action(detail=False, methods=["delete"], url_path="bulk-delete") | |
def bulk_delete(self, request, *args, **kwargs): | |
"""Receives a list of IDs and deletes them all. | |
Reports errors to the user.""" | |
if self.ENABLE_BULK_DELETE: | |
serializer = BaseBulkDeleteSerializer(data=request.data) | |
serializer.is_valid(raise_exception=True) | |
instances = self.get_queryset().filter( | |
id__in=serializer.validated_data["ids"] | |
) | |
# Wrap in a transaction so we don't delete some and not others | |
with transaction.atomic(): | |
for instance in instances: | |
self.perform_destroy(instance) | |
# Compare the instance_ids to the serializer.validated_data["ids"] | |
# and alert the user of any missing ids (we still delete the ones that exist) | |
missing_ids = [ | |
id | |
for id in serializer.validated_data["ids"] | |
if id not in [instance.id for instance in instances] | |
] | |
if missing_ids: | |
# Dynamically return the proper response if there are ids missing | |
return Response( | |
{ | |
"detail": f"The following ids were not found: {missing_ids}.{' Other ids were deleted.' if len(instances) > 0 else ''}" | |
}, | |
status=( | |
status.HTTP_404_NOT_FOUND | |
if len(instances) == 0 | |
else status.HTTP_200_OK | |
), | |
) | |
return Response(status=status.HTTP_204_NO_CONTENT) | |
raise MethodNotAllowed(request.method) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment