Last active
August 29, 2024 08:32
-
-
Save yomajo/8ae96b4421410f579b774ed7154cd280 to your computer and use it in GitHub Desktop.
When hybrid properties does not cut it
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Hybrid properties can't take arguments. But sometimes you just wanna sort on computed values... | |
Standalone script, that illustrates example of Product.price and catalog to sort products based on client price | |
computed from client discount configuration | |
''' | |
import os | |
import random | |
from enum import Enum | |
from decimal import Decimal | |
from sqlalchemy import Select, String, Numeric, JSON, create_engine, select, case, asc, desc | |
from sqlalchemy.orm import sessionmaker, Mapped, declarative_base, mapped_column | |
from typing import Optional | |
Base = declarative_base() | |
DB_FNAME = 'example.db' | |
engine = create_engine(f'sqlite:///{DB_FNAME}') | |
Session = sessionmaker(bind=engine) | |
session = Session() | |
class ProdCat(Enum): | |
ELECTRONICS = 'electronics' | |
FURNITURE = 'furniture' | |
DRUG = 'drug' | |
OTHER = 'other' | |
class Product(Base): | |
__tablename__ = 'product' | |
id: Mapped[int] = mapped_column(primary_key=True) | |
name: Mapped[str] = mapped_column(String(40)) | |
category: Mapped[str] = mapped_column(String(40)) | |
price: Mapped[Decimal] = mapped_column(Numeric(precision=10, scale=2)) | |
def client_price(self, cl_disc_config: dict[str, str]) -> Decimal: | |
'''returns client price with computed client discount''' | |
cat_disc_str: dict = cl_disc_config.get(self.category, '0') | |
disc_fraction = Decimal(cat_disc_str) / 100 | |
return round(self.price * (1 - disc_fraction), 2) | |
def __repr__(self): | |
return f'Product(name={self.name}, category={self.category}, price={self.price})' | |
class Client(Base): | |
__tablename__ = 'client' | |
id: Mapped[int] = mapped_column(primary_key=True) | |
name: Mapped[str] = mapped_column(String(60)) | |
disc_config: Mapped[Optional[dict | list]] = mapped_column(type_=JSON) | |
def __repr__(self): | |
return f'Client(name={self.name}, email={self.email})' | |
def setup_db(): | |
if not os.path.exists(DB_FNAME): | |
Base.metadata.create_all(engine) | |
print('DB created') | |
populate_data() | |
else: | |
print('DB ready') | |
def populate_data(): | |
'''adds 20 products, single client and his category-based discount config''' | |
for i in range(20): | |
session.add(Product( | |
name=f'prod_{i}', | |
category=random.choice(list(ProdCat)).value, | |
price=Decimal(random.randint(100, 10000))/100 | |
)) | |
disc_config = { | |
ProdCat.ELECTRONICS.value: '10', | |
ProdCat.FURNITURE.value: '5', | |
ProdCat.DRUG.value: '25', | |
ProdCat.OTHER.value: '0' | |
} | |
session.add(Client( | |
name='Milky Titz Inc.', | |
disc_config=disc_config | |
)) | |
session.commit() | |
print('Data populated') | |
def apply_client_price_sort(stmt: Select, disc_config: dict[str, str] | None, sort_asc: bool = True) -> Select: | |
'''applies custom sort on computed client price based on client disc_config | |
Args: | |
- stmt - sqlalchemy select statement for SKU | |
- disc_config - client discount configuration | |
- sort_asc - sort direction''' | |
if disc_config is None: | |
return stmt.order_by(asc(Product.price)) if sort_asc else stmt.order_by(desc(Product.price)) | |
cases = [] | |
for category, disc_str in disc_config.items(): | |
if disc_str != '0': | |
disc_fraction = Decimal(disc_str) / 100 | |
cases.append((Product.category == category, Product.price * (1 - disc_fraction))) | |
clprice_expr = case(*cases, else_=Product.price) | |
sort_order = asc(clprice_expr) if sort_asc else desc(clprice_expr) | |
return stmt.order_by(sort_order) | |
def run(): | |
setup_db() | |
# client discount config | |
disc_config = session.scalars(select(Client.disc_config)).first() | |
# get catalog sorted by desc client price | |
sober_catalog_stmt = select(Product).where(Product.category != ProdCat.DRUG.value) | |
sober_catalog_stmt = apply_client_price_sort(sober_catalog_stmt, disc_config, sort_asc=False) | |
catalog: list[Product] = session.scalars(sober_catalog_stmt).all() | |
for p in catalog: | |
cl_disc = disc_config.get(p.category, '0') | |
cl_price = p.client_price(disc_config) | |
print(f'{p.name} ({p.category}) price: {p.price}, -{cl_disc}% = {cl_price}') | |
if __name__ == '__main__': | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment