Skip to content

Instantly share code, notes, and snippets.

@storborg
Created April 15, 2009 04:47
Show Gist options
  • Save storborg/95607 to your computer and use it in GitHub Desktop.
Save storborg/95607 to your computer and use it in GitHub Desktop.
"""
a polymorphic many-to-many association with referential integrity and clean
collection access! (at least in one direction) sqlalchemy is the shit.
"""
__author__ = 'scott torborg (scotttorborg.com)'
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.associationproxy import _AssociationList
class ProxyList(_AssociationList):
"""
An iterable which proxies access to a given attribute of a wrapped list.
"""
def __init__(self, creator, col, attr_name):
"""
creator
When new items are added to this proxied list, new instances of the
intermediate class must be created. This function is used, called
with the new proxied value.
col
Collection we are proxying across.
attr
Attribute to use when proxying across intermediate objects.
"""
self.creator = creator
self._col = col
self.attr_name = attr_name
def lazy_collection(self):
return self._col
def _create(self, value):
"""
Override this method to control how new intermediate objects are
created when items are appended to the list: for example, if we need
to access self while we are creating new objects.
"""
return self.creator(value)
def _get(self, object):
"""
Override this method to control how object reads is proxied.
"""
return getattr(object, self.attr_name)
def _set(self, object, value):
"""
Override this method to control how object writes are proxied.
"""
return setattr(object, self.attr_name, value)
metadata = MetaData('sqlite://')
Base = declarative_base(metadata=metadata)
class _AssocLeft(Base):
__tablename__ = 'assoc_left'
assoc_id = Column(Integer, primary_key=True)
type = Column(String(50))
assoc_rights = relation('_AssocRight')
def __init__(self, type):
self.type = type
class _AssocRight(Base):
__tablename__ = 'assoc_right'
id = Column(Integer, primary_key=True)
assoc_id = Column(None, ForeignKey('assoc_left.assoc_id'))
tag_id = Column(None, ForeignKey('tags.id'))
tag = relation('Tag', backref='assoc_rights')
assoc_left = relation('_AssocLeft')
def __init__(self, tag):
self.tag = tag
def taggable(cls, name):
"""taggable interface. gives the cls a .tags property"""
cls.assoc_id = Column(None, ForeignKey('assoc_left.assoc_id'))
mapper = class_mapper(cls)
table = mapper.local_table
mapper.add_property('assoc_left', relation(_AssocLeft))
def tags(self):
if self.assoc_left is None:
self.assoc_left = _AssocLeft(cls.__name__)
return ProxyList(_AssocRight, self.assoc_left.assoc_rights, 'tag')
setattr(cls, name, property(tags))
class Tag(Base):
__tablename__ = 'tags'
id = Column(Integer, primary_key=True)
text = Column(String(50))
def __init__(self, text):
self.text = text
def __repr__(self):
return "<Tag: %s>" % self.text
class User(Base):
__tablename__ = 'users'
name = Column(String(50), primary_key=True)
def __init__(self, name):
self.name = name
def __repr__(self):
return "<User: %s>" % self.name
class Page(Base):
__tablename__ = 'pages'
title = Column(String(50), primary_key=True)
def __init__(self, title):
self.title = title
def __repr__(self):
return "<Page: %s>" % self.title
taggable(User, 'tags')
taggable(Page, 'tags')
######
metadata.drop_all()
metadata.create_all()
print "creating user bob"
user_bob = User('bob')
print "creating page coffee"
page_coffee = Page('coffee')
print "creating tag foo"
tag_foo = Tag('foo')
print "creating tag bar"
tag_bar = Tag('bar')
print "creating tag baz"
tag_baz = Tag('baz')
print "tagging bob with foo"
user_bob.tags.append(tag_foo)
print "tagging bob with bar"
user_bob.tags.append(tag_bar)
print "tagging coffee with bar"
page_coffee.tags.append(tag_bar)
print "tagging coffee with baz"
page_coffee.tags.append(tag_baz)
print "adding all objects to session"
sess = create_session()
[sess.add(o) for o in [user_bob, page_coffee, tag_foo, tag_bar, tag_baz]]
sess.flush()
sess.expunge_all()
print
print "****** cleared everything *********"
print
bob = sess.query(User).filter_by(name='bob').one()
print "bob is %s" % bob
print "bob's tags are %s" % [tag.text for tag in bob.tags]
coffee = sess.query(Page).filter_by(title='coffee').one()
print "coffee is %s" % coffee
print "coffee's tags are %s" % [tag.text for tag in coffee.tags]
for tag in sess.query(Tag).all():
print tag,
print "associations", [ar.assoc_left for ar in tag.assoc_rights]
assert [tag.text for tag in bob.tags] == ['foo', 'bar']
assert [tag.text for tag in coffee.tags] == ['bar', 'baz']
print "tests passed"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment