Last active
November 23, 2016 05:00
-
-
Save dhruvg/51362d06ea1daa2bbe5f3438a2737386 to your computer and use it in GitHub Desktop.
Keep track of model changes across different sessions in SQLAlchemy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import event, orm | |
from sqlalchemy.inspection import inspect | |
from sqlalchemy.orm import object_mapper | |
from sqlalchemy.orm.properties import ColumnProperty | |
from test_project.storage import Session | |
class SessionTracker(object): | |
def __init__(self, sessionmaker): | |
# Keeps track of model changes across different simultaneous sessions. | |
self.model_changes = {} | |
event.listen(sessionmaker, 'after_flush', self.on_after_flush) | |
event.listen(sessionmaker, 'after_commit', self.on_after_commit) | |
event.listen(sessionmaker, 'after_rollback', self.on_after_rollback) | |
def on_after_commit(self, session): | |
session_changes = self.model_changes.pop(session, None) | |
if session_changes: | |
for _object, changes in session_changes.items(): | |
_object.on_after_commit(changes) | |
def on_after_rollback(self, session): | |
self.model_changes.pop(session, None) | |
def on_after_flush(self, session, _): | |
if session not in self.model_changes: | |
self.model_changes[session] = {} | |
changed_objects = session.new.union(session.dirty) | |
for _object in changed_objects: | |
if _object not in self.model_changes: | |
self.model_changes[session][_object] = {} | |
for mapper_property in object_mapper(_object).iterate_properties: | |
if isinstance(mapper_property, ColumnProperty): | |
key = mapper_property.key | |
attribute_state = inspect(_object).attrs.get(key) | |
history = attribute_state.history | |
if history.has_changes(): | |
value = attribute_state.value | |
# old_value is None for new objects and old value for dirty objects | |
old_value = self._get_old_value(attribute_state) | |
self.model_changes[session][_object][mapper_property] = (value, old_value) | |
_object.on_after_flush(self.model_changes[session][_object]) | |
@staticmethod | |
def _get_old_value(_attribute_state): | |
_history = _attribute_state.history | |
return _history.deleted[0] if _history.deleted else None | |
# Initialize session tracker | |
SessionTracker(Session) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment