Last active
June 23, 2023 21:37
-
-
Save jlumpe/5ae1b5fd9ddff884c91d3b5d5316c740 to your computer and use it in GitHub Desktop.
Python class for a mutable set of trees with indexed nodes containing arbitrary data.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import MutableMapping, Set | |
class SetProxy(Set): | |
"""Read-only proxy for a set type.""" | |
def __init__(self, set_): | |
self._set_ = set_ | |
def __len__(self): | |
return len(self._set_) | |
def __iter__(self): | |
return iter(self._set_) | |
def __contains__(self, item): | |
return item in self._set_ | |
def __repr__(self): | |
return '{}({!r})'.format(type(self).__name__, self._set_) | |
class AttributeMappingProxy: | |
""" | |
Translates attribute get/set/delete on proxy to key get/set/delete on a | |
mapping. | |
""" | |
def __init__(self, mapping): | |
# Look out for name mangling | |
self.__mapping__ = mapping | |
def __getattr__(self, name): | |
try: | |
return self.__mapping__[name] | |
except KeyError: | |
raise AttributeError(name) | |
def __setattr__(self, name, value): | |
if name.startswith('__'): | |
object.__setattr__(self, name, value) | |
else: | |
self.__mapping__[name] = value | |
def __delattr__(self, name): | |
try: | |
del self.__mapping__[name] | |
except KeyError: | |
raise AttributeError(name) | |
def __dir__(self): | |
attrs = dir(type(self)) | |
attrs.extend(self.__dict__) | |
attrs.extend(self.__mapping__) | |
return attrs | |
class Forest(MutableMapping): | |
"""Indexed collection of nodes which form a forest of trees.""" | |
def __init__(self): | |
self._nodes = dict() | |
self.nodes = SetProxy(self._nodes.values()) | |
def __len__(self): | |
return len(self._nodes) | |
def __iter__(self): | |
return iter(self._nodes) | |
def __contains__(self, key): | |
return key in self._nodes | |
def __getitem__(self, key): | |
return self._nodes[key] | |
def __setitem__(self, key, data): | |
self._add_node(key, data=data) | |
def __delitem__(self, key): | |
self.remove(key, splice=True) | |
def _check_key(self, key): | |
"""Check that a key can be added.""" | |
if key in self: | |
raise KeyError('Key {!r} exists'.format(key)) | |
elif key is None: | |
raise TypeError('None is not a valid key') | |
elif isinstance(key, _ForestNode): | |
raise TypeError('Cannot use nodes as keys') | |
def _check_node(self, node, convert_key=True): | |
"""Check that the node exists in the forest.""" | |
if node is None: | |
return None | |
if not isinstance(node, _ForestNode): | |
if convert_key: | |
node = self[node] | |
else: | |
raise TypeError('Expected node, not {}'.format(type(node))) | |
elif node._forest is not self: | |
raise ValueError('Node not in forest') | |
return node | |
def _add_node(self, key, data=None, parent=None): | |
self._check_key(key) | |
node = _ForestNode(self, key) | |
self._nodes[key] = node | |
if data is not None: | |
node.update(data) | |
if parent is not None: | |
self._attach(parent, node) | |
return node | |
def _attach(self, parent, child): | |
"""Attach child node to parent. | |
Set child's parent, add child to parent's children. Detach from existing | |
parent first if needed. | |
""" | |
if child._parent_key == parent._key: | |
return | |
if child._parent_key is not None: | |
self._detach(child.parent, child) | |
child._parent_key = parent._key | |
parent._child_keys.add(child._key) | |
def _detach(self, parent, child): | |
"""Detach child node from parent. | |
Set child parent to None and remove child from parent's children. | |
""" | |
parent._child_keys.remove(child._key) | |
child._parent_key = None | |
def _remove(self, node): | |
"""Remove node from forest, updating other side of its relationships. | |
""" | |
if node._parent_key is not None: | |
self._detach(node.parent, node) | |
for child in list(node.children): | |
self._detach(node, child) | |
del self._nodes[node._key] | |
node._forest = None | |
def _prune(self, node): | |
if node._parent_key is not None: | |
self._detach(node.parent, node) | |
for child in list(node.children): | |
self._prune(child) | |
self._remove(node) | |
def _rename(self, node, new_key): | |
self._check_key(new_key) | |
old_key = node._key | |
parent = node.parent | |
if parent is not None: | |
parent._child_keys.remove(old_key) | |
parent._child_keys.add(new_key) | |
for child in node.children: | |
child._parent_key = new_key | |
node._key = new_key | |
del self._nodes[old_key] | |
self._nodes[new_key] = node | |
def add_node(self, key, data=None): | |
"""Create a new node.""" | |
return self._add_node(key, data=data) | |
def remove(self, node, splice=True): | |
"""Remove a node""" | |
node = self._check_node(node) | |
parent = node.parent | |
children = list(node.children) | |
self._remove(node) | |
if splice and parent is not None: | |
for child in children: | |
self._attach(parent, child) | |
def prune(self, node): | |
"""Remove a node and all its descendants.""" | |
node = self._check_node(node) | |
if node._parent_key is not None: | |
self._detach(node.parent, node) | |
self._prune(node) | |
def root_nodes(self): | |
return (node for node in self._nodes.values() if node._parent_key is None) | |
def to_json(self): | |
"""Convert to a JSON-able dictionary representation. | |
Keys must be strings. | |
""" | |
data = dict() | |
for node in self.nodes: | |
data[node.key] = dict(parent=node._parent_key, data=dict(node)) | |
return data | |
@classmethod | |
def from_json(cls, data): | |
"""Create from parsed JSON data, as created by :meth:`to_json`.""" | |
forest = cls() | |
node_data_by_parent = dict() | |
for key, node_json in data.items(): | |
parent_key = node_json['parent'] | |
try: | |
parent_nodes = node_data_by_parent[parent_key] | |
except KeyError: | |
parent_nodes = node_data_by_parent[parent_key] = dict() | |
parent_nodes[key] = node_json['data'] | |
def create_subtree(root): | |
for key, child_data in node_data_by_parent.pop(root.key, {}).items(): | |
child = root.add_child(key, child_data) | |
create_subtree(child) | |
for key, node_data in node_data_by_parent.pop(None).items(): | |
node = forest.add_node(key, node_data) | |
create_subtree(node) | |
if node_data_by_parent: | |
raise ValueError('Invalid file') | |
return forest | |
class _ForestNode(dict): | |
def __init__(self, forest, key): | |
self._forest = forest | |
self._key = key | |
self._parent_key = None | |
self._child_keys = set() | |
self.child_keys = SetProxy(self._child_keys) | |
self.children = _ForestNodeChildren(self) | |
self.a = AttributeMappingProxy(self) | |
def __eq__(self, other): | |
"""Compare by identity, not as a mapping.""" | |
return other is self | |
def data_eq(self, other): | |
"""Check if this node's data is equivalent to another mapping.""" | |
return dict.__eq__(self, other) | |
def _require_forest(self): | |
if self._forest is not None: | |
return self._forest | |
else: | |
raise ValueError('Node has been removed from forest') | |
@property | |
def forest(self): | |
return self._forest | |
@property | |
def key(self): | |
return self._key | |
@key.setter | |
def key(self, value): | |
self._require_forest()._rename(self, value) | |
@property | |
def parent(self): | |
if self._parent_key is None: | |
return None | |
else: | |
return self._require_forest()[self._parent_key] | |
@parent.setter | |
def parent(self, node): | |
forest = self._require_forest() | |
forest._check_node(node, convert_key=False) | |
if node is None: | |
if self._parent_key is not None: | |
forest._detach(self.parent, self) | |
else: | |
if self in node.ancestors(include_self=True): | |
raise ValueError('Cannot graft a node onto one of its descendants') | |
forest._attach(node, self) | |
@property | |
def parent_key(self): | |
return self._parent_key | |
@parent_key.setter | |
def parent_key(self, value): | |
self.parent = None if value is None else self._require_forest()[value] | |
def add_child(self, key, data=None): | |
"""Create a new child node.""" | |
return self._require_forest()._add_node(key, parent=self, data=data) | |
def traverse_pre(self): | |
"""Iterate through nodes in subtree in pre-order (top down).""" | |
yield self | |
for child in self.children: | |
yield from child.traverse_pre() | |
def traverse_post(self): | |
"""Iterate through nodes in subtree in post-order (bottom up).""" | |
for child in self.children: | |
yield from child.traverse_post() | |
yield self | |
def count(self): | |
"""Get the number of nodes in this node's subtree.""" | |
count = 0 | |
for node in self.traverse_pre(): | |
count += 1 | |
return count | |
def depth(self): | |
"""Get the number of ancestors the node has.""" | |
depth = 0 | |
for node in self.ancestors(): | |
depth += 1 | |
return depth | |
def ancestors(self, include_self=False): | |
"""Iterate over the node's ancestors, most recent first.""" | |
node = self if include_self else self.parent | |
while node is not None: | |
yield node | |
node = node.parent | |
def ancestor_keys(self, include_self=False): | |
return (node.key for node in self.ancestors(include_self)) | |
def __repr__(self): | |
if self._forest is not None: | |
return '<Node {!r}>'.format(self._key) | |
else: | |
return '<Node #DETACHED#>' | |
class _ForestNodeChildren(Set): | |
def __init__(self, node): | |
self._node = node | |
def __contains__(self, node): | |
return isinstance(node, _ForestNode) and node._key in self._node._child_keys | |
def __iter__(self): | |
return (self._node._require_forest()[key] for key in self._node._child_keys) | |
def __len__(self): | |
return len(self._node._child_keys) | |
def __repr__(self): | |
return '{{{}}}'.format(', '.join(map(repr, self))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
sys.path.insert(0, '.') | |
import pytest | |
from forest import Forest | |
TEST_FOREST = { | |
1: { | |
11: { | |
112: {}, | |
113: {}, | |
}, | |
12: {}, | |
13: {}, | |
}, | |
2: { | |
21: { | |
211: { | |
2111: {} | |
} | |
} | |
}, | |
3: {}, | |
} | |
def validate_relationships(node): | |
"""Validate parent-child relationships for a node.""" | |
parent = node.parent | |
if parent is not None: | |
assert node in parent.children | |
assert node.key in parent.child_keys | |
for child in node.children: | |
assert child.parent is node | |
assert child.key in node.child_keys | |
@pytest.fixture() | |
def forest(): | |
"""A forest with nodes of varying degrees and depths.""" | |
forest = Forest() | |
def _create_subtree(structure, key, parent): | |
if parent is None: | |
node = forest.add_node(key) | |
else: | |
node = parent.add_child(key) | |
for child_key, child_structure in structure.items(): | |
_create_subtree(child_structure, child_key, node) | |
for key, structure in TEST_FOREST.items(): | |
_create_subtree(structure, key, None) | |
return forest | |
class TestForest: | |
"""Test the Forest class.""" | |
def test_relationships(self, forest): | |
"""Test all parent-child relationships in test forest.""" | |
for node in forest.nodes: | |
validate_relationships(node) | |
def test_mapping_methods(self, forest): | |
"""Test basic (non-mutating) mapping special methods.""" | |
keys_list = list(forest) | |
assert len(forest) == len(keys_list) | |
for key in keys_list: | |
assert key in forest | |
node = forest[key] | |
assert node.key == key | |
bad_key = 'not a valid key' | |
with pytest.raises(KeyError) as exc_info: | |
forest[bad_key] | |
assert bad_key in str(exc_info) | |
@pytest.mark.parametrize('new_key', [11, 15]) | |
@pytest.mark.parametrize('method', ['add_node', '__setitem__']) | |
@pytest.mark.parametrize('data', [None, dict(), dict(foo='bar')]) | |
def test_add_node(self, forest, new_key, method, data): | |
"""Test the add_node() and __setitem__ methods.""" | |
def add_node(): | |
if method == 'add_node': | |
return forest.add_node(new_key, data) | |
else: | |
forest[new_key] = data | |
return forest[new_key] | |
if new_key in forest: | |
with pytest.raises(KeyError): | |
add_node() | |
else: | |
node = add_node() | |
assert node.key == new_key | |
assert node.parent is None | |
assert dict(node) == (data or {}) | |
def test_add_node_invalid(self): | |
"""Test the add_node() method with invalid key types.""" | |
forest = Forest() | |
# None as key | |
with pytest.raises(TypeError): | |
forest.add_node(None) | |
# Node as key | |
node = forest.add_node(1) | |
with pytest.raises(TypeError): | |
forest.add_node(node) | |
@pytest.mark.parametrize('key', [1, 11]) | |
@pytest.mark.parametrize('method,splice', [ | |
('node', False), | |
('node', True), | |
('key', False), | |
('key', True), | |
('delitem', True), | |
]) | |
def test_remove(self, forest, key, method, splice): | |
"""Test the remove() method.""" | |
node = forest[key] | |
parent = node.parent | |
child_keys = set(node.child_keys) | |
if method == 'node': | |
forest.remove(node, splice) | |
if method == 'key': | |
forest.remove(key, splice) | |
elif method == 'delitem': | |
del forest[key] | |
assert key not in forest | |
if parent is not None: | |
assert node not in parent.children | |
assert key not in parent.child_keys | |
for child_key in child_keys: | |
child = forest[child_key] | |
assert child.parent is (parent if splice else None) | |
@pytest.mark.parametrize('use_key', [False, True]) | |
def test_prune(self, forest, use_key): | |
"""Test the prune() method.""" | |
key = 11 | |
node = forest[key] | |
parent = node.parent | |
subtree_keys = {d.key for d in node.traverse_pre()} | |
forest.prune(key if use_key else node) | |
for key in subtree_keys: # Includes node | |
assert key not in forest | |
if parent is not None: | |
assert node not in parent.children | |
assert key not in parent.child_keys | |
def test_root_nodes(self, forest): | |
"""Test the root_nodes() method.""" | |
root_nodes = list(forest.root_nodes()) | |
assert {node.key for node in root_nodes} == TEST_FOREST.keys() | |
for node in root_nodes: | |
assert node.parent is None | |
def test_json(self, forest): | |
"""Test to_json() and from_json() methods.""" | |
import json | |
# Need to use string keys | |
for node in list(forest.nodes): | |
node.key = str(node.key) | |
json_data = json.dumps(forest.to_json()) | |
forest2 = Forest.from_json(json.loads(json_data)) | |
assert len(forest) == len(forest2) | |
for node in forest.nodes: | |
node2 = forest2[node.key] | |
assert node2.parent_key == node.parent_key | |
assert node.child_keys == node2.child_keys | |
assert node.data_eq(node2) | |
class TestNode: | |
"""Test the _Node class.""" | |
def test_equality(self): | |
"""Test standard equality between nodes and data equality with data_eq().""" | |
forest = Forest() | |
forest.add_node(1) | |
forest.add_node(2) | |
forest.add_node(3, dict(foo='bar')) | |
forest.add_node(4, forest[3]) | |
for node1 in forest.nodes: | |
for node2 in forest.nodes: | |
assert (node1 == node2) == (node1 is node2) | |
same_data = dict(node1) == dict(node2) | |
assert node1.data_eq(node2) == same_data | |
assert node1.data_eq(dict(node2)) == same_data | |
@pytest.mark.parametrize('prev_key,new_key', [ | |
(1, -1), | |
(1, 11), | |
(11, -11), | |
(11, 12), | |
(11, 11), | |
]) | |
def test_rename(self, forest, prev_key, new_key): | |
"""Test assignment to .key attribute.""" | |
node = forest[prev_key] | |
if new_key in forest: | |
with pytest.raises(KeyError): | |
node.key = new_key | |
else: | |
node.key = new_key | |
assert node.key == new_key | |
validate_relationships(node) | |
@pytest.mark.parametrize('key,onto_key', [ | |
(1, 3), | |
(1, None), | |
(11, 1), | |
(11, 2), | |
(11, None), | |
]) | |
@pytest.mark.parametrize('use_key', [False, True]) | |
def test_graft(self, forest, key, onto_key, use_key): | |
"""Test assignment to .parent attribute.""" | |
node = forest[key] | |
previous_parent = node.parent | |
new_parent = None if onto_key is None else forest[onto_key] | |
if use_key: | |
node.parent_key = onto_key | |
else: | |
node.parent = new_parent | |
assert node.parent is new_parent | |
validate_relationships(node) | |
if previous_parent is not None: | |
validate_relationships(previous_parent) | |
def test_graft_circular(self, forest): | |
"""Test grafting a node onto one of its descendants (not allowed).""" | |
with pytest.raises(ValueError): | |
forest[1].parent_key = 11 | |
with pytest.raises(ValueError): | |
forest[1].parent_key = 1 | |
@pytest.mark.parametrize('child_key', [11, 15]) | |
@pytest.mark.parametrize('data', [None, dict(), dict(foo='bar')]) | |
def test_add_child(self, forest, child_key, data): | |
"""Test the add_child() method.""" | |
parent = forest[1] | |
if child_key in forest: | |
with pytest.raises(KeyError): | |
parent.add_child(child_key, data) | |
else: | |
child = parent.add_child(child_key, data) | |
assert child.key == child_key | |
assert child.parent is parent | |
assert dict(child) == (data or {}) | |
@pytest.mark.parametrize('include_self', [False, True]) | |
def test_ancestors(self, forest, include_self): | |
"""Test the ancestors() method.""" | |
for node in forest.nodes: | |
ancestors = node.ancestors(include_self) | |
if node.parent is None and not include_self: | |
with pytest.raises(StopIteration): | |
next(ancestors) | |
continue | |
if include_self: | |
assert next(ancestors) is node | |
child = node | |
for parent in ancestors: | |
assert child.parent is parent | |
child = parent | |
assert child.parent is None | |
@pytest.mark.parametrize('include_self', [False, True]) | |
def test_ancestor_keys(self, forest, include_self): | |
"""Test the ancestor_keys() method.""" | |
for node in forest.nodes: | |
for ancestor, key in zip(node.ancestors(), node.ancestor_keys()): | |
assert ancestor.key == key | |
@pytest.mark.parametrize('order', ['pre', 'post']) | |
def test_traverse(self, forest, order): | |
for root in forest.nodes: | |
seen_keys = set() | |
it = root.traverse_pre() if order == 'pre' else root.traverse_post() | |
for node in it: | |
assert root in node.ancestors(True) | |
assert node.key not in seen_keys | |
if order == 'pre': | |
if node is not root: | |
assert node.parent.key in seen_keys | |
else: | |
assert seen_keys.issuperset(node.child_keys) | |
seen_keys.add(node.key) | |
assert len(seen_keys) == root.count() | |
def test_count(self, forest): | |
"""Test the count() method.""" | |
for node in forest.nodes: | |
assert node.count() == 1 + sum(c.count() for c in node.children) | |
def test_depth(self, forest): | |
"""Test the depth() method.""" | |
for node in forest.nodes: | |
assert node.depth() == len(list(node.ancestors())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment