Last active
January 18, 2024 22:03
-
-
Save tcbegley/38267d40130f3633528ce4b00903652e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from uuid import UUID | |
from litestar import Litestar, get | |
from litestar.contrib.sqlalchemy.base import UUIDBase | |
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO, SQLAlchemyDTOConfig | |
from litestar.contrib.sqlalchemy.plugins import ( | |
AsyncSessionConfig, | |
SQLAlchemyAsyncConfig, | |
SQLAlchemyPlugin, | |
) | |
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository | |
from litestar.di import Provide | |
from sqlalchemy import Column, ForeignKey, Table, select | |
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker | |
from sqlalchemy.orm import Mapped, relationship, selectinload | |
association_table = Table( | |
"association_table", | |
UUIDBase.metadata, | |
Column("student_id", ForeignKey("student.id")), | |
Column("teacher", ForeignKey("teacher.id")), | |
) | |
class Student(UUIDBase): | |
name: Mapped[str] | |
teachers: Mapped[list[Teacher]] = relationship( | |
secondary=association_table, lazy="selectin" | |
) | |
class Teacher(UUIDBase): | |
name: Mapped[str] | |
sqlalchemy_config = SQLAlchemyAsyncConfig( | |
connection_string="sqlite+aiosqlite:///test.sqlite", | |
session_config=AsyncSessionConfig(expire_on_commit=False), | |
) | |
async def on_startup() -> None: | |
engine = sqlalchemy_config.get_engine() | |
sessionmaker = async_sessionmaker(expire_on_commit=False) | |
async with engine.begin() as conn: | |
await conn.run_sync(UUIDBase.metadata.drop_all) | |
await conn.run_sync(UUIDBase.metadata.create_all) | |
async with sessionmaker(bind=engine) as db_session: | |
teachers = [ | |
Teacher(name="Mr Smith"), | |
Teacher(name="Mrs Jones"), | |
Teacher(name="Dr Clever"), | |
] | |
db_session.add_all( | |
[ | |
Student(name="Alice", teachers=[teachers[0], teachers[2]]), | |
Student(name="Bob", teachers=[teachers[0], teachers[1]]), | |
] | |
) | |
await db_session.commit() | |
class StudentRepository(SQLAlchemyAsyncRepository[Student]): | |
model_type = Student | |
class TeacherRepository(SQLAlchemyAsyncRepository[Teacher]): | |
model_type = Teacher | |
async def provide_students_repo(db_session: AsyncSession) -> StudentRepository: | |
return StudentRepository(session=db_session) | |
async def provide_teachers_repo(db_session: AsyncSession) -> TeacherRepository: | |
return TeacherRepository( | |
session=db_session, statement=select(Teacher).join(Student) | |
) | |
@get("/students") | |
async def get_students(students_repo: StudentRepository) -> list[Student]: | |
return await students_repo.list() | |
@get("/students/{student_id:uuid}/teachers") | |
async def get_teachers_for_student( | |
student_id: UUID, teachers_repo: TeacherRepository | |
) -> list[Teacher]: | |
# here I want to filter the list based on the joined student id, but under the hood | |
# the kwargs are assumed to be attributes of TeacherRepository.model_type which is | |
# Teacher in this case, so perhaps I have not set this up right | |
return await teachers_repo.list(student_id=student_id) | |
app = Litestar( | |
route_handlers=[get_students, get_teachers_for_student], | |
plugins=[SQLAlchemyPlugin(config=sqlalchemy_config)], | |
dependencies={ | |
"students_repo": Provide(provide_students_repo), | |
"teachers_repo": Provide(provide_teachers_repo), | |
}, | |
on_startup=[on_startup], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment