Skip to content

Instantly share code, notes, and snippets.

@DurandA
Created November 9, 2023 21:22
Show Gist options
  • Save DurandA/202c1b15b2d3bd2a7f3c9f5e6af8c677 to your computer and use it in GitHub Desktop.
Save DurandA/202c1b15b2d3bd2a7f3c9f5e6af8c677 to your computer and use it in GitHub Desktop.
Create Pydantic obj from SQLAlchemy model with AsyncAttrs
from typing import Any, List, Type, TypeVar
from pydantic import BaseModel
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
from sqlalchemy.ext.asyncio import (AsyncAttrs, AsyncEngine, AsyncSession,
create_async_engine)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.future import select
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import RelationshipProperty, relationship, sessionmaker
from sqlalchemy.orm.decl_api import DeclarativeMeta
Base: DeclarativeMeta = declarative_base(cls=AsyncAttrs)
class Parent(Base):
__tablename__ = 'parents'
id: Any = Column(Integer, primary_key=True)
name: Any = Column(String)
children = relationship("Child", back_populates="parent")
class Child(Base):
__tablename__ = 'children'
id: Any = Column(Integer, primary_key=True)
name: Any = Column(String)
parent_id: Any = Column(Integer, ForeignKey('parents.id'))
parent = relationship("Parent", back_populates="children")
class ChildPydantic(BaseModel):
id: int
name: str
class ParentPydantic(BaseModel):
id: int
name: str
children: List[ChildPydantic]
T = TypeVar('T', bound=BaseModel)
async def load_relationships_and_create_pydantic(
db_obj,
pydantic_model: Type[T],
session: AsyncSession
) -> T:
# create a dictionary to hold the attributes including relationships
loaded_attrs = {}
for field_name, field_info in pydantic_model.__annotations__.items():
mapper_attrs = inspect(db_obj.__class__).mapper.attrs
# check if the field is a relationship that should be loaded
if field_name in mapper_attrs and isinstance(mapper_attrs[field_name], RelationshipProperty):
relationship_data = await getattr(db_obj.awaitable_attrs, field_name)
if isinstance(relationship_data, list): # relationship is a list of items, use type of the list
loaded_attrs[field_name] = [
await load_relationships_and_create_pydantic(item, field_info.__args__[0], session)
for item in relationship_data
]
else: # relationship is a single item
loaded_attrs[field_name] = await load_relationships_and_create_pydantic(relationship_data, field_info, session)
else:
# regular attribute, get its value
attr_value = getattr(db_obj, field_name, None)
if attr_value is not None:
loaded_attrs[field_name] = attr_value
# create the Pydantic instance
return pydantic_model(**loaded_attrs)
async def main():
DATABASE_URL = "sqlite+aiosqlite:///./test.db"
engine: AsyncEngine = create_async_engine(DATABASE_URL, echo=True)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
parent = Parent(name='Parent1')
session.add(parent)
await session.commit()
child = Child(name='Child1', parent_id=parent.id)
session.add(child)
await session.commit()
stmt = select(Parent).where(Parent.id == parent.id)
result = await session.execute(stmt)
parent_obj = result.scalars().first()
parent_pydantic = await load_relationships_and_create_pydantic(parent_obj, ParentPydantic, session)
print(parent_pydantic.model_dump())
import asyncio
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment