Skip to content

Instantly share code, notes, and snippets.

@codebutler
Last active October 30, 2023 16:17
Show Gist options
  • Save codebutler/d4e1d41318634eb37e6692a5f980740c to your computer and use it in GitHub Desktop.
Save codebutler/d4e1d41318634eb37e6692a5f980740c to your computer and use it in GitHub Desktop.
# Extends drf serialization to support union types without any additional boilerplate.
# See test for example usage.
#
# From https://drf-spectacular.readthedocs.io/en/latest/blueprints.html:
# Simply copy&paste the snippets into your codebase. The extensions register
# themselves automatically. Just be sure that the python interpreter sees them
# at least once. To that end, we suggest creating a PROJECT/schema.py file and
# importing it in your PROJECT/__init__.py (same directory as settings.py and urls.py)
# with import PROJECT.schema. Now you are all set.
from typing import Literal, get_args, get_origin
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.utils import PolymorphicProxySerializer
from rest_framework import serializers
from rest_framework_dataclasses.fields import UnionField
_orig_get_discriminator = UnionField.get_discriminator
def _get_discriminator_from_literal(tp: type, discriminator_field_name: str):
if discriminator_field_name not in tp.__annotations__:
return _orig_get_discriminator(UnionField({}), tp)
literal_hint = tp.__annotations__[discriminator_field_name]
if get_origin(literal_hint) is not Literal:
raise AttributeError(
f"{tp} has a {discriminator_field_name} attribute that is not a Literal"
)
return get_args(literal_hint)[0]
# https://github.com/oxan/djangorestframework-dataclasses/pull/87
def _get_discriminator(self, tp: type):
return _get_discriminator_from_literal(tp, self.discriminator_field_name)
UnionField.get_discriminator = _get_discriminator
# https://github.com/tfranzel/drf-spectacular/issues/1081
class UnionFieldFix(OpenApiSerializerFieldExtension):
target_class = "rest_framework_dataclasses.fields.UnionField"
def map_serializer_field(self, auto_schema, direction):
union_field: UnionField = self.target
mapping = {
field_class_name: _get_discriminator_from_literal(
field_class, union_field.discriminator_field_name
)
for field_class, field_class_name in union_field.type_mapping.items()
}
for field_class_name, field_serializer in union_field.child_fields.items():
field_serializer._get_discriminator_field_type = ( # noqa: SLF001
lambda obj, captured_name=mapping[field_class_name]: captured_name
)
field_serializer.fields[
union_field.discriminator_field_name
] = serializers.SerializerMethodField(
"_get_discriminator_field_type",
)
union_component_name = (
union_field.parent.dataclass.__name__ + union_field.field_name.title()
)
union_serializer = PolymorphicProxySerializer(
component_name=union_component_name,
serializers=union_field.child_fields.values(),
resource_type_field_name=union_field.discriminator_field_name,
)
union_component = auto_schema.resolve_serializer(union_serializer, direction)
return union_component.ref
openapi: 3.0.3
components:
schemas:
Example:
type: object
properties:
vehicles:
type: array
items:
$ref: '#/components/schemas/Vehicle'
required:
- vehicles
CarVehicleMetadata:
type: object
properties:
type:
type: string
readOnly: true
license_plate:
type: string
required:
- license_plate
- type
PlaneVehicleMetadata:
type: object
properties:
type:
type: string
readOnly: true
flight_number:
type: string
required:
- flight_number
- type
TrainVehicleMetadata:
type: object
properties:
type:
type: string
readOnly: true
train_number:
type: string
carriage_number:
type: string
required:
- carriage_number
- train_number
- type
Vehicle:
type: object
properties:
name:
type: string
metadata:
$ref: '#/components/schemas/VehicleMetadata'
required:
- metadata
- name
VehicleMetadata:
oneOf:
- $ref: '#/components/schemas/CarVehicleMetadata'
- $ref: '#/components/schemas/TrainVehicleMetadata'
- $ref: '#/components/schemas/PlaneVehicleMetadata'
discriminator:
propertyName: type
mapping:
car: '#/components/schemas/CarVehicleMetadata'
train: '#/components/schemas/TrainVehicleMetadata'
plane: '#/components/schemas/PlaneVehicleMetadata'
from dataclasses import dataclass
from io import BytesIO
from typing import Literal
import pytest
from rest_framework.parsers import JSONParser
from rest_framework.renderers import JSONRenderer
from rest_framework_dataclasses.serializers import DataclassSerializer
@dataclass(frozen=True, kw_only=True)
class CarVehicleMetadata:
type: Literal["car"] = "car"
license_plate: str
@dataclass(frozen=True, kw_only=True)
class TrainVehicleMetadata:
type: Literal["train"] = "train"
train_number: str
carriage_number: str
@dataclass(frozen=True, kw_only=True)
class PlaneVehicleMetadata:
type: Literal["plane"] = "plane"
flight_number: str
VehicleMetadata = CarVehicleMetadata | TrainVehicleMetadata | PlaneVehicleMetadata
@dataclass(frozen=True, kw_only=True)
class Vehicle:
name: str
metadata: VehicleMetadata
@dataclass(frozen=True, kw_only=True)
class Example:
vehicles: list[Vehicle]
@pytest.fixture()
def example_data():
return Example(
vehicles=[
Vehicle(
name="car1",
metadata=CarVehicleMetadata(
license_plate="ABC123",
),
),
Vehicle(
name="train1",
metadata=TrainVehicleMetadata(
train_number="123",
carriage_number="A",
),
),
Vehicle(
name="plane1",
metadata=PlaneVehicleMetadata(
flight_number="123",
),
),
]
)
def test_serialization(example_data):
serialized_data = DataclassSerializer(dataclass=Example).to_representation(
example_data
)
rendered_json = JSONRenderer().render(serialized_data)
expected_json = (
b'{"vehicles":[{"name":"car1","metadata":{"type":"car","license_plate":"ABC123"}},'
b'{"name":"train1","metadata":{"type":"train","train_number":"123","carriage_number":"A"}},'
b'{"name":"plane1","metadata":{"type":"plane","flight_number":"123"}}]}'
)
assert rendered_json == expected_json
parsed_data = JSONParser().parse(BytesIO(rendered_json))
dataclass_serializer = DataclassSerializer(data=parsed_data, dataclass=Example)
dataclass_serializer.is_valid(raise_exception=True)
deserialized_instance = dataclass_serializer.save()
assert example_data == deserialized_instance
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment