Skip to content

Instantly share code, notes, and snippets.

@pawl
Last active December 28, 2017 15:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pawl/df5ba8923d9929dd1f4fc4e683eced40 to your computer and use it in GitHub Desktop.
Save pawl/df5ba8923d9929dd1f4fc4e683eced40 to your computer and use it in GitHub Desktop.
Example of a custom "IN()" relationship loading strategy in sqlalchemy
from collections import defaultdict
from sqlalchemy import create_engine, Column, ForeignKey, Integer
from sqlalchemy.orm import relationship, scoped_session, sessionmaker
from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.ext.declarative import declarative_base
engine = create_engine('mysql://root@localhost/test?charset=utf8mb4',
convert_unicode=True,
echo=True)
session = scoped_session(sessionmaker(autocommit=False,
autoflush=False,
bind=engine))
Base = declarative_base()
Base.query = session.query_property()
class Post(Base):
__tablename__ = 'posts'
id = Column(Integer, primary_key=True)
products = relationship('Product', lazy='raise_on_sql', backref='post')
class Product(Base):
__tablename__ = 'products'
id = Column(Integer, primary_key=True)
post_id = Column(Integer, ForeignKey('posts.id'), index=True)
links = relationship('ProductLink', lazy='raise_on_sql')
class ProductLink(Base):
__tablename__ = 'links'
id = Column(Integer, primary_key=True)
product_id = Column(Integer, ForeignKey('products.id'), index=True)
#Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
# create new rows if database is empty
first_result = Post.query.first()
if not first_result:
for x in range(50):
products = [
Product(links=[ProductLink() for link in range(8)])
for y in range(12)
]
session.add(Post(products=products))
session.commit()
def products_loader(posts):
post_ids = {post.id for post in posts}
if not post_ids:
return posts
products = session.query(Product).filter(
Product.post_id.in_(post_ids)
).all()
# group products by Post.id
products_by_post_id = defaultdict(list)
for product in products:
products_by_post_id[product.post_id].append(product)
# add queried products to the post model
for post in posts:
post_products = products_by_post_id.get(post.id, [])
set_committed_value(post, 'products', post_products)
return posts
def links_loader(posts):
product_ids = set()
for post in posts:
for product in post.products:
product_ids.add(product.id)
if not product_ids:
return posts
links = session.query(ProductLink).filter(
ProductLink.product_id.in_(product_ids)
).all()
# group links by Product.id
links_by_product_id = defaultdict(list)
for link in links:
links_by_product_id[link.product_id].append(link)
# add queried links to the Product models
for post in posts:
for product in post.products:
post_links = links_by_product_id.get(product.id, [])
set_committed_value(product, 'links', post_links)
return posts
posts = Post.query.limit(20).all()
posts = products_loader(posts)
posts = links_loader(posts)
for post in posts:
for product in post.products:
print(product.id)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment