Skip to content

Instantly share code, notes, and snippets.

@marty0678
Created March 8, 2024 21:57
Show Gist options
  • Save marty0678/306de10831da9e429c2b58a0eb760bdc to your computer and use it in GitHub Desktop.
Save marty0678/306de10831da9e429c2b58a0eb760bdc to your computer and use it in GitHub Desktop.
Django Rest Framework Base Model CRUD Class
# Views
class UserProfileViewSet(BaseModelViewSet):
"""View set for the UserProfile."""
ENABLE_LIST = True
ENABLE_RETRIEVE = True
ENABLE_PARTIAL_UPDATE = True
ENABLE_PAGINATION = True
LIST_CACHE_KEY = "users:user-profiles-list"
serializer_class = UserProfileSerializer
# Additional permission classes as needed, defaults defined in settings.py but
# you could put them on the parent class too
permission_classes = [...]
def get_queryset(self):
return UserProfile.objects.all() # Or however you want to filter it
# Urls
from rest_framework.routers import DefaultRouter
router = DefaultRouter()
router.register("user-profile", UserProfileViewSet, basename="users-api-user-profile")
urlpatterns = [
path("", include(router.urls)),
]
from rest_framework.pagination import PageNumberPagination
class BaseModelPagination(PageNumberPagination):
page_size = 25
page_size_query_param = 'page_size'
page_query_param = 'page'
max_page_size = 100
from rest_framework import serializers
class BaseBulkDeleteSerializer(serializers.Serializer):
ids = serializers.ListSerializer(child=serializers.IntegerField())
from django.conf import settings
from django.db.models.deletion import RestrictedError
from django.db import transaction
from django.views.decorators.cache import cache_page
from rest_framework import viewsets, status
from rest_framework.response import Response
from rest_framework.exceptions import MethodNotAllowed, ValidationError
from rest_framework.decorators import action
from .paginators import BaseModelPagination
from .serializers import BaseBulkDeleteSerializer
class BaseModelViewSet(viewsets.ModelViewSet):
"""Base class for all model views that are sent via the API.
Feature toggles to enable different parts of CRUD. Are all disabled by default.
Designed to be used with a Router to auto configure routes and URLs."""
ENABLE_ALL = False
# Standard actions
ENABLE_LIST = False
ENABLE_RETRIEVE = False
ENABLE_CREATE = False
ENABLE_UPDATE = False
ENABLE_PARTIAL_UPDATE = False
ENABLE_DESTROY = False
# Custom actions
ENABLE_BULK_CREATE = False
ENABLE_BULK_UPDATE = False
ENABLE_BULK_DELETE = False
# Control
ENABLE_PAGINATION = False
LIST_CACHE_KEY = None
LIST_CACHE_TIME = 60 * 15 # 15 minutes
def __init__(self, **kwargs):
if self.ENABLE_ALL:
self.ENABLE_LIST = True
self.ENABLE_RETRIEVE = True
self.ENABLE_CREATE = True
self.ENABLE_UPDATE = True
self.ENABLE_PARTIAL_UPDATE = True
self.ENABLE_DESTROY = True
self.ENABLE_BULK_CREATE = True
self.ENABLE_BULK_UPDATE = True
self.ENABLE_BULK_DELETE = True
if self.ENABLE_PAGINATION:
self.pagination_class = BaseModelPagination
super().__init__(**kwargs)
def get_serializer_context(self):
"""Adds the user to the serializer context on every request."""
context = super().get_serializer_context()
context.update({"user": self.request.user})
return context
def list(self, request, *args, **kwargs):
if self.ENABLE_LIST:
if self.LIST_CACHE_KEY and settings.USE_LIST_CACHE:
# Cache the page response based on the subclass's settings
decorated_view_func = cache_page(
self.LIST_CACHE_TIME, key_prefix=self.LIST_CACHE_KEY
)(super().list)
return decorated_view_func(request, *args, **kwargs)
return super().list(request, *args, **kwargs)
raise MethodNotAllowed(request.method)
def retrieve(self, request, *args, **kwargs):
if self.ENABLE_RETRIEVE:
return super().retrieve(request, *args, **kwargs)
raise MethodNotAllowed(request.method)
def create(self, request, *args, **kwargs):
if self.ENABLE_CREATE:
return super().create(request, *args, **kwargs)
raise MethodNotAllowed(request.method)
def update(self, request, *args, **kwargs):
if self.ENABLE_UPDATE or self.ENABLE_PARTIAL_UPDATE:
super().update(request, *args, **kwargs)
# On PATCH responses, DRF will not serialize read only fields,
# so we need to manually serialize the instance to include them
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)
raise MethodNotAllowed(request.method)
def partial_update(self, request, *args, **kwargs):
if self.ENABLE_PARTIAL_UPDATE:
return super().partial_update(request, *args, **kwargs)
raise MethodNotAllowed(request.method)
def destroy(self, request, *args, **kwargs):
if self.ENABLE_DESTROY:
try:
return super().destroy(request, *args, **kwargs)
except RestrictedError as e:
raise ValidationError(e.args[0])
raise MethodNotAllowed(request.method)
@action(detail=False, methods=["post"], url_path="bulk-create")
def bulk_create(self, request, *args, **kwargs):
"""A modified create action that preforms a bulk create of the model.
Copied from DRF's source code but with many=True added."""
if self.ENABLE_BULK_CREATE:
serializer = self.get_serializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)
raise MethodNotAllowed(request.method)
@action(detail=False, methods=["patch"], url_path="bulk-update")
def bulk_update(self, request, *args, **kwargs):
"""Updates a list of objects in bulk. Must be a list of objects with a id
key. Only performs the update in the database if all validations pass.
NOTE: This needs optimization. It's not efficient to update each instance one by one.
"""
if self.ENABLE_BULK_UPDATE:
# Separate out the id from the payload we are going to pass to the serializer
# Structure is: {'id': **rest_of_payload}
try:
data = {
instance["id"]: {
key: value for key, value in instance.items() if key != "id"
}
for instance in request.data
}
except KeyError:
raise ValidationError("All instances must have a id.")
except TypeError:
raise ValidationError("The data must be a list of instances.")
# Gets all instances based on the ids
instances = self.get_queryset().filter(id__in=data.keys())
# Validate if any ids are not found
missing_ids = [
id
for id in data.keys()
if id not in [instance.id for instance in instances]
]
if missing_ids:
raise ValidationError(
f"The following ids were not found: {missing_ids}"
)
# Now we validate each instance
serializer_instances_to_update = []
for instance in instances:
serializer = self.get_serializer(
instance=instance, data=data[instance.id], partial=True
)
serializer.is_valid(raise_exception=True)
serializer_instances_to_update.append(serializer)
# Now we know we have no validation errors, we can update the instances
for serializer in serializer_instances_to_update:
self.perform_update(serializer)
return_serializer = self.get_serializer(instance=instances, many=True)
return Response(return_serializer.data)
raise MethodNotAllowed(request.method)
@action(detail=False, methods=["delete"], url_path="bulk-delete")
def bulk_delete(self, request, *args, **kwargs):
"""Receives a list of IDs and deletes them all.
Reports errors to the user."""
if self.ENABLE_BULK_DELETE:
serializer = BaseBulkDeleteSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
instances = self.get_queryset().filter(
id__in=serializer.validated_data["ids"]
)
# Wrap in a transaction so we don't delete some and not others
with transaction.atomic():
for instance in instances:
self.perform_destroy(instance)
# Compare the instance_ids to the serializer.validated_data["ids"]
# and alert the user of any missing ids (we still delete the ones that exist)
missing_ids = [
id
for id in serializer.validated_data["ids"]
if id not in [instance.id for instance in instances]
]
if missing_ids:
# Dynamically return the proper response if there are ids missing
return Response(
{
"detail": f"The following ids were not found: {missing_ids}.{' Other ids were deleted.' if len(instances) > 0 else ''}"
},
status=(
status.HTTP_404_NOT_FOUND
if len(instances) == 0
else status.HTTP_200_OK
),
)
return Response(status=status.HTTP_204_NO_CONTENT)
raise MethodNotAllowed(request.method)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment