Created
April 15, 2009 04:47
-
-
Save storborg/95607 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
a polymorphic many-to-many association with referential integrity and clean | |
collection access! (at least in one direction) sqlalchemy is the shit. | |
""" | |
__author__ = 'scott torborg (scotttorborg.com)' | |
from sqlalchemy import * | |
from sqlalchemy.orm import * | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.ext.associationproxy import _AssociationList | |
class ProxyList(_AssociationList): | |
""" | |
An iterable which proxies access to a given attribute of a wrapped list. | |
""" | |
def __init__(self, creator, col, attr_name): | |
""" | |
creator | |
When new items are added to this proxied list, new instances of the | |
intermediate class must be created. This function is used, called | |
with the new proxied value. | |
col | |
Collection we are proxying across. | |
attr | |
Attribute to use when proxying across intermediate objects. | |
""" | |
self.creator = creator | |
self._col = col | |
self.attr_name = attr_name | |
def lazy_collection(self): | |
return self._col | |
def _create(self, value): | |
""" | |
Override this method to control how new intermediate objects are | |
created when items are appended to the list: for example, if we need | |
to access self while we are creating new objects. | |
""" | |
return self.creator(value) | |
def _get(self, object): | |
""" | |
Override this method to control how object reads is proxied. | |
""" | |
return getattr(object, self.attr_name) | |
def _set(self, object, value): | |
""" | |
Override this method to control how object writes are proxied. | |
""" | |
return setattr(object, self.attr_name, value) | |
metadata = MetaData('sqlite://') | |
Base = declarative_base(metadata=metadata) | |
class _AssocLeft(Base): | |
__tablename__ = 'assoc_left' | |
assoc_id = Column(Integer, primary_key=True) | |
type = Column(String(50)) | |
assoc_rights = relation('_AssocRight') | |
def __init__(self, type): | |
self.type = type | |
class _AssocRight(Base): | |
__tablename__ = 'assoc_right' | |
id = Column(Integer, primary_key=True) | |
assoc_id = Column(None, ForeignKey('assoc_left.assoc_id')) | |
tag_id = Column(None, ForeignKey('tags.id')) | |
tag = relation('Tag', backref='assoc_rights') | |
assoc_left = relation('_AssocLeft') | |
def __init__(self, tag): | |
self.tag = tag | |
def taggable(cls, name): | |
"""taggable interface. gives the cls a .tags property""" | |
cls.assoc_id = Column(None, ForeignKey('assoc_left.assoc_id')) | |
mapper = class_mapper(cls) | |
table = mapper.local_table | |
mapper.add_property('assoc_left', relation(_AssocLeft)) | |
def tags(self): | |
if self.assoc_left is None: | |
self.assoc_left = _AssocLeft(cls.__name__) | |
return ProxyList(_AssocRight, self.assoc_left.assoc_rights, 'tag') | |
setattr(cls, name, property(tags)) | |
class Tag(Base): | |
__tablename__ = 'tags' | |
id = Column(Integer, primary_key=True) | |
text = Column(String(50)) | |
def __init__(self, text): | |
self.text = text | |
def __repr__(self): | |
return "<Tag: %s>" % self.text | |
class User(Base): | |
__tablename__ = 'users' | |
name = Column(String(50), primary_key=True) | |
def __init__(self, name): | |
self.name = name | |
def __repr__(self): | |
return "<User: %s>" % self.name | |
class Page(Base): | |
__tablename__ = 'pages' | |
title = Column(String(50), primary_key=True) | |
def __init__(self, title): | |
self.title = title | |
def __repr__(self): | |
return "<Page: %s>" % self.title | |
taggable(User, 'tags') | |
taggable(Page, 'tags') | |
###### | |
metadata.drop_all() | |
metadata.create_all() | |
print "creating user bob" | |
user_bob = User('bob') | |
print "creating page coffee" | |
page_coffee = Page('coffee') | |
print "creating tag foo" | |
tag_foo = Tag('foo') | |
print "creating tag bar" | |
tag_bar = Tag('bar') | |
print "creating tag baz" | |
tag_baz = Tag('baz') | |
print "tagging bob with foo" | |
user_bob.tags.append(tag_foo) | |
print "tagging bob with bar" | |
user_bob.tags.append(tag_bar) | |
print "tagging coffee with bar" | |
page_coffee.tags.append(tag_bar) | |
print "tagging coffee with baz" | |
page_coffee.tags.append(tag_baz) | |
print "adding all objects to session" | |
sess = create_session() | |
[sess.add(o) for o in [user_bob, page_coffee, tag_foo, tag_bar, tag_baz]] | |
sess.flush() | |
sess.expunge_all() | |
print "****** cleared everything *********" | |
bob = sess.query(User).filter_by(name='bob').one() | |
print "bob is %s" % bob | |
print "bob's tags are %s" % [tag.text for tag in bob.tags] | |
coffee = sess.query(Page).filter_by(title='coffee').one() | |
print "coffee is %s" % coffee | |
print "coffee's tags are %s" % [tag.text for tag in coffee.tags] | |
for tag in sess.query(Tag).all(): | |
print tag, | |
print "associations", [ar.assoc_left for ar in tag.assoc_rights] | |
assert [tag.text for tag in bob.tags] == ['foo', 'bar'] | |
assert [tag.text for tag in coffee.tags] == ['bar', 'baz'] | |
print "tests passed" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment