Last active
January 27, 2022 22:41
-
-
Save cj-dimaggio/34d0983906a14ad3c739 to your computer and use it in GitHub Desktop.
Class definitions of SQLAlchemy classes that maintain order of lists in Many-To-Many relationships.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import Column, String, Integer, ForeignKey, create_engine, Table | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import sessionmaker, RelationshipProperty | |
from sqlalchemy.orm.dependency import ManyToManyDP | |
from sqlalchemy.util.langhelpers import public_factory | |
################# | |
# THE ACTUAL CODE | |
################# | |
class OrderedManyToManyDP(ManyToManyDP): | |
def process_saves(self, uowcommit, states): | |
from sqlalchemy.orm import attributes, sync | |
secondary_delete = [] | |
secondary_insert = [] | |
secondary_update = [] | |
order_field = self.prop.order_field | |
processed = self._get_reversed_processed_set(uowcommit) | |
tmp = set() | |
for state in states: | |
need_cascade_pks = not self.passive_updates and \ | |
self._pks_changed(uowcommit, state) | |
if need_cascade_pks: | |
passive = attributes.PASSIVE_OFF | |
else: | |
passive = attributes.PASSIVE_NO_INITIALIZE | |
history = uowcommit.get_attribute_history(state, self.key, | |
passive) | |
if history: | |
# Up to this point everything has been the same as ManyToManyDP's | |
# process_saves implementation | |
added = [instance.object for instance in history.added] | |
unchanged = [instance.object for instance in history.unchanged] | |
# Even though sqlalchemy's MANYTOMANY relationship doesn't appear | |
# to support multiple entries of the same association (for instance, | |
# an Element won't have the same Image appear more than once in the relationship | |
# list even if it has multiple entries in the secondary table) I think | |
# that it *should* so in order to add theoretical support for this, | |
# we need to keep track of those entries that we've already accounted for | |
# in case there's multiple of the same kind of relationship | |
used = [] | |
# We'll need a connection to figure out what were the original weight values were. | |
# This would probably be unnecessary if I wasn't idealistically holding onto this | |
# concept of relationships to the same object multiple times but I'm kind of dumb. | |
connection = uowcommit.transaction.connection(self.mapper) | |
# Instead of iterating through each list in the history element, we iterate | |
# through what's actually been set to the python object and keep track of its | |
# index, then we see which history list it's in and where. For large lists this | |
# can obviously be costly but I couldn't think of a cleaner way of doing this. | |
# TODO: Try to get the '.index' initially and use that to determine if the child_obj is in the list, | |
# if it is, use that index, otherwise catch the ValueError and move on to the next check. | |
# Should be the same functionality but with one less, potentially costly, operation that has | |
# to walk the array. | |
for index, child_obj in enumerate(state.dict.get(self.key)): | |
if child_obj in added: | |
child = history.added[added.index(child_obj)] | |
# From this point the logic is mostly just copied from our parent's behavior | |
if (processed is not None and | |
(state, child) in processed): | |
continue | |
associationrow = {} | |
if not self._synchronize(state, | |
child, | |
associationrow, | |
False, uowcommit, "add"): | |
continue | |
# This is the one point where we deviate from ManyToManyDP. | |
# associationrow at this point is just a dict that holds what | |
# we want to set the columns to on the insert. Something like: | |
# { 'foreignkey_a': num, 'foreignkey_b': num } | |
# So we manually add what we want the weight to be to that dict | |
associationrow.update({order_field: index}) | |
secondary_insert.append(associationrow) | |
tmp.add((child, state)) | |
elif child_obj in unchanged: | |
from sqlalchemy import sql | |
child = history.unchanged[unchanged.index(child_obj)] | |
# In the original implementation, this logic is hidden behind a "need_cascade_pks" | |
# check. Which, admittedly, I haven't really tested nor totally understand when it's set. | |
# But we're interested in updating all of the existing rows (even if they don't technically, | |
# actually change, we still need to check) and the original logic is actually very good | |
# for retrieving the existing row so I just hijacked it. | |
associationrow = {} | |
sync.update(state, | |
self.parent, | |
associationrow, | |
"old_", | |
self.prop.synchronize_pairs) | |
sync.update(child, | |
self.mapper, | |
associationrow, | |
"old_", | |
self.prop.secondary_synchronize_pairs) | |
# association row now has something that looks like: | |
# {'foreignkey_a': num, 'foreignkey_b': num, 'old_foreignkey_a': num, 'old_foreignkey_b': num} | |
# and _run_crud will use those 'old_' prefixed keys as part of the update. So we want to set | |
# the 'order_field' and 'old_'order_field' keys as well. | |
# Again, we're assuming that there can be multiple of the same relationship in the table and we | |
# want to update all of them, but we need to keep track of those that we've already accounted for | |
# so we make sure we don't grab one that has been added to the 'used' list | |
statement = self.secondary.select(sql.and_(*[ | |
c == associationrow.get(c.key) | |
for c in self.secondary.c | |
if c.key in associationrow | |
] + [getattr(self.secondary.c, order_field) != w | |
for w in used])) | |
resp = connection.execute(statement).first() | |
if resp: | |
old_order = getattr(resp, order_field) | |
used.append(old_order) | |
associationrow.update({order_field: index, | |
"old_" + order_field: old_order | |
}) | |
secondary_update.append(associationrow) | |
# We don't really care about rows that are going to be deleted (I don't think so at least) | |
# so this is just copied directly from ManyToManyDP | |
for child in history.deleted: | |
if (processed is not None and | |
(state, child) in processed): | |
continue | |
associationrow = {} | |
if not self._synchronize(state, | |
child, | |
associationrow, | |
False, uowcommit, "delete"): | |
continue | |
secondary_delete.append(associationrow) | |
tmp.update((c, state) | |
for c in history.deleted) | |
# Again, all of this is just ManyToManyDP behavior | |
if processed is not None: | |
processed.update(tmp) | |
self._run_crud(uowcommit, secondary_insert, | |
secondary_update, secondary_delete) | |
class OrderedRelationshipProperty(RelationshipProperty): | |
def __init__(self, argument, order_field, **kwargs): | |
from sqlalchemy.exc import ArgumentError | |
super(OrderedRelationshipProperty, self).__init__(argument, **kwargs) | |
if self.secondary is None: | |
# It wouldn't theoretically be a *huge* lift to get this working for non MANYTOMANY | |
# relationships as well but there's no reason trying to reimplement ordering_list's | |
# established and much better tested behaviour. Perhaps in these cases we can act as | |
# a simple wrapper. | |
raise ArgumentError("This relationship is currently only implemented for " | |
"secondary relationships. For primary relationships you " | |
"should instead use 'ordering_list'") | |
if not hasattr(self.secondary.c, order_field): | |
raise ArgumentError("order_field specified: %s, does not exist on the associated " | |
"secondary table: %s." % (order_field, self.secondary)) | |
# We should also probably set the 'order_by' parameter based on the order_field | |
# but it would require a little bit more of back and forth with RelationshipProperty | |
# then I have time to implement right now, so let's keep it as a TODO | |
self.order_field = order_field | |
def _post_init(self): | |
from sqlalchemy.orm.base import MANYTOONE | |
if self.uselist is None: | |
self.uselist = self.direction is not MANYTOONE | |
if not self.viewonly: | |
# Because we're demanding secondary tables we can | |
# safely assume we're dealing with a MANYTOMANY | |
self._dependency_processor = OrderedManyToManyDP(self) | |
# Haven't totally investigated what public_factory does or if it's even necessary, | |
# but including the logic to more accurately follow the original relationship initialization | |
ordered_relationship = public_factory(OrderedRelationshipProperty, ".orm.ordered_relationship") | |
###################################### | |
# EXAMPLE CODE TO TEST THE NEW CLASSES | |
###################################### | |
Base = declarative_base() | |
association_table = Table('association', Base.metadata, | |
Column('elements_id', Integer, ForeignKey('elements.id')), | |
Column('images_id', Integer, ForeignKey('images.id')), | |
Column('weight', Integer) | |
) | |
class Images(Base): | |
__tablename__ = 'images' | |
id = Column(Integer, primary_key=True) | |
url = Column(String) | |
class Elements(Base): | |
__tablename__ = 'elements' | |
id = Column(Integer, primary_key=True) | |
name = Column(String) | |
images = ordered_relationship('Images', 'weight', secondary=association_table, order_by=association_table.c.weight, | |
lazy=False) | |
engine = create_engine('sqlite:///sqlalchemy_testing.sqlite') | |
session = sessionmaker() | |
session.configure(bind=engine) | |
Base.metadata.create_all(engine) | |
img1 = Images(url='img 1') | |
img2 = Images(url='img 2') | |
img3 = Images(url='img 3') | |
element = Elements(name="Test") | |
# Note that something like [img1, img2, img3, img3] will | |
# still only show up as [img1, img2, img3] when received | |
# from the database, the duplicate getting redacted. | |
# I don't like that. However I still try to take account for | |
# stuff like that in OrderedManyToManyDP. | |
element.images = [img1, img2, img3] | |
s = session() | |
s.add(element) | |
s.commit() | |
s.close() | |
element = s.query(Elements).get(1) | |
print [img.url for img in element.images] | |
s = session() | |
element = s.query(Elements).get(1) | |
img1 = s.query(Images).get(1) | |
img2 = s.query(Images).get(2) | |
img3 = s.query(Images).get(3) | |
element.images = [img3, img1, img2] | |
s.commit() | |
s.close() | |
element = s.query(Elements).get(1) | |
print [img.url for img in element.images] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment