Skip to content

Instantly share code, notes, and snippets.

@wisedier
Created November 20, 2019 06:43
Show Gist options
  • Save wisedier/ddde0466bc303b05621334b4eaad1315 to your computer and use it in GitHub Desktop.
Save wisedier/ddde0466bc303b05621334b4eaad1315 to your computer and use it in GitHub Desktop.
SQLAlachemy eager loading in CTE recursive query for a self-relational model which is a declarative model
class Network(Base, IdMixin, TimestampMixin):
ip_address = sa.Column(IPAddressType, nullable=False)
netmask = sa.Column(IPAddressType, nullable=False)
name = sa.Column(sa.String(64), nullable=False, default=lambda: str(uuid.uuid4()), unique=True)
class Host(Base, IdMixin, TimestampMixin):
name = sa.Column(sa.String(64), nullable=False)
addr = sa.Column(sa.String(256), nullable=False)
is_compromised = sa.Column(sa.Boolean, nullable=False, default=False)
meta = sa.Column(JSONB)
class NetworkHost(Base, IdMixin):
network_id = sa.Column(
sa.Integer,
sa.ForeignKey(Network.id, ondelete='CASCADE'),
nullable=False,
)
network = orm.relationship(Network)
host_id = sa.Column(sa.Integer, sa.ForeignKey(Host.id, ondelete='CASCADE'), nullable=False)
host = orm.relationship(Host)
@declared_attr
def parent_id(self):
return sa.Column(
sa.Integer,
sa.ForeignKey(f'{self.__tablename__}.id', ondelete='SET NULL'),
doc='When it is null, it means TEP agent is running in the host',
)
@declared_attr
def parent(self):
return orm.relationship(
self,
remote_side=self.id,
backref=orm.backref('children', cascade='all'),
)
from sqlalchemy import literal, orm
from db.base import Session
from db.models import Host, Network, NetworkHost
hierarchy = (NetworkHost.query
.filter(NetworkHost.parent_id.is_(None))
.with_entities(NetworkHost, literal(0).label('level'))
.cte(name="hierarchy", recursive=True))
parent = orm.aliased(hierarchy, name="p")
children = orm.aliased(NetworkHost, name="c")
hierarchy = (hierarchy
.union_all(children.query
.filter(children.parent_id == parent.c.id)
.with_entities(children, (parent.c.level + 1).label('level'))))
h = (hierarchy
.join(Network, Network.__table__.c.id == hierarchy.c.network_id)
.join(Host, Host.__table__.c.id == hierarchy.c.host_id))
edges = (Session.query(h)
.select_entity_from(h)
.with_entities(hierarchy.c.id.label('id'),
hierarchy.c.parent_id.label('parent_id'),
Network.id.label('network_id'),
Host.id.label('host_id'),
hierarchy.c.level.label('level')))
for obj in edges:
print(obj.id, obj.parent_id, obj.network_id, obj.host_id, obj.level)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment