Skip to content

Instantly share code, notes, and snippets.

@mahenzon
Last active March 8, 2024 15:42
Show Gist options
  • Save mahenzon/b19cc43f1cdd20863a2bae44aeb42571 to your computer and use it in GitHub Desktop.
Save mahenzon/b19cc43f1cdd20863a2bae44aeb42571 to your computer and use it in GitHub Desktop.
SQLAlchemy m2m to self example
from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, select
from sqlalchemy.orm import sessionmaker, relationship, DeclarativeBase, selectinload
# DB_URL = "sqlite:///tags.db"
DB_URL = "sqlite:///:memory:"
DB_ECHO = False
class Base(DeclarativeBase):
id = Column(Integer, primary_key=True)
# Define the Tag model
class Tag(Base):
__tablename__ = "tags"
name = Column(String, nullable=False, unique=True)
related_tags = relationship(
"Tag",
secondary="tag_to_tag_association",
primaryjoin="Tag.id==TagToTagAssociation.left_tag_id",
secondaryjoin="Tag.id==TagToTagAssociation.right_tag_id",
# backref="related_by",
)
def __str__(self):
return self.name
def __repr__(self):
return f"{self.__class__.__name__}(id={self.id}, name={self.name!r})"
# Define the TagToTagAssociation model
class TagToTagAssociation(Base):
__tablename__ = "tag_to_tag_association"
left_tag_id = Column(Integer, ForeignKey("tags.id"), nullable=False)
right_tag_id = Column(Integer, ForeignKey("tags.id"), nullable=False)
def main():
# Create an SQLAlchemy engine and session
engine = create_engine(url=DB_URL, echo=DB_ECHO)
Base.metadata.create_all(engine)
session_factory = sessionmaker(bind=engine)
session = session_factory()
# create tags
tag_main = Tag(name="main")
tag_a = Tag(name="a")
tag_b = Tag(name="b")
tag_c = Tag(name="c")
session.add_all(
[
tag_main,
tag_a,
tag_b,
tag_c,
]
)
# assign ids to tags
session.flush()
t2t_m2a = TagToTagAssociation(
left_tag_id=tag_main.id,
right_tag_id=tag_a.id,
)
t2t_m2b = TagToTagAssociation(
left_tag_id=tag_main.id,
right_tag_id=tag_b.id,
)
t2t_a2c = TagToTagAssociation(
left_tag_id=tag_a.id,
right_tag_id=tag_c.id,
)
session.add_all(
[
t2t_m2a,
t2t_m2b,
t2t_a2c,
]
)
session.commit()
# Example query to fetch a tag with all its related tags
tag_main_with_related: Tag | None = session.scalar(
select(Tag)
.where(Tag.id == tag_main.id)
.options(
selectinload(
Tag.related_tags,
)
)
)
assert tag_main_with_related
print(tag_main_with_related)
# output: main
print(tag_main_with_related.related_tags)
# output: [Tag(id=2, name='a'), Tag(id=3, name='b')]
session.close()
if __name__ == "__main__":
main()
SQLAlchemy==2.0.28
typing_extensions==4.10.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment