Skip to content

Instantly share code, notes, and snippets.

@jerber
Last active February 1, 2022 21:38
Show Gist options
  • Save jerber/22c22d949d553be804e75a7703e17c2d to your computer and use it in GitHub Desktop.
Save jerber/22c22d949d553be804e75a7703e17c2d to your computer and use it in GitHub Desktop.
Getting Strawberry to work with nested Pydantic Models
from __future__ import annotations
import typing as T
from enum import Enum
from pydantic import BaseModel
from pydantic.fields import ModelField
import strawberry
from strawberry.field import StrawberryField
from strawberry.types.types import TypeDefinition
def add_back_fields_to_straw_cls(
straw_cls: strawberry.type, model_fields: T.List[ModelField]
) -> strawberry.type:
td: TypeDefinition = straw_cls._type_definition
for model_field in model_fields:
print(f"{model_field=}")
straw_field = StrawberryField(
python_name=model_field.name,
graphql_name=model_field.name,
type_annotation=straw_cls,
default=model_field.default,
default_factory=model_field.default_factory,
)
print(straw_field)
td.fields.append(straw_field)
return straw_cls
def to_strawberry(
pydantic_model: T.Type[BaseModel],
name: str,
input: bool = False,
fields: T.Set[str] = None,
) -> strawberry.type:
removed_field_models: T.List[ModelField] = []
fields = list(fields) if fields else list(pydantic_model.__fields__.keys())
field_names = list(pydantic_model.__fields__.keys())
for field_name in field_names:
field_model = pydantic_model.__fields__[field_name]
t = field_model.type_
print(f"{field_name=}, {field_model=}")
if not isinstance(t, type):
continue
if issubclass(t, Enum):
strawberry.enum(t)
if issubclass(t, BaseModel):
if t == pydantic_model:
print(f"THIS IS RECURSIVE, {pydantic_model=}, {t=}, {field_name=}")
removed_field_models.append(field_model)
del pydantic_model.__fields__[field_name]
continue
straw = to_strawberry(pydantic_model=t, name=t.__name__, input=input)
field_model.type_ = straw
pyd = strawberry.experimental.pydantic
decorator = pyd.type if input is False else pyd.input
cls = type(f"{name}", (object,), {})
new_cls = decorator(model=pydantic_model, fields=fields)(cls)
if removed_field_models:
# TODO add_back_fields_to_straw_cls
new_cls = add_back_fields_to_straw_cls(
straw_cls=new_cls, model_fields=removed_field_models
)
return new_cls
# example cases
from pydantic import BaseModel, Field
class Teacher(BaseModel):
name: str
# students: T.List[Student] = Field(default_factory=list)
class Student(BaseModel):
name: str
friends: T.List[Student] = Field(default_factory=list)
teacher: Teacher = None
best_friend: T.Optional[Student] = None
Student.update_forward_refs()
Teacher.update_forward_refs()
def test_existing_pydantic_func():
"""
This errors... You cannot have nested BaseModels since those too have to be converted into strawberry types.
So, I created a function that recursively does that, called to_strawberry.
"""
@strawberry.experimental.pydantic.type(Student, fields=[*Student.__fields__.keys()])
class StudentType:
pass
def test_custom_pydantic_func():
"""
This tests my attempt to get it to work. It's not working as intended. We need a more fundamental solution to
going from pydantic to strawberry models.
It breaks from infinite recursion when Teacher.students is uncommented.
"""
StudentType = to_strawberry(pydantic_model=Student, name="StudentType", input=False)
assert "friends" in list(StudentType.__dataclass_fields__.keys()) # this fails
if __name__ == "__main__":
test_existing_pydantic_func()
test_custom_pydantic_func()
"""Starting a server with it... fails when you try to query Student.friends."""
StudentType = to_strawberry(pydantic_model=Student, name="StudentType", input=False)
@strawberry.type
class Query:
@strawberry.field
def get_student(self) -> StudentType:
student_type = StudentType(name="Jon")
student_type.friends = []
return student_type
schema = strawberry.Schema(query=Query)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment