Skip to content

Instantly share code, notes, and snippets.

@paulwinex
Last active February 15, 2024 21:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save paulwinex/6e9c53750774a233701adef4edd1d6dd to your computer and use it in GitHub Desktop.
Save paulwinex/6e9c53750774a233701adef4edd1d6dd to your computer and use it in GitHub Desktop.
from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Table
from sqlalchemy.orm import sessionmaker, relationship, backref, declared_attr, Mapped
from sqlalchemy.orm import declarative_base, Mapped, mapped_column
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql.expression import and_
from sqlalchemy import select, func, union_all, text
import asyncio
from sqlalchemy.ext.asyncio import (
AsyncSession,
create_async_engine,
async_sessionmaker,
async_scoped_session,
)
BaseModel = declarative_base()
DATABASE_URL = "sqlite+aiosqlite:///:memory:"
class TaggableMixin:
@declared_attr
def tags(cls) -> Mapped[list["TagModel"]]:
return relationship(
"TagModel",
secondary=TagLinks.__table__,
primaryjoin=lambda: and_(
cls.id == TagLinks.model_id,
TagLinks.model_type == cls.__tablename__,
),
secondaryjoin=lambda: and_(
TagModel.id == TagLinks.tag_id,
TagLinks.model_type == cls.__tablename__,
),
overlaps="tags",
lazy="selectin",
uselist=True,
)
class TagModel(BaseModel):
__tablename__ = "tags"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(unique=True, nullable=False)
parent_id: Mapped[int] = mapped_column(ForeignKey("tags.id"), nullable=True)
path: Mapped[str] = mapped_column(nullable=True)
children: Mapped[list["TagModel"]] = relationship(
"TagModel",
backref=backref("parent", remote_side=[id]),
lazy="selectin",
uselist=True,
)
async def update_path(self, session):
if self.parent_id is None:
self.path = f"/{self.id}/" # type: ignore
else:
parent = await session.get(TagModel, self.parent_id)
self.path = f"{parent.path}{self.id}/" # type: ignore
def __str__(self):
return f"{self.name} [{self.id}]"
def __repr__(self):
return f"<{self.name} {self.path}>"
class TagLinks(BaseModel):
__tablename__ = "tag_links"
id: Mapped[int] = mapped_column(primary_key=True)
tag_id: Mapped[int] = mapped_column(ForeignKey("tags.id"))
model_id: Mapped[int]
model_type: Mapped[str] = mapped_column(String(50))
def __str__(self):
return f"{model_type}.{self.id} > TagModel.{self.tag_id}"
def __repr__(self):
return f"<TagLinks {self.__str__()}>"
class AssetModel(TaggableMixin, BaseModel):
__tablename__ = "asset"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(index=True, unique=True, nullable=False)
label: Mapped[str] = mapped_column(index=True, nullable=False)
def __str__(self):
return f"{self.name} [{self.id}]"
def __repr__(self):
return f"<{self.__str__()}>"
class ProjectModel(TaggableMixin, BaseModel):
__tablename__ = "projects"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False, unique=True, index=True)
label: Mapped[str] = mapped_column(nullable=False, unique=True)
def __str__(self):
return f"{self.name} [{self.id}]"
def __repr__(self):
return f"<{self.__str__()}>"
engine = create_async_engine(url=DATABASE_URL, echo=False, future=True)
session_factory = async_sessionmaker(
bind=engine,
autoflush=False,
autocommit=False,
expire_on_commit=False,
)
async def create_tables() -> None:
async with engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.create_all)
async def add_tag(session, name, parent_id=None):
tag = TagModel(name=name, parent_id=parent_id)
session.add(tag)
await session.flush()
await tag.update_path(session)
await session.commit()
return tag
async def add_object_with_tags(session, model_class, name, label, *tags):
obj = model_class(name=name, label=label)
session.add(obj)
await session.flush()
for tag in tags:
tag_link = TagLinks(
tag_id=tag.id, model_id=obj.id, model_type=model_class.__tablename__
)
session.add(tag_link)
await session.commit()
return obj
async def get_objects_by_tag(session, model_class, tag_name):
tag_query = select(TagModel).where(TagModel.name == tag_name)
tag_result = await session.execute(tag_query)
tag = tag_result.scalar()
if tag is None:
return []
objects_query = (
select(model_class).join(model_class.tags).where(TagModel.id == tag.id)
)
objects_result = await session.execute(objects_query)
objects = objects_result.scalars().all()
return objects
async def find_objects_by_tag_and_descendants(session, model_class, tag_name):
tag_query = select(TagModel).where(TagModel.name == tag_name)
tag_result = await session.execute(tag_query)
tag = tag_result.scalar()
if tag is None:
return []
tag_ids_query = select(TagModel.id).where(TagModel.path.like(f"{tag.path}%"))
tag_ids_result = await session.execute(tag_ids_query)
tag_ids = [t[0] for t in tag_ids_result.all()]
objects_query = (
select(model_class).join(model_class.tags).where(TagModel.id.in_(tag_ids))
)
objects_result = await session.execute(objects_query)
objects = objects_result.scalars().all()
return objects
async def get_all_tags(session):
query = select(TagModel)
result = await session.execute(query)
return result.scalars().all()
async def show_tags(session):
tags = await get_all_tags(session)
paths = sorted([(tag.path, tag.name) for tag in tags], key=lambda x: x[0])
print("Paths:")
for tag in paths:
objects_tag = await find_objects_by_tag_and_descendants(
session, AssetModel, tag[1]
)
print("-" * (tag[0].count("/") - 1), tag[1], tag[0], ">", objects_tag)
async def add_tag_to_object(session, tag_name: str | TagModel, instance):
if isinstance(tag_name, TagModel):
tag = tag_name
else:
tag_query = select(TagModel).where(TagModel.name == tag_name)
tag_result = await session.execute(tag_query)
tag = tag_result.scalar()
if tag is None:
raise TagNotFoundError(tag_name)
tag_link = TagLinks(
tag_id=tag.id, model_id=instance.id, model_type=instance.__class__.__tablename__
)
session.add(tag_link)
await session.commit()
await session.refresh(instance)
return instance
async def testing1():
await create_tables()
async with session_factory() as session:
tag0 = await add_tag(session, "root")
tag1 = await add_tag(session, "tag1", parent_id=tag0.id)
tag11 = await add_tag(session, "tag1-1", parent_id=tag1.id)
tag12 = await add_tag(session, "tag1-2", parent_id=tag1.id)
tag2 = await add_tag(session, "tag2", parent_id=tag0.id)
tag21 = await add_tag(session, "tag2-1", parent_id=tag12.id)
project1 = await add_object_with_tags(
session, ProjectModel, "project1", "Project 1", tag0
)
asset1 = await add_object_with_tags(
session, AssetModel, "asset1", "Asset 1", tag1
)
asset2 = await add_object_with_tags(
session, AssetModel, "asset2", "Asset 2", tag2
)
asset3 = await add_object_with_tags(
session, AssetModel, "asset21", "Asset 21", tag21
)
asset4 = await add_object_with_tags(
session, AssetModel, "asset12", "Asset 12", tag12
)
# Получение объектов по тегу
objects_tag1 = await get_objects_by_tag(session, ProjectModel, "root")
print(f"ProjectModels with tag 'root': {objects_tag1}")
objects_tag2 = await get_objects_by_tag(session, AssetModel, "tag1")
print(f"AssetModels with tag 'tag1': {objects_tag2}")
objects_tag3 = await get_objects_by_tag(session, AssetModel, "tag2")
print(f"AssetModels with tag 'tag2': {objects_tag3}")
objects_tag2_1 = await get_objects_by_tag(session, AssetModel, "tag2-1")
print(f"AssetModels with tag 'tag2_1': {objects_tag2_1}")
objects_tag2 = await get_objects_by_tag(session, AssetModel, "tag2")
print(f"AssetModels with tag 'tag2': {objects_tag2}")
objects_tag2_1 = await get_objects_by_tag(session, AssetModel, "tag1-2")
print(f"AssetModels with tag 'tag1-2': {objects_tag2_1}")
hi_objects = await find_objects_by_tag_and_descendants(
session, AssetModel, "tag1"
)
print("Hi objects for tag1", hi_objects)
await show_tags(session)
async def testing2():
await create_tables()
async with session_factory() as session:
# create tags
tag0 = await add_tag(session, "root")
tag1 = await add_tag(session, "tag1", parent_id=tag0.id)
tag2 = await add_tag(session, "tag2", parent_id=tag1.id)
tag3 = await add_tag(session, "tag3", parent_id=tag2.id)
asset1 = await add_object_with_tags(
session, AssetModel, "asset1", "Asset 1", tag3
)
objects_tag3 = await get_objects_by_tag(session, AssetModel, "tag3")
print(f"ProjectModels with tag 'tag3': {objects_tag3}")
async def testing3():
async with session_factory() as session:
await create_tables()
# Создаем теги
tag1 = await add_tag(session, "Tag1")
tag2 = await add_tag(session, "Tag2", tag1.id)
tag3 = await add_tag(session, "Tag3", tag2.id)
# Создаем объекты с тегами
asset1 = await add_object_with_tags(
session, AssetModel, "Asset1", "Label1", tag1, tag2
)
asset2 = await add_object_with_tags(
session, AssetModel, "Asset2", "Label2", tag3
)
# Получаем объекты по тегу
assets_by_tag1 = await get_objects_by_tag(session, AssetModel, "Tag1")
print("Assets by Tag1:", assets_by_tag1)
# Получаем объекты по тегу и всем его потомкам
assets_by_tag1_and_descendants = await find_objects_by_tag_and_descendants(
session, AssetModel, "Tag1"
)
print("Assets by Tag1 and its descendants:", assets_by_tag1_and_descendants)
# Создаем проект с тегами
project1 = await add_object_with_tags(
session, ProjectModel, "Project1", "Project Label1", tag1, tag3
)
# Получаем проекты по тегу
projects_by_tag1 = await get_objects_by_tag(session, ProjectModel, "Tag1")
print("Projects by Tag1:", projects_by_tag1)
# Получаем проекты по тегу и всем его потомкам
projects_by_tag1_and_descendants = await find_objects_by_tag_and_descendants(
session, ProjectModel, "Tag1"
)
print("Projects by Tag1 and its descendants:", projects_by_tag1_and_descendants)
print("-" * 10)
query = select(AssetModel).where(AssetModel.id == asset2.id)
res = await session.execute(query)
asset = res.scalar()
print("Asset Tags 1", asset.tags)
asset = await add_tag_to_object(session, tag1, asset)
print("Asset Tags 2", asset.tags)
await show_tags(session)
if __name__ == "__main__":
asyncio.run(testing1())
# asyncio.run(testing2())
# asyncio.run(testing3())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment