Skip to content

Instantly share code, notes, and snippets.

@funseiki
Created June 19, 2018 17:50
Show Gist options
  • Save funseiki/8e5ccd5cf2be19c6b698fbdcf81ae7d6 to your computer and use it in GitHub Desktop.
Save funseiki/8e5ccd5cf2be19c6b698fbdcf81ae7d6 to your computer and use it in GitHub Desktop.
Multiple polymorphic identities
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.associationproxy import association_proxy
from utils import deco, hashFunc, queryFunc
metadata = MetaData()
fruitTypesTable = Table(
'FruitTypes', metadata,
Column('type_id', Integer, primary_key=True),
Column('name', String(256))
)
fruitsTable = Table(
'Fruits', metadata,
Column('fruit_id', Integer, primary_key=True),
Column('type_id', Integer, ForeignKey(fruitTypesTable.c.type_id)),
Column('name', String(256)),
Column('sweetness', Integer)
)
applesTable = Table(
'Apples', metadata,
Column('fruit_id', Integer, ForeignKey(fruitsTable.c.fruit_id), primary_key=True),
Column('appleyness', Integer)
)
from prism.core.component.Component import Component
class FruitType(Component):
def __init__(self, *args, **kwds):
for key in kwds:
setattr(self, key, kwds[key])
pass
return
def __repr__(self):
return "{}<{}>".format(self.__class__.__name__, self.__dict__)
class Fruit(Component):
def __init__(self, *args, **kwds):
for key in kwds:
setattr(self, key, kwds[key])
pass
return
def __repr__(self):
return "{}<{}>".format(self.__class__.__name__, self.__dict__)
class Apple(Fruit): pass
class Orange(Fruit): pass
engine = create_engine('sqlite://')
metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
# FruitTypes
mapper(FruitType, fruitTypesTable)
# Ensure that on construction we check the database for an existing row
meth = deco(session, hashFunc, queryFunc)
meth(FruitType)
# Add an apple type
appleType = FruitType(id=1, name='apple')
# Add an orange type
orangeType = FruitType(id=2, name='orange')
session.add(appleType)
session.add(orangeType)
session.commit()
# What should this one's polymorphic_identity be?
mapper(Fruit, fruitsTable,
polymorphic_on='type_id',
properties={
"typeObject": relationship(FruitType)
})
Fruit.fruitType = association_proxy("typeObject", 'name')
mapper(Apple, applesTable,
inherits=Fruit,
polymorphic_identity=1)
mapper(Orange,
inherits=Fruit,
polymorphic_identity=2)
delicious = Apple(name="red delicious", appleyness=1, sweetness=5)
session.add(delicious)
session.commit()
# TODO: Figure out how to instantiate just a Fruit with any polymorphic identity
f = Fruit(name="hello", fruitType="poop", sweetness=0)
session.add(f)
session.commit()
print("Printing fruits")
fruits = session.query(Fruit).all()
for fruit in fruits:
print fruit
def unique(session, cls, hashfunc, queryfunc, constructor, arg, kw):
"""
Use this method to ensure that when instantiating an object, it is unique
It is a getOrCreate implementation
session: SqlAlchemy session
cls: The model class
hashfunc: A way to determine uniqueness. E.g. lambda name: name
queryFunc: Query for finding a value based on its attribute. E.g. Widget.name==name
constructor: Constructor for this class
arg: positional arguments for hashfunc, queryfunc, and constructor
kw: keyword arguments for hashfunc, queryfunc, and constructor
"""
cache = getattr(session, '_unique_cache', None)
if cache is None:
session._unique_cache = cache = {}
key = (cls, hashfunc(*arg, **kw))
if key in cache:
return cache[key]
else:
with session.no_autoflush:
q = session.query(cls)
q = queryfunc(q, *arg, **kw)
obj = q.first()
if not obj:
obj = constructor(*arg, **kw)
session.add(obj)
cache[key] = obj
return obj
def deco(session, hashfunc, queryfunc):
def decorate(cls):
def noop(self, *arg, **kwds):
pass
def __new__(cls, bases, *arg, **kwds):
if not arg and not kwds:
return object.__new__(cls)
def constructor(*arg, **kwds):
obj = object.__new__(cls)
obj._init(*arg, **kwds)
return obj
return unique(session, cls, hashfunc, queryfunc, constructor, arg, kwds)
cls._init = cls.__init__
cls.__init__ = noop
cls.__new__ = classmethod(__new__)
return cls
return decorate
def hashFunc(*args, **kwds):
return kwds.pop('name', args[0] if len(args) else None)
def queryFunc(query, *args, **kwds):
val = kwds.pop('name', args[0] if len(args) else None)
return query.filter(FruitType.name == val)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment