Last active
October 30, 2023 16:17
-
-
Save codebutler/d4e1d41318634eb37e6692a5f980740c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Extends drf serialization to support union types without any additional boilerplate. | |
# See test for example usage. | |
# | |
# From https://drf-spectacular.readthedocs.io/en/latest/blueprints.html: | |
# Simply copy&paste the snippets into your codebase. The extensions register | |
# themselves automatically. Just be sure that the python interpreter sees them | |
# at least once. To that end, we suggest creating a PROJECT/schema.py file and | |
# importing it in your PROJECT/__init__.py (same directory as settings.py and urls.py) | |
# with import PROJECT.schema. Now you are all set. | |
from typing import Literal, get_args, get_origin | |
from drf_spectacular.extensions import OpenApiSerializerFieldExtension | |
from drf_spectacular.utils import PolymorphicProxySerializer | |
from rest_framework import serializers | |
from rest_framework_dataclasses.fields import UnionField | |
_orig_get_discriminator = UnionField.get_discriminator | |
def _get_discriminator_from_literal(tp: type, discriminator_field_name: str): | |
if discriminator_field_name not in tp.__annotations__: | |
return _orig_get_discriminator(UnionField({}), tp) | |
literal_hint = tp.__annotations__[discriminator_field_name] | |
if get_origin(literal_hint) is not Literal: | |
raise AttributeError( | |
f"{tp} has a {discriminator_field_name} attribute that is not a Literal" | |
) | |
return get_args(literal_hint)[0] | |
# https://github.com/oxan/djangorestframework-dataclasses/pull/87 | |
def _get_discriminator(self, tp: type): | |
return _get_discriminator_from_literal(tp, self.discriminator_field_name) | |
UnionField.get_discriminator = _get_discriminator | |
# https://github.com/tfranzel/drf-spectacular/issues/1081 | |
class UnionFieldFix(OpenApiSerializerFieldExtension): | |
target_class = "rest_framework_dataclasses.fields.UnionField" | |
def map_serializer_field(self, auto_schema, direction): | |
union_field: UnionField = self.target | |
mapping = { | |
field_class_name: _get_discriminator_from_literal( | |
field_class, union_field.discriminator_field_name | |
) | |
for field_class, field_class_name in union_field.type_mapping.items() | |
} | |
for field_class_name, field_serializer in union_field.child_fields.items(): | |
field_serializer._get_discriminator_field_type = ( # noqa: SLF001 | |
lambda obj, captured_name=mapping[field_class_name]: captured_name | |
) | |
field_serializer.fields[ | |
union_field.discriminator_field_name | |
] = serializers.SerializerMethodField( | |
"_get_discriminator_field_type", | |
) | |
union_component_name = ( | |
union_field.parent.dataclass.__name__ + union_field.field_name.title() | |
) | |
union_serializer = PolymorphicProxySerializer( | |
component_name=union_component_name, | |
serializers=union_field.child_fields.values(), | |
resource_type_field_name=union_field.discriminator_field_name, | |
) | |
union_component = auto_schema.resolve_serializer(union_serializer, direction) | |
return union_component.ref |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
openapi: 3.0.3 | |
components: | |
schemas: | |
Example: | |
type: object | |
properties: | |
vehicles: | |
type: array | |
items: | |
$ref: '#/components/schemas/Vehicle' | |
required: | |
- vehicles | |
CarVehicleMetadata: | |
type: object | |
properties: | |
type: | |
type: string | |
readOnly: true | |
license_plate: | |
type: string | |
required: | |
- license_plate | |
- type | |
PlaneVehicleMetadata: | |
type: object | |
properties: | |
type: | |
type: string | |
readOnly: true | |
flight_number: | |
type: string | |
required: | |
- flight_number | |
- type | |
TrainVehicleMetadata: | |
type: object | |
properties: | |
type: | |
type: string | |
readOnly: true | |
train_number: | |
type: string | |
carriage_number: | |
type: string | |
required: | |
- carriage_number | |
- train_number | |
- type | |
Vehicle: | |
type: object | |
properties: | |
name: | |
type: string | |
metadata: | |
$ref: '#/components/schemas/VehicleMetadata' | |
required: | |
- metadata | |
- name | |
VehicleMetadata: | |
oneOf: | |
- $ref: '#/components/schemas/CarVehicleMetadata' | |
- $ref: '#/components/schemas/TrainVehicleMetadata' | |
- $ref: '#/components/schemas/PlaneVehicleMetadata' | |
discriminator: | |
propertyName: type | |
mapping: | |
car: '#/components/schemas/CarVehicleMetadata' | |
train: '#/components/schemas/TrainVehicleMetadata' | |
plane: '#/components/schemas/PlaneVehicleMetadata' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from io import BytesIO | |
from typing import Literal | |
import pytest | |
from rest_framework.parsers import JSONParser | |
from rest_framework.renderers import JSONRenderer | |
from rest_framework_dataclasses.serializers import DataclassSerializer | |
@dataclass(frozen=True, kw_only=True) | |
class CarVehicleMetadata: | |
type: Literal["car"] = "car" | |
license_plate: str | |
@dataclass(frozen=True, kw_only=True) | |
class TrainVehicleMetadata: | |
type: Literal["train"] = "train" | |
train_number: str | |
carriage_number: str | |
@dataclass(frozen=True, kw_only=True) | |
class PlaneVehicleMetadata: | |
type: Literal["plane"] = "plane" | |
flight_number: str | |
VehicleMetadata = CarVehicleMetadata | TrainVehicleMetadata | PlaneVehicleMetadata | |
@dataclass(frozen=True, kw_only=True) | |
class Vehicle: | |
name: str | |
metadata: VehicleMetadata | |
@dataclass(frozen=True, kw_only=True) | |
class Example: | |
vehicles: list[Vehicle] | |
@pytest.fixture() | |
def example_data(): | |
return Example( | |
vehicles=[ | |
Vehicle( | |
name="car1", | |
metadata=CarVehicleMetadata( | |
license_plate="ABC123", | |
), | |
), | |
Vehicle( | |
name="train1", | |
metadata=TrainVehicleMetadata( | |
train_number="123", | |
carriage_number="A", | |
), | |
), | |
Vehicle( | |
name="plane1", | |
metadata=PlaneVehicleMetadata( | |
flight_number="123", | |
), | |
), | |
] | |
) | |
def test_serialization(example_data): | |
serialized_data = DataclassSerializer(dataclass=Example).to_representation( | |
example_data | |
) | |
rendered_json = JSONRenderer().render(serialized_data) | |
expected_json = ( | |
b'{"vehicles":[{"name":"car1","metadata":{"type":"car","license_plate":"ABC123"}},' | |
b'{"name":"train1","metadata":{"type":"train","train_number":"123","carriage_number":"A"}},' | |
b'{"name":"plane1","metadata":{"type":"plane","flight_number":"123"}}]}' | |
) | |
assert rendered_json == expected_json | |
parsed_data = JSONParser().parse(BytesIO(rendered_json)) | |
dataclass_serializer = DataclassSerializer(data=parsed_data, dataclass=Example) | |
dataclass_serializer.is_valid(raise_exception=True) | |
deserialized_instance = dataclass_serializer.save() | |
assert example_data == deserialized_instance |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment