Last active
April 8, 2024 20:11
-
-
Save IMBlues/e36e792159729f429f9abf656ba24d10 to your computer and use it in GitHub Desktop.
Make your Django REST framework supporting dependency injection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import functools | |
from collections import namedtuple | |
from dataclasses import dataclass, field | |
from typing import TYPE_CHECKING, Optional, Type | |
from django.conf import settings | |
from django.utils.module_loading import import_string | |
from rest_framework import status | |
try: | |
from drf_yasg.utils import swagger_auto_schema | |
SKIP_SWAGGER_SCHEMA = False | |
except ImportError: | |
SKIP_SWAGGER_SCHEMA = True | |
from rest_framework.serializers import BaseSerializer | |
if TYPE_CHECKING: | |
from rest_framework.request import Request | |
ResponseParams = namedtuple("ResponseParams", "data,params") | |
@dataclass | |
class SerializerInjector: | |
"""A injector for injecting serializer as dependency""" | |
in_cls: Type[BaseSerializer] | |
out_cls: Type[BaseSerializer] | |
config: dict = field(default_factory=dict) | |
in_raw_params: dict = field(default_factory=dict) | |
out_raw_params: dict = field(default_factory=dict) | |
_default_config_value = {"data_from": "query_params", "return_validated_data": True, "remain_request": True} | |
def __post_init__(self): | |
self.in_raw_params = self.in_raw_params or dict(raise_exception=True) | |
self.config = self.config or {} | |
self.out_raw_params = self.out_raw_params or {} | |
for extend_config_name, default_value in self._default_config_value.items(): | |
setattr( | |
self, | |
f"{extend_config_name}", | |
self.config.get(extend_config_name, default_value), | |
) | |
try: | |
self.resp_cls = import_string(settings.SERIALIZER_INJECTOR_RESP_CLS) | |
except AttributeError: | |
self.resp_cls = import_string("rest_framework.response.Response") | |
def __str__(self): | |
return f"Injector<In:{self.in_cls.__class__.__name__}, Out:{self.out_cls.__class__.__name__}>" | |
def update_out_params(self, params: dict): | |
"""Update out params""" | |
self.out_raw_params.update(params) | |
def get_serializer_instance(self, request: "Request") -> "BaseSerializer": | |
"""Get in serializer instance""" | |
slz_obj = self.in_cls(data=getattr(request, self.data_from)) # type: ignore | |
slz_obj.is_valid(**self.in_raw_params) | |
return slz_obj | |
def get_validated_data(self, request: "Request") -> dict: | |
"""Get validated data via in_serializer""" | |
return self.get_serializer_instance(request).validated_data | |
def get_in_params(self, request: "Request") -> dict: | |
"""Get extra params before view logic""" | |
if self.return_validated_data: # type: ignore | |
return {"validated_data": self.get_validated_data(request)} | |
else: | |
return {"serializer_instance": self.get_serializer_instance(request)} | |
def get_response(self, data): | |
"""Get Response data""" | |
return self.resp_cls(data=self.out_cls(data, **self.out_raw_params).data) | |
def serializer_inject( | |
in_cls: Type[BaseSerializer] = None, | |
out_cls: Type[BaseSerializer] = None, | |
config: Optional[dict] = None, | |
in_params: Optional[dict] = None, | |
out_params: Optional[dict] = None, | |
swagger_params: Optional[dict] = None, | |
): | |
def decorator_serializer_inject(func): | |
injector = SerializerInjector(in_cls, out_cls, config, in_params, out_params) | |
if not SKIP_SWAGGER_SCHEMA: | |
default_params = {} | |
if in_cls: | |
if injector.data_from == "query_params": | |
default_params = {"query_serializer": in_cls()} | |
else: | |
default_params = {"request_body": in_cls()} | |
if out_cls: | |
default_params.update({"responses": {status.HTTP_200_OK: out_cls()}}) | |
default_params.update(swagger_params or {}) | |
func = swagger_auto_schema(**default_params)(func) | |
@functools.wraps(func) | |
def decorated(*args, **kwargs): | |
in_content = {} | |
if in_cls: | |
in_content.update(**injector.get_in_params(args[1])) | |
if not injector.remain_request: | |
args = args[0] + args[1:] | |
original_data = func(*args, **kwargs, **in_content) | |
if not out_cls: | |
return original_data | |
# support runtime serializer params, like "context" | |
if isinstance(original_data, ResponseParams): | |
injector.update_out_params(original_data.params) | |
original_data = original_data.data | |
return injector.get_response(original_data) | |
return decorated | |
return decorator_serializer_inject |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Haha, I wondered it too. Dependency injection is not a particularly trendy concept, and the DRF community should consider supporting such syntactic sugar natively. Maybe sometime I'll raise an issue with the community and see if they'll accept it.
And I had updated some code in our SDK project(the code in this gist is not updated), you can install the sdk directly(althought we only have readme in chinese version, but I believe the code is pretty and easy reading) or just copy it to your own project. Enjoy!