Skip to content

Instantly share code, notes, and snippets.

@nrbnlulu
Created May 7, 2024 05:46
Show Gist options
  • Save nrbnlulu/c9abf7842218996f9d065709a76c4cee to your computer and use it in GitHub Desktop.
Save nrbnlulu/c9abf7842218996f9d065709a76c4cee to your computer and use it in GitHub Desktop.
Convert dataclasses to strawberry types
from collections.abc import Callable
from typing import Literal, dataclass_transform
import strawberry
from strawberry.type import get_object_definition
@dataclass_transform()
def dataclass_to_strawberry[T: type](kind: Literal["input", "type"] = "type") -> Callable[[T], T]:
def inner(cls: T) -> T:
strawberry_bases = []
extend_dict = {}
extended_annotations = {}
for klass in reversed(cls.__bases__):
if get_object_definition(klass):
strawberry_bases.append(klass)
else:
extend_dict.update(klass.__dict__)
extended_annotations.update(klass.__annotations__)
# strawberry doesn't support slots
# so we need to remove them from the class
# before passing it to strawberry
extend_dict.pop("__slots__", None)
extended_annotations.update(cls.__annotations__)
# remove dataclasses fields from the class
extend_dict = {
k: v
for k, v in extend_dict.items()
if k not in extended_annotations and not k.startswith("__dataclass")
}
extend_dict.pop("__dataclasses_fields", None)
extend_dict.pop("__init__", None)
extend_dict.pop("__dict__", None)
cls.__annotations__ = extended_annotations
extend_dict.update(cls.__dict__)
return strawberry.type(
type(cls.__name__, tuple(strawberry_bases), extend_dict), is_input=kind == "input"
)
return inner
@dataclass_transform()
def dataclass_to_strawberry_input[T: type]() -> Callable[[T], T]:
def inner(cls: T) -> T:
return dataclass_to_strawberry("input")(cls)
return inner
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment