Skip to content

Instantly share code, notes, and snippets.

@Tantas
Last active February 18, 2024 18:44
Show Gist options
  • Save Tantas/fcffc5521d2ea8ed6e01d4d144cb4c95 to your computer and use it in GitHub Desktop.
Save Tantas/fcffc5521d2ea8ed6e01d4d144cb4c95 to your computer and use it in GitHub Desktop.
SqlAlchemy Pydantic 2+ Yaml Column
import json
from io import BytesIO, IOBase, StringIO
from typing import Type, TypeVar
from pydantic import BaseModel, TypeAdapter
from ruamel.yaml import YAML
from sqlalchemy import Text, TypeDecorator
T = TypeVar("T", bound=BaseModel)
def from_yaml(model_type: Type[T], raw: str | bytes | IOBase) -> T:
if isinstance(raw, str):
stream = StringIO(raw)
elif isinstance(raw, IOBase):
stream = raw
else:
stream = BytesIO(raw)
reader = YAML(typ="safe", pure=True)
return TypeAdapter(model_type).validate_python(reader.load(stream))
def to_yaml(model: BaseModel) -> str:
stream = StringIO()
writer = YAML(typ="safe", pure=True)
writer.default_flow_style = False
writer.indent(mapping=None, sequence=None, offset=None)
writer.sort_base_mapping_type_on_output = False
writer.dump(json.loads(model.model_dump_json(exclude_none=True)), stream)
stream.seek(0)
return stream.read()
class PydanticYaml(TypeDecorator):
impl = Text
def __init__(self, model: type[BaseModel], *args, **kwargs):
super().__init__(args, kwargs)
self.model = model
# Fixes table generation bug in :meth:`visit_TEXT` in
# :class:`sqlalchemy.dialects.mysql.base.MySQLTypeCompiler` where length
# is set to an empty tuple and fails the conditional statement check.
self.impl.length = None
def process_bind_param(self, value, dialect):
if value is None:
return value
if not isinstance(value, BaseModel):
assert ValueError("Value must be type BaseModel.")
return to_yaml(value)
def process_result_value(self, value, dialect):
if value is not None:
value = from_yaml(self.model, value)
return value
"""
# Usage example.
from pydantic import BaseModel
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass
class PydanticType(BaseModel):
field: str
class Entity(Base):
id: Mapped[int] = mapped_column(primary_key=True)
data: Mapped[PydanticType] = mapped_column(PydanticYaml(PydanticType))
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment