Skip to content

Instantly share code, notes, and snippets.

@tcbegley
Last active January 18, 2024 22:03
Show Gist options
  • Save tcbegley/38267d40130f3633528ce4b00903652e to your computer and use it in GitHub Desktop.
Save tcbegley/38267d40130f3633528ce4b00903652e to your computer and use it in GitHub Desktop.
from __future__ import annotations
from uuid import UUID
from litestar import Litestar, get
from litestar.contrib.sqlalchemy.base import UUIDBase
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO, SQLAlchemyDTOConfig
from litestar.contrib.sqlalchemy.plugins import (
AsyncSessionConfig,
SQLAlchemyAsyncConfig,
SQLAlchemyPlugin,
)
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
from litestar.di import Provide
from sqlalchemy import Column, ForeignKey, Table, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import Mapped, relationship, selectinload
association_table = Table(
"association_table",
UUIDBase.metadata,
Column("student_id", ForeignKey("student.id")),
Column("teacher", ForeignKey("teacher.id")),
)
class Student(UUIDBase):
name: Mapped[str]
teachers: Mapped[list[Teacher]] = relationship(
secondary=association_table, lazy="selectin"
)
class Teacher(UUIDBase):
name: Mapped[str]
sqlalchemy_config = SQLAlchemyAsyncConfig(
connection_string="sqlite+aiosqlite:///test.sqlite",
session_config=AsyncSessionConfig(expire_on_commit=False),
)
async def on_startup() -> None:
engine = sqlalchemy_config.get_engine()
sessionmaker = async_sessionmaker(expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(UUIDBase.metadata.drop_all)
await conn.run_sync(UUIDBase.metadata.create_all)
async with sessionmaker(bind=engine) as db_session:
teachers = [
Teacher(name="Mr Smith"),
Teacher(name="Mrs Jones"),
Teacher(name="Dr Clever"),
]
db_session.add_all(
[
Student(name="Alice", teachers=[teachers[0], teachers[2]]),
Student(name="Bob", teachers=[teachers[0], teachers[1]]),
]
)
await db_session.commit()
class StudentRepository(SQLAlchemyAsyncRepository[Student]):
model_type = Student
class TeacherRepository(SQLAlchemyAsyncRepository[Teacher]):
model_type = Teacher
async def provide_students_repo(db_session: AsyncSession) -> StudentRepository:
return StudentRepository(session=db_session)
async def provide_teachers_repo(db_session: AsyncSession) -> TeacherRepository:
return TeacherRepository(
session=db_session, statement=select(Teacher).join(Student)
)
@get("/students")
async def get_students(students_repo: StudentRepository) -> list[Student]:
return await students_repo.list()
@get("/students/{student_id:uuid}/teachers")
async def get_teachers_for_student(
student_id: UUID, teachers_repo: TeacherRepository
) -> list[Teacher]:
# here I want to filter the list based on the joined student id, but under the hood
# the kwargs are assumed to be attributes of TeacherRepository.model_type which is
# Teacher in this case, so perhaps I have not set this up right
return await teachers_repo.list(student_id=student_id)
app = Litestar(
route_handlers=[get_students, get_teachers_for_student],
plugins=[SQLAlchemyPlugin(config=sqlalchemy_config)],
dependencies={
"students_repo": Provide(provide_students_repo),
"teachers_repo": Provide(provide_teachers_repo),
},
on_startup=[on_startup],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment