Created March 8, 2024 17:25
from typing import List, Optional
import aiosqlite
from fastapi import FastAPI, HTTPException, Depends
from sqlalchemy import Column, Integer, String, Table, ForeignKey, select, UniqueConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_scoped_session, async_sessionmaker
from sqlalchemy.orm import relationship, mapped_column, Mapped, backref, DeclarativeBase, selectinload
from pydantic import BaseModel as BaseSchema, ConfigDict
from contextlib import asynccontextmanager
class Base(DeclarativeBase):
id = Column(Integer, primary_key=True)
class TagsLinkModel(Base):
__tablename__ = "tag_links"
__table_args__ = (UniqueConstraint("tag_left_id", "tag_right_id", name="tag_links_uc"),)
tag_left_id: Mapped[int] = mapped_column(ForeignKey("", ondelete="CASCADE"), nullable=False)
tag_right_id: Mapped[int] = mapped_column(ForeignKey("", ondelete="CASCADE"), nullable=False)
class TagModel(Base):
__tablename__ = "tags"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=True)
links: Mapped[list["TagModel"]] = relationship(
primaryjoin=(TagsLinkModel.tag_left_id == id),
secondaryjoin=(TagsLinkModel.tag_right_id == id),
backref=backref("parent_links", lazy="selectin"),
class TagCreateSchema(BaseSchema):
name: str
links: Optional[List[int]] = None
class TagUpdateSchema(BaseSchema):
name: Optional[str] = None
links: Optional[List[int]] = None
class TagShortResponseSchema(BaseSchema):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
class TagResponseSchema(TagShortResponseSchema):
links: List[TagShortResponseSchema]
DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(DATABASE_URL, echo=False, future=True)
async_session = async_sessionmaker(
async def lifespan(app: FastAPI):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
app = FastAPI(lifespan=lifespan)
async def get_tags_by_id(tag_ids: list[int], session: AsyncSession):
query = (
result = await session.execute(query)
tags = result.scalars().all()
if not tags:
raise HTTPException(status_code=404, detail="Tag not found")
return tags
async def get_session():
async with async_session() as session:
yield session
@app.get("/tags", response_model=List[TagResponseSchema])
async def get_tags(session: AsyncSession = Depends(get_session)):
query = select(TagModel).options(
result = await session.execute(query)
tag = result.scalars().all()
return tag
@app.get("/tags/{tag_id}", response_model=TagResponseSchema)
async def get_tag(tag_id: int, session: AsyncSession = Depends(get_session)):
return (await get_tags_by_id([tag_id], session))[0]"/tags", response_model=TagResponseSchema)
async def create_tag(tag_form: TagCreateSchema, session: AsyncSession = Depends(get_session)):
new_tag = TagModel(**tag_form.model_dump(exclude_unset=True))
await session.commit()
await session.refresh(new_tag)
return new_tag
@app.patch("/tags/{tag_id}", response_model=TagResponseSchema)
async def update_tag(tag_id: int, tag_data: TagUpdateSchema, session: AsyncSession = Depends(get_session)):
existing_tag = (await get_tags_by_id([tag_id], session))[0]
if not existing_tag:
raise HTTPException(status_code=404, detail="Tag not found")
data = tag_data.model_dump(exclude_unset=True)
if data:
for field, value in data.items():
if field == "links":
if value:
other_tags = await get_tags_by_id(value, session)
existing_tag.links = other_tags
existing_tag.links = []
setattr(existing_tag, field, value)
await session.commit()
await session.refresh(existing_tag)
return existing_tag
raise HTTPException(status_code=400, detail="No Data to update")
async def delete_tag(tag_id: int, session: AsyncSession = Depends(get_session)):
tag = await session.get(TagModel, tag_id)
if not tag:
raise HTTPException(status_code=404, detail="Tag not found")
await session.delete(tag)
await session.commit()
return {"message": "Tag deleted"}
if __name__ == "__main__":
import uvicorn"self-m2m-relation:app", host="", port=8001)
import requests
from pprint import pprint
url = ""
# create tags
for i in range(3): + "/tags", json={"name": f"tag{i}"})
# make links
requests.patch(url + "/tags/1", json={"links": [2, 3]}).json()
# {'id': 1, 'name': 'tag0', 'links': [{'id': 2, 'name': 'tag1'}, {'id': 3, 'name': 'tag2'}]}
requests.patch(url + "/tags/2", json={"links": [3]}).json()
# {'id': 2, 'name': 'tag1', 'links': [{'id': 3, 'name': 'tag2'}]}
requests.patch(url + "/tags/3", json={"links": [1]}).json()
# {'id': 3, 'name': 'tag2', 'links': [{'id': 1, 'name': 'tag0'}]}
# request tag
requests.get(url + "/tags/1").json()
# {'id': 1, 'name': 'tag0', 'links': [{'id': 2, 'name': 'tag1'}, {'id': 3, 'name': 'tag2'}]}
# request tags
pprint(requests.get(url + "/tags").json())
# [
# {
# "id": 1,
# "name": "tag0",
# "links": [{"id": 2, "name": "tag1"}, {"id": 3, "name": "tag2"}],
# },
# {
# "id": 2,
# "name": "tag1",
# "links": [{"id": 3, "name": "tag2"}],
# },
# {
# "id": 3,
# "name": "tag2",
# "links": [{"id": 1, "name": "tag0"}],
# },
# ]
