Skip to content

Instantly share code, notes, and snippets.

@yomajo
Last active August 29, 2024 08:32
Show Gist options
  • Save yomajo/8ae96b4421410f579b774ed7154cd280 to your computer and use it in GitHub Desktop.
Save yomajo/8ae96b4421410f579b774ed7154cd280 to your computer and use it in GitHub Desktop.
When hybrid properties does not cut it
'''
Hybrid properties can't take arguments. But sometimes you just wanna sort on computed values...
Standalone script, that illustrates example of Product.price and catalog to sort products based on client price
computed from client discount configuration
'''
import os
import random
from enum import Enum
from decimal import Decimal
from sqlalchemy import Select, String, Numeric, JSON, create_engine, select, case, asc, desc
from sqlalchemy.orm import sessionmaker, Mapped, declarative_base, mapped_column
from typing import Optional
Base = declarative_base()
DB_FNAME = 'example.db'
engine = create_engine(f'sqlite:///{DB_FNAME}')
Session = sessionmaker(bind=engine)
session = Session()
class ProdCat(Enum):
ELECTRONICS = 'electronics'
FURNITURE = 'furniture'
DRUG = 'drug'
OTHER = 'other'
class Product(Base):
__tablename__ = 'product'
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(40))
category: Mapped[str] = mapped_column(String(40))
price: Mapped[Decimal] = mapped_column(Numeric(precision=10, scale=2))
def client_price(self, cl_disc_config: dict[str, str]) -> Decimal:
'''returns client price with computed client discount'''
cat_disc_str: dict = cl_disc_config.get(self.category, '0')
disc_fraction = Decimal(cat_disc_str) / 100
return round(self.price * (1 - disc_fraction), 2)
def __repr__(self):
return f'Product(name={self.name}, category={self.category}, price={self.price})'
class Client(Base):
__tablename__ = 'client'
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(60))
disc_config: Mapped[Optional[dict | list]] = mapped_column(type_=JSON)
def __repr__(self):
return f'Client(name={self.name}, email={self.email})'
def setup_db():
if not os.path.exists(DB_FNAME):
Base.metadata.create_all(engine)
print('DB created')
populate_data()
else:
print('DB ready')
def populate_data():
'''adds 20 products, single client and his category-based discount config'''
for i in range(20):
session.add(Product(
name=f'prod_{i}',
category=random.choice(list(ProdCat)).value,
price=Decimal(random.randint(100, 10000))/100
))
disc_config = {
ProdCat.ELECTRONICS.value: '10',
ProdCat.FURNITURE.value: '5',
ProdCat.DRUG.value: '25',
ProdCat.OTHER.value: '0'
}
session.add(Client(
name='Milky Titz Inc.',
disc_config=disc_config
))
session.commit()
print('Data populated')
def apply_client_price_sort(stmt: Select, disc_config: dict[str, str] | None, sort_asc: bool = True) -> Select:
'''applies custom sort on computed client price based on client disc_config
Args:
- stmt - sqlalchemy select statement for SKU
- disc_config - client discount configuration
- sort_asc - sort direction'''
if disc_config is None:
return stmt.order_by(asc(Product.price)) if sort_asc else stmt.order_by(desc(Product.price))
cases = []
for category, disc_str in disc_config.items():
if disc_str != '0':
disc_fraction = Decimal(disc_str) / 100
cases.append((Product.category == category, Product.price * (1 - disc_fraction)))
clprice_expr = case(*cases, else_=Product.price)
sort_order = asc(clprice_expr) if sort_asc else desc(clprice_expr)
return stmt.order_by(sort_order)
def run():
setup_db()
# client discount config
disc_config = session.scalars(select(Client.disc_config)).first()
# get catalog sorted by desc client price
sober_catalog_stmt = select(Product).where(Product.category != ProdCat.DRUG.value)
sober_catalog_stmt = apply_client_price_sort(sober_catalog_stmt, disc_config, sort_asc=False)
catalog: list[Product] = session.scalars(sober_catalog_stmt).all()
for p in catalog:
cl_disc = disc_config.get(p.category, '0')
cl_price = p.client_price(disc_config)
print(f'{p.name} ({p.category}) price: {p.price}, -{cl_disc}% = {cl_price}')
if __name__ == '__main__':
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment