Skip to content

Instantly share code, notes, and snippets.

@jlumpe
Last active June 23, 2023 21:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jlumpe/5ae1b5fd9ddff884c91d3b5d5316c740 to your computer and use it in GitHub Desktop.
Save jlumpe/5ae1b5fd9ddff884c91d3b5d5316c740 to your computer and use it in GitHub Desktop.
Python class for a mutable set of trees with indexed nodes containing arbitrary data.
from collections.abc import MutableMapping, Set
class SetProxy(Set):
"""Read-only proxy for a set type."""
def __init__(self, set_):
self._set_ = set_
def __len__(self):
return len(self._set_)
def __iter__(self):
return iter(self._set_)
def __contains__(self, item):
return item in self._set_
def __repr__(self):
return '{}({!r})'.format(type(self).__name__, self._set_)
class AttributeMappingProxy:
"""
Translates attribute get/set/delete on proxy to key get/set/delete on a
mapping.
"""
def __init__(self, mapping):
# Look out for name mangling
self.__mapping__ = mapping
def __getattr__(self, name):
try:
return self.__mapping__[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name, value):
if name.startswith('__'):
object.__setattr__(self, name, value)
else:
self.__mapping__[name] = value
def __delattr__(self, name):
try:
del self.__mapping__[name]
except KeyError:
raise AttributeError(name)
def __dir__(self):
attrs = dir(type(self))
attrs.extend(self.__dict__)
attrs.extend(self.__mapping__)
return attrs
class Forest(MutableMapping):
"""Indexed collection of nodes which form a forest of trees."""
def __init__(self):
self._nodes = dict()
self.nodes = SetProxy(self._nodes.values())
def __len__(self):
return len(self._nodes)
def __iter__(self):
return iter(self._nodes)
def __contains__(self, key):
return key in self._nodes
def __getitem__(self, key):
return self._nodes[key]
def __setitem__(self, key, data):
self._add_node(key, data=data)
def __delitem__(self, key):
self.remove(key, splice=True)
def _check_key(self, key):
"""Check that a key can be added."""
if key in self:
raise KeyError('Key {!r} exists'.format(key))
elif key is None:
raise TypeError('None is not a valid key')
elif isinstance(key, _ForestNode):
raise TypeError('Cannot use nodes as keys')
def _check_node(self, node, convert_key=True):
"""Check that the node exists in the forest."""
if node is None:
return None
if not isinstance(node, _ForestNode):
if convert_key:
node = self[node]
else:
raise TypeError('Expected node, not {}'.format(type(node)))
elif node._forest is not self:
raise ValueError('Node not in forest')
return node
def _add_node(self, key, data=None, parent=None):
self._check_key(key)
node = _ForestNode(self, key)
self._nodes[key] = node
if data is not None:
node.update(data)
if parent is not None:
self._attach(parent, node)
return node
def _attach(self, parent, child):
"""Attach child node to parent.
Set child's parent, add child to parent's children. Detach from existing
parent first if needed.
"""
if child._parent_key == parent._key:
return
if child._parent_key is not None:
self._detach(child.parent, child)
child._parent_key = parent._key
parent._child_keys.add(child._key)
def _detach(self, parent, child):
"""Detach child node from parent.
Set child parent to None and remove child from parent's children.
"""
parent._child_keys.remove(child._key)
child._parent_key = None
def _remove(self, node):
"""Remove node from forest, updating other side of its relationships.
"""
if node._parent_key is not None:
self._detach(node.parent, node)
for child in list(node.children):
self._detach(node, child)
del self._nodes[node._key]
node._forest = None
def _prune(self, node):
if node._parent_key is not None:
self._detach(node.parent, node)
for child in list(node.children):
self._prune(child)
self._remove(node)
def _rename(self, node, new_key):
self._check_key(new_key)
old_key = node._key
parent = node.parent
if parent is not None:
parent._child_keys.remove(old_key)
parent._child_keys.add(new_key)
for child in node.children:
child._parent_key = new_key
node._key = new_key
del self._nodes[old_key]
self._nodes[new_key] = node
def add_node(self, key, data=None):
"""Create a new node."""
return self._add_node(key, data=data)
def remove(self, node, splice=True):
"""Remove a node"""
node = self._check_node(node)
parent = node.parent
children = list(node.children)
self._remove(node)
if splice and parent is not None:
for child in children:
self._attach(parent, child)
def prune(self, node):
"""Remove a node and all its descendants."""
node = self._check_node(node)
if node._parent_key is not None:
self._detach(node.parent, node)
self._prune(node)
def root_nodes(self):
return (node for node in self._nodes.values() if node._parent_key is None)
def to_json(self):
"""Convert to a JSON-able dictionary representation.
Keys must be strings.
"""
data = dict()
for node in self.nodes:
data[node.key] = dict(parent=node._parent_key, data=dict(node))
return data
@classmethod
def from_json(cls, data):
"""Create from parsed JSON data, as created by :meth:`to_json`."""
forest = cls()
node_data_by_parent = dict()
for key, node_json in data.items():
parent_key = node_json['parent']
try:
parent_nodes = node_data_by_parent[parent_key]
except KeyError:
parent_nodes = node_data_by_parent[parent_key] = dict()
parent_nodes[key] = node_json['data']
def create_subtree(root):
for key, child_data in node_data_by_parent.pop(root.key, {}).items():
child = root.add_child(key, child_data)
create_subtree(child)
for key, node_data in node_data_by_parent.pop(None).items():
node = forest.add_node(key, node_data)
create_subtree(node)
if node_data_by_parent:
raise ValueError('Invalid file')
return forest
class _ForestNode(dict):
def __init__(self, forest, key):
self._forest = forest
self._key = key
self._parent_key = None
self._child_keys = set()
self.child_keys = SetProxy(self._child_keys)
self.children = _ForestNodeChildren(self)
self.a = AttributeMappingProxy(self)
def __eq__(self, other):
"""Compare by identity, not as a mapping."""
return other is self
def data_eq(self, other):
"""Check if this node's data is equivalent to another mapping."""
return dict.__eq__(self, other)
def _require_forest(self):
if self._forest is not None:
return self._forest
else:
raise ValueError('Node has been removed from forest')
@property
def forest(self):
return self._forest
@property
def key(self):
return self._key
@key.setter
def key(self, value):
self._require_forest()._rename(self, value)
@property
def parent(self):
if self._parent_key is None:
return None
else:
return self._require_forest()[self._parent_key]
@parent.setter
def parent(self, node):
forest = self._require_forest()
forest._check_node(node, convert_key=False)
if node is None:
if self._parent_key is not None:
forest._detach(self.parent, self)
else:
if self in node.ancestors(include_self=True):
raise ValueError('Cannot graft a node onto one of its descendants')
forest._attach(node, self)
@property
def parent_key(self):
return self._parent_key
@parent_key.setter
def parent_key(self, value):
self.parent = None if value is None else self._require_forest()[value]
def add_child(self, key, data=None):
"""Create a new child node."""
return self._require_forest()._add_node(key, parent=self, data=data)
def traverse_pre(self):
"""Iterate through nodes in subtree in pre-order (top down)."""
yield self
for child in self.children:
yield from child.traverse_pre()
def traverse_post(self):
"""Iterate through nodes in subtree in post-order (bottom up)."""
for child in self.children:
yield from child.traverse_post()
yield self
def count(self):
"""Get the number of nodes in this node's subtree."""
count = 0
for node in self.traverse_pre():
count += 1
return count
def depth(self):
"""Get the number of ancestors the node has."""
depth = 0
for node in self.ancestors():
depth += 1
return depth
def ancestors(self, include_self=False):
"""Iterate over the node's ancestors, most recent first."""
node = self if include_self else self.parent
while node is not None:
yield node
node = node.parent
def ancestor_keys(self, include_self=False):
return (node.key for node in self.ancestors(include_self))
def __repr__(self):
if self._forest is not None:
return '<Node {!r}>'.format(self._key)
else:
return '<Node #DETACHED#>'
class _ForestNodeChildren(Set):
def __init__(self, node):
self._node = node
def __contains__(self, node):
return isinstance(node, _ForestNode) and node._key in self._node._child_keys
def __iter__(self):
return (self._node._require_forest()[key] for key in self._node._child_keys)
def __len__(self):
return len(self._node._child_keys)
def __repr__(self):
return '{{{}}}'.format(', '.join(map(repr, self)))
import sys
sys.path.insert(0, '.')
import pytest
from forest import Forest
TEST_FOREST = {
1: {
11: {
112: {},
113: {},
},
12: {},
13: {},
},
2: {
21: {
211: {
2111: {}
}
}
},
3: {},
}
def validate_relationships(node):
"""Validate parent-child relationships for a node."""
parent = node.parent
if parent is not None:
assert node in parent.children
assert node.key in parent.child_keys
for child in node.children:
assert child.parent is node
assert child.key in node.child_keys
@pytest.fixture()
def forest():
"""A forest with nodes of varying degrees and depths."""
forest = Forest()
def _create_subtree(structure, key, parent):
if parent is None:
node = forest.add_node(key)
else:
node = parent.add_child(key)
for child_key, child_structure in structure.items():
_create_subtree(child_structure, child_key, node)
for key, structure in TEST_FOREST.items():
_create_subtree(structure, key, None)
return forest
class TestForest:
"""Test the Forest class."""
def test_relationships(self, forest):
"""Test all parent-child relationships in test forest."""
for node in forest.nodes:
validate_relationships(node)
def test_mapping_methods(self, forest):
"""Test basic (non-mutating) mapping special methods."""
keys_list = list(forest)
assert len(forest) == len(keys_list)
for key in keys_list:
assert key in forest
node = forest[key]
assert node.key == key
bad_key = 'not a valid key'
with pytest.raises(KeyError) as exc_info:
forest[bad_key]
assert bad_key in str(exc_info)
@pytest.mark.parametrize('new_key', [11, 15])
@pytest.mark.parametrize('method', ['add_node', '__setitem__'])
@pytest.mark.parametrize('data', [None, dict(), dict(foo='bar')])
def test_add_node(self, forest, new_key, method, data):
"""Test the add_node() and __setitem__ methods."""
def add_node():
if method == 'add_node':
return forest.add_node(new_key, data)
else:
forest[new_key] = data
return forest[new_key]
if new_key in forest:
with pytest.raises(KeyError):
add_node()
else:
node = add_node()
assert node.key == new_key
assert node.parent is None
assert dict(node) == (data or {})
def test_add_node_invalid(self):
"""Test the add_node() method with invalid key types."""
forest = Forest()
# None as key
with pytest.raises(TypeError):
forest.add_node(None)
# Node as key
node = forest.add_node(1)
with pytest.raises(TypeError):
forest.add_node(node)
@pytest.mark.parametrize('key', [1, 11])
@pytest.mark.parametrize('method,splice', [
('node', False),
('node', True),
('key', False),
('key', True),
('delitem', True),
])
def test_remove(self, forest, key, method, splice):
"""Test the remove() method."""
node = forest[key]
parent = node.parent
child_keys = set(node.child_keys)
if method == 'node':
forest.remove(node, splice)
if method == 'key':
forest.remove(key, splice)
elif method == 'delitem':
del forest[key]
assert key not in forest
if parent is not None:
assert node not in parent.children
assert key not in parent.child_keys
for child_key in child_keys:
child = forest[child_key]
assert child.parent is (parent if splice else None)
@pytest.mark.parametrize('use_key', [False, True])
def test_prune(self, forest, use_key):
"""Test the prune() method."""
key = 11
node = forest[key]
parent = node.parent
subtree_keys = {d.key for d in node.traverse_pre()}
forest.prune(key if use_key else node)
for key in subtree_keys: # Includes node
assert key not in forest
if parent is not None:
assert node not in parent.children
assert key not in parent.child_keys
def test_root_nodes(self, forest):
"""Test the root_nodes() method."""
root_nodes = list(forest.root_nodes())
assert {node.key for node in root_nodes} == TEST_FOREST.keys()
for node in root_nodes:
assert node.parent is None
def test_json(self, forest):
"""Test to_json() and from_json() methods."""
import json
# Need to use string keys
for node in list(forest.nodes):
node.key = str(node.key)
json_data = json.dumps(forest.to_json())
forest2 = Forest.from_json(json.loads(json_data))
assert len(forest) == len(forest2)
for node in forest.nodes:
node2 = forest2[node.key]
assert node2.parent_key == node.parent_key
assert node.child_keys == node2.child_keys
assert node.data_eq(node2)
class TestNode:
"""Test the _Node class."""
def test_equality(self):
"""Test standard equality between nodes and data equality with data_eq()."""
forest = Forest()
forest.add_node(1)
forest.add_node(2)
forest.add_node(3, dict(foo='bar'))
forest.add_node(4, forest[3])
for node1 in forest.nodes:
for node2 in forest.nodes:
assert (node1 == node2) == (node1 is node2)
same_data = dict(node1) == dict(node2)
assert node1.data_eq(node2) == same_data
assert node1.data_eq(dict(node2)) == same_data
@pytest.mark.parametrize('prev_key,new_key', [
(1, -1),
(1, 11),
(11, -11),
(11, 12),
(11, 11),
])
def test_rename(self, forest, prev_key, new_key):
"""Test assignment to .key attribute."""
node = forest[prev_key]
if new_key in forest:
with pytest.raises(KeyError):
node.key = new_key
else:
node.key = new_key
assert node.key == new_key
validate_relationships(node)
@pytest.mark.parametrize('key,onto_key', [
(1, 3),
(1, None),
(11, 1),
(11, 2),
(11, None),
])
@pytest.mark.parametrize('use_key', [False, True])
def test_graft(self, forest, key, onto_key, use_key):
"""Test assignment to .parent attribute."""
node = forest[key]
previous_parent = node.parent
new_parent = None if onto_key is None else forest[onto_key]
if use_key:
node.parent_key = onto_key
else:
node.parent = new_parent
assert node.parent is new_parent
validate_relationships(node)
if previous_parent is not None:
validate_relationships(previous_parent)
def test_graft_circular(self, forest):
"""Test grafting a node onto one of its descendants (not allowed)."""
with pytest.raises(ValueError):
forest[1].parent_key = 11
with pytest.raises(ValueError):
forest[1].parent_key = 1
@pytest.mark.parametrize('child_key', [11, 15])
@pytest.mark.parametrize('data', [None, dict(), dict(foo='bar')])
def test_add_child(self, forest, child_key, data):
"""Test the add_child() method."""
parent = forest[1]
if child_key in forest:
with pytest.raises(KeyError):
parent.add_child(child_key, data)
else:
child = parent.add_child(child_key, data)
assert child.key == child_key
assert child.parent is parent
assert dict(child) == (data or {})
@pytest.mark.parametrize('include_self', [False, True])
def test_ancestors(self, forest, include_self):
"""Test the ancestors() method."""
for node in forest.nodes:
ancestors = node.ancestors(include_self)
if node.parent is None and not include_self:
with pytest.raises(StopIteration):
next(ancestors)
continue
if include_self:
assert next(ancestors) is node
child = node
for parent in ancestors:
assert child.parent is parent
child = parent
assert child.parent is None
@pytest.mark.parametrize('include_self', [False, True])
def test_ancestor_keys(self, forest, include_self):
"""Test the ancestor_keys() method."""
for node in forest.nodes:
for ancestor, key in zip(node.ancestors(), node.ancestor_keys()):
assert ancestor.key == key
@pytest.mark.parametrize('order', ['pre', 'post'])
def test_traverse(self, forest, order):
for root in forest.nodes:
seen_keys = set()
it = root.traverse_pre() if order == 'pre' else root.traverse_post()
for node in it:
assert root in node.ancestors(True)
assert node.key not in seen_keys
if order == 'pre':
if node is not root:
assert node.parent.key in seen_keys
else:
assert seen_keys.issuperset(node.child_keys)
seen_keys.add(node.key)
assert len(seen_keys) == root.count()
def test_count(self, forest):
"""Test the count() method."""
for node in forest.nodes:
assert node.count() == 1 + sum(c.count() for c in node.children)
def test_depth(self, forest):
"""Test the depth() method."""
for node in forest.nodes:
assert node.depth() == len(list(node.ancestors()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment