Created
July 9, 2022 00:21
-
-
Save lo5/2bf02c6455617aa65b244fa11cd1926e to your computer and use it in GitHub Desktop.
A Python dict with undo/redo and transaction support
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional | |
class Operation: | |
def __init__(self, do, undo): | |
self.do = do | |
self.undo = undo | |
class BatchOperation(Operation): | |
def __init__(self, ops: List[Operation]): | |
def do(): | |
for op in ops: | |
op.do() | |
def undo(): | |
for op in reversed(ops): | |
op.undo() | |
super().__init__(do, undo) | |
class History: | |
def __init__(self): | |
self._ops: List[Operation] = [] | |
self._index = 0 | |
def do(self, op: Operation): | |
if self._index < len(self._ops): | |
del self._ops[self._index:] | |
self._ops.append(op) | |
self.redo() | |
def can_redo(self) -> bool: | |
return self._index < len(self._ops) | |
def can_undo(self) -> bool: | |
return self._index > 0 | |
def redo(self): | |
if not self.can_redo(): | |
raise ValueError('Cannot redo: at end of history') | |
self._ops[self._index].do() | |
self._index += 1 | |
def undo(self): | |
if not self.can_undo(): | |
raise ValueError('Cannot undo: at start of history') | |
self._index -= 1 | |
self._ops[self._index].undo() | |
class Transaction: | |
def __init__(self): | |
self.dict = dict() | |
self.ops: List[Operation] = [] | |
class RevertibleDict: | |
def __init__(self): | |
self._dict = dict() | |
self._history = History() | |
self._transaction: Optional[Transaction] = None | |
def __getitem__(self, key): | |
if self._transaction: | |
return self._transaction.dict[key] | |
else: | |
return self._dict[key] | |
def __contains__(self, key): | |
if self._transaction: | |
return key in self._transaction.dict | |
else: | |
return key in self._dict | |
def __setitem__(self, key, value): | |
d = self._dict | |
def do(): | |
d[key] = value | |
if key in d: | |
v = d[key] | |
def undo(): | |
d[key] = v | |
else: | |
def undo(): | |
del d[key] | |
op = Operation(do, undo) | |
if self._transaction: | |
self._transaction.dict[key] = value | |
self._transaction.ops.append(op) | |
else: | |
self._history.do(op) | |
def __delitem__(self, key): | |
d = self._dict | |
if key not in d: | |
raise KeyError(key) | |
v = d[key] | |
def do(): | |
del d[key] | |
def undo(): | |
d[key] = v | |
op = Operation(do, undo) | |
if self._transaction: | |
del self._transaction.dict[key] | |
self._transaction.ops.append(op) | |
else: | |
self._history.do(op) | |
def __enter__(self): | |
if self._transaction: | |
raise RuntimeError('Nesting transactions is not allowed.') | |
self._transaction = Transaction() | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
t = self._transaction | |
if t is None: | |
raise RuntimeError('Cannot commit or rollback: not in transaction.') | |
self._transaction = None | |
if exc_type is None: | |
self._history.do(BatchOperation(t.ops)) | |
def undo(self): | |
self._history.undo() | |
def redo(self): | |
self._history.redo() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from .history import RevertibleDict | |
def test_state(): | |
d = RevertibleDict() | |
# undo/redo once with empty history | |
assert 'x' not in d | |
d['x'] = 42 | |
assert d['x'] == 42 | |
d.undo() | |
assert 'x' not in d | |
d.redo() | |
assert d['x'] == 42 | |
# undo/redo once with non-empty history | |
d['x'] = 43 | |
assert d['x'] == 43 | |
d.undo() | |
assert d['x'] == 42 | |
d.redo() | |
assert d['x'] == 43 | |
# undo/redo twice with non-empty history | |
d.undo() | |
assert d['x'] == 42 | |
d.undo() | |
assert 'x' not in d | |
d.redo() | |
assert d['x'] == 42 | |
d.redo() | |
assert d['x'] == 43 | |
# undo/redo thrice with non-empty history | |
d['x'] = 44 | |
d.undo() | |
assert d['x'] == 43 | |
d.undo() | |
assert d['x'] == 42 | |
d.undo() | |
assert 'x' not in d | |
d.redo() | |
assert d['x'] == 42 | |
d.redo() | |
assert d['x'] == 43 | |
d.redo() | |
assert d['x'] == 44 | |
# undo twice and clobber | |
d.undo() # 43 | |
d.undo() # 42 | |
d['x'] = 45 | |
assert d['x'] == 45 | |
d.undo() | |
assert d['x'] == 42 | |
d.redo() | |
assert d['x'] == 45 | |
def test_state_nested(): | |
d = RevertibleDict() | |
# undo/redo once with empty history | |
assert ('x', 'y') not in d | |
d['x', 'y'] = 42 | |
assert d['x', 'y'] == 42 | |
d.undo() | |
assert ('x', 'y') not in d | |
d.redo() | |
assert d['x', 'y'] == 42 | |
# undo/redo once with non-empty history | |
d['x', 'y'] = 43 | |
assert d['x', 'y'] == 43 | |
d.undo() | |
assert d['x', 'y'] == 42 | |
d.redo() | |
assert d['x', 'y'] == 43 | |
# undo/redo twice with non-empty history | |
d.undo() | |
assert d['x', 'y'] == 42 | |
d.undo() | |
assert ('x', 'y') not in d | |
d.redo() | |
assert d['x', 'y'] == 42 | |
d.redo() | |
assert d['x', 'y'] == 43 | |
# undo/redo thrice with non-empty history | |
d['x', 'y'] = 44 | |
d.undo() | |
assert d['x', 'y'] == 43 | |
d.undo() | |
assert d['x', 'y'] == 42 | |
d.undo() | |
assert ('x', 'y') not in d | |
d.redo() | |
assert d['x', 'y'] == 42 | |
d.redo() | |
assert d['x', 'y'] == 43 | |
d.redo() | |
assert d['x', 'y'] == 44 | |
# undo twice and clobber | |
d.undo() # 43 | |
d.undo() # 42 | |
d['x', 'y'] = 45 | |
assert d['x', 'y'] == 45 | |
d.undo() | |
assert d['x', 'y'] == 42 | |
d.redo() | |
assert d['x', 'y'] == 45 | |
def test_rollback(): | |
d = RevertibleDict() | |
try: | |
with d: | |
d['x'] = 42 | |
d['y'] = 43 | |
d['z'] = 44 | |
assert d['x'] == 42 | |
assert d['y'] == 43 | |
assert d['z'] == 44 | |
raise ValueError('test') | |
except: | |
pass | |
assert 'x' not in d | |
assert 'y' not in d | |
assert 'z' not in d | |
def test_commit(): | |
d = RevertibleDict() | |
with d: | |
d['x'] = 42 | |
d['y'] = 43 | |
d['z'] = 44 | |
assert d['x'] == 42 | |
assert d['y'] == 43 | |
assert d['z'] == 44 | |
assert d['x'] == 42 | |
assert d['y'] == 43 | |
assert d['z'] == 44 | |
d.undo() | |
assert 'x' not in d | |
assert 'y' not in d | |
assert 'z' not in d | |
d.redo() | |
assert d['x'] == 42 | |
assert d['y'] == 43 | |
assert d['z'] == 44 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment