Skip to content

Instantly share code, notes, and snippets.

@paulwinex
Created March 8, 2024 17:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save paulwinex/63efeddf0b085ed9d432976484773741 to your computer and use it in GitHub Desktop.
Save paulwinex/63efeddf0b085ed9d432976484773741 to your computer and use it in GitHub Desktop.
from typing import List, Optional
import aiosqlite
from fastapi import FastAPI, HTTPException, Depends
from sqlalchemy import Column, Integer, String, Table, ForeignKey, select, UniqueConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_scoped_session, async_sessionmaker
from sqlalchemy.orm import relationship, mapped_column, Mapped, backref, DeclarativeBase, selectinload
from pydantic import BaseModel as BaseSchema, ConfigDict
from contextlib import asynccontextmanager
class Base(DeclarativeBase):
id = Column(Integer, primary_key=True)
class TagsLinkModel(Base):
__tablename__ = "tag_links"
__table_args__ = (UniqueConstraint("tag_left_id", "tag_right_id", name="tag_links_uc"),)
tag_left_id: Mapped[int] = mapped_column(ForeignKey("tags.id", ondelete="CASCADE"), nullable=False)
tag_right_id: Mapped[int] = mapped_column(ForeignKey("tags.id", ondelete="CASCADE"), nullable=False)
class TagModel(Base):
__tablename__ = "tags"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=True)
links: Mapped[list["TagModel"]] = relationship(
"TagModel",
secondary=TagsLinkModel.__table__,
primaryjoin=(TagsLinkModel.tag_left_id == id),
secondaryjoin=(TagsLinkModel.tag_right_id == id),
backref=backref("parent_links", lazy="selectin"),
lazy="selectin",
uselist=True,
)
class TagCreateSchema(BaseSchema):
name: str
links: Optional[List[int]] = None
class TagUpdateSchema(BaseSchema):
name: Optional[str] = None
links: Optional[List[int]] = None
class TagShortResponseSchema(BaseSchema):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
class TagResponseSchema(TagShortResponseSchema):
links: List[TagShortResponseSchema]
DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(DATABASE_URL, echo=False, future=True)
async_session = async_sessionmaker(
bind=engine,
autoflush=False,
autocommit=False,
expire_on_commit=False,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
app = FastAPI(lifespan=lifespan)
async def get_tags_by_id(tag_ids: list[int], session: AsyncSession):
query = (
select(TagModel)
.where(TagModel.id.in_(tag_ids))
.options(
selectinload(
TagModel.links,
)
)
)
result = await session.execute(query)
tags = result.scalars().all()
if not tags:
raise HTTPException(status_code=404, detail="Tag not found")
return tags
async def get_session():
async with async_session() as session:
yield session
@app.get("/tags", response_model=List[TagResponseSchema])
async def get_tags(session: AsyncSession = Depends(get_session)):
query = select(TagModel).options(
selectinload(
TagModel.links,
)
)
result = await session.execute(query)
tag = result.scalars().all()
return tag
@app.get("/tags/{tag_id}", response_model=TagResponseSchema)
async def get_tag(tag_id: int, session: AsyncSession = Depends(get_session)):
return (await get_tags_by_id([tag_id], session))[0]
@app.post("/tags", response_model=TagResponseSchema)
async def create_tag(tag_form: TagCreateSchema, session: AsyncSession = Depends(get_session)):
new_tag = TagModel(**tag_form.model_dump(exclude_unset=True))
session.add(new_tag)
await session.commit()
await session.refresh(new_tag)
return new_tag
@app.patch("/tags/{tag_id}", response_model=TagResponseSchema)
async def update_tag(tag_id: int, tag_data: TagUpdateSchema, session: AsyncSession = Depends(get_session)):
existing_tag = (await get_tags_by_id([tag_id], session))[0]
if not existing_tag:
raise HTTPException(status_code=404, detail="Tag not found")
data = tag_data.model_dump(exclude_unset=True)
if data:
for field, value in data.items():
if field == "links":
if value:
other_tags = await get_tags_by_id(value, session)
existing_tag.links = other_tags
else:
existing_tag.links = []
else:
setattr(existing_tag, field, value)
await session.commit()
await session.refresh(existing_tag)
return existing_tag
else:
raise HTTPException(status_code=400, detail="No Data to update")
@app.delete("/tags/{tag_id}")
async def delete_tag(tag_id: int, session: AsyncSession = Depends(get_session)):
tag = await session.get(TagModel, tag_id)
if not tag:
raise HTTPException(status_code=404, detail="Tag not found")
await session.delete(tag)
await session.commit()
return {"message": "Tag deleted"}
if __name__ == "__main__":
import uvicorn
uvicorn.run("self-m2m-relation:app", host="0.0.0.0", port=8001)
import requests
from pprint import pprint
url = "http://0.0.0.0:8001"
# create tags
for i in range(3):
requests.post(url + "/tags", json={"name": f"tag{i}"})
# make links
requests.patch(url + "/tags/1", json={"links": [2, 3]}).json()
# {'id': 1, 'name': 'tag0', 'links': [{'id': 2, 'name': 'tag1'}, {'id': 3, 'name': 'tag2'}]}
requests.patch(url + "/tags/2", json={"links": [3]}).json()
# {'id': 2, 'name': 'tag1', 'links': [{'id': 3, 'name': 'tag2'}]}
requests.patch(url + "/tags/3", json={"links": [1]}).json()
# {'id': 3, 'name': 'tag2', 'links': [{'id': 1, 'name': 'tag0'}]}
# request tag
requests.get(url + "/tags/1").json()
# {'id': 1, 'name': 'tag0', 'links': [{'id': 2, 'name': 'tag1'}, {'id': 3, 'name': 'tag2'}]}
# request tags
pprint(requests.get(url + "/tags").json())
# [
# {
# "id": 1,
# "name": "tag0",
# "links": [{"id": 2, "name": "tag1"}, {"id": 3, "name": "tag2"}],
# },
# {
# "id": 2,
# "name": "tag1",
# "links": [{"id": 3, "name": "tag2"}],
# },
# {
# "id": 3,
# "name": "tag2",
# "links": [{"id": 1, "name": "tag0"}],
# },
# ]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment