Skip to content

Instantly share code, notes, and snippets.

@grakic
Last active May 5, 2023 19:41
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save grakic/66006440ed871c8b9d73a2206dff8a5e to your computer and use it in GitHub Desktop.
Save grakic/66006440ed871c8b9d73a2206dff8a5e to your computer and use it in GitHub Desktop.
from sqlalchemy_continuum.plugins import Plugin
from sqlalchemy_continuum import Operation
from sqlalchemy.inspection import inspect
class RelatedVersioningPlugin(Plugin):
def __init__(self):
self.class_registry = []
def get_model(self, table_name):
"""
Get declarative model for given table
:param table: sqlalchemy.schema.Table
:return: db.Model
"""
for klass in self.class_registry:
if hasattr(klass, '__tablename__'):
if klass.__tablename__ == table_name:
return klass
raise AttributeError('Unknown model for table %s' % table_name)
def before_create_version_objects(self, uow, session):
# Iterate all operations and find other objects, having a foreign key
# reference to the target object, that should also be added to transaction
for _, operation in uow.operations.items():
# Process each operation only once
if operation.processed:
continue
if not self.class_registry:
self.class_registry = operation.target._decl_class_registry.values()
# Prevent infinite loop
max_allowed_level = 10
# Work on all referenced objects and their types, starting from operation.target
object_sets = [(max_allowed_level, [operation.target]),]
for level, objects in object_sets:
if level <= 0:
raise RuntimeError('Maximum level reached while collecting referenced objects')
model = type(objects[0])
table = objects[0].__table__
ignored_properties = model.__versioned__.get('exclude_for_related', [])
# Get object property name with a given foreign keys
property_name_by_fk = {}
reflected_model = inspect(model)
for column in table.columns:
for fk in column.foreign_keys:
property_name_by_fk[fk] = reflected_model.get_property_by_column(column).key
# Get other tables referenced by foreign keys in the object
known_tables = table.metadata.tables
referenced_tables = set()
for fk in table.foreign_keys:
referenced_tables.update((
(foreign_table_name, fk.column, property_name_by_fk[fk])
for foreign_table_name, foreign_table in known_tables.items()
if property_name_by_fk[fk] not in ignored_properties and fk.references(foreign_table)
))
# Forward same versioning operation to all objects referenced with foreign key
for foreign_table_name, foreign_column, local_property_name in referenced_tables:
foreign_model = self.get_model(foreign_table_name)
# Collect foreign column values
values = []
for obj in objects:
value = getattr(obj, local_property_name)
if value is not None:
values.append(value)
if values:
# Get objects from database
referenced_objects = foreign_model.query.filter(foreign_column.in_(values)).all()
# Include all referenced objects to current versioning transaction
for referenced_object in referenced_objects:
if referenced_object not in uow.operations:
uow.operations.add(Operation(referenced_object, operation.type))
# Check references on referenced model
if referenced_objects:
object_sets.append((level-1, referenced_objects))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment