Skip to content

Instantly share code, notes, and snippets.

Created January 5, 2021 14:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MaxHalford/ce20368786432273a3bb2a1b0b4c0a70 to your computer and use it in GitHub Desktop.
Save MaxHalford/ce20368786432273a3bb2a1b0b4c0a70 to your computer and use it in GitHub Desktop.
Tree implementation
"""Generic branch and leaf implementation."""
import operator
import textwrap
import typing
class Op:
"""An operator that is part of a split."""
__slots__ = "symbol", "func"
def __init__(self, symbol, func):
self.symbol = symbol
self.func = func
def __call__(self, a, b):
return self.func(a, b)
def __repr__(self):
return self.symbol
LT = Op("<",
EQ = Op("=", operator.eq)
class Split:
"""A split that is part of a branch."""
def __init__(self, on, how, at):
self.on = on = how = at
def __call__(self, x):
def __repr__(self):
at =
if isinstance(at, float):
at = f"{at:.5f}"
return f"{self.on} {repr(} {at}"
class Node:
"""A node is a branch or a leaf."""
def __init__(self, **kwargs):
class Branch(Node):
"""A branch is made up of a split and two child nodes."""
def __init__(self, split, left, right, **kwargs):
self.split = split
self.left = left
self.right = right
def next(self, x):
"""Returns the next node where x belongs."""
return self.left if self.split(x) else self.right
def path(self, x):
"""Iterate over the nodes that lead to the leaf which contains x."""
yield self
yield from
def __repr__(self):
>>> tree = Branch(
... Split('x', LT, 17.42),
... Branch(
... Split('y', LT, -4.38),
... Leaf(no=2),
... Leaf(no=3),
... no=1
... ),
... Leaf(no=4),
... no=0
... )
>>> tree
x < 17.42000 {'no': 0}
y < -4.38000 {'no': 1}
{'no': 2}
{'no': 3}
{'no': 4}
return (
+ " "
+ str({k: v for k, v in self.__dict__.items() if k not in ("split", "left", "right")})
+ textwrap.indent(f"\n{self.left}\n{self.right}", prefix=" " * 2)
def size(self):
"""Number of descendants, including thyself."""
return 1 + self.left.size + self.right.size
def height(self):
"""Distance to the deepest node."""
return 1 + max(self.left.height, self.right.height)
def iter_dfs(self, depth=0):
"""Iterate over nodes in a depth-first manner.
>>> tree = Branch(
... None,
... Branch(
... None,
... Leaf(no=2),
... Leaf(no=3),
... no=1
... ),
... Leaf(no=4),
... no=0
... )
>>> for node, depth in tree.iter_dfs():
... print(f'#{}, depth {depth}')
#0, depth 0
#1, depth 1
#2, depth 2
#3, depth 2
#4, depth 1
yield self, depth
yield from self.left.iter_dfs(depth=depth + 1)
yield from self.right.iter_dfs(depth=depth + 1)
def iter_leaves(self):
"""Iterate over the leaves in a depth-first manner.
>>> tree = Branch(
... None,
... Branch(
... None,
... Leaf(no=2),
... Leaf(no=3),
... no=1
... ),
... Leaf(no=4),
... no=0
... )
>>> for leaf in tree.iter_leaves():
... print(f'#{}')
yield from self.left.iter_leaves()
yield from self.right.iter_leaves()
def iter_branches(self):
"""Iterate over branches in a depth-first manner.
>>> tree = Branch(
... None,
... Branch(
... None,
... Leaf(no=2),
... Leaf(no=3),
... no=1
... ),
... Leaf(no=4),
... no=0
... )
>>> for branch in tree.iter_branches():
... print(f'#{}')
yield self
yield from self.left.iter_branches()
yield from self.right.iter_branches()
def iter_edges(self):
"""Iterate over edges in a depth-first manner.
>>> tree = Branch(
... None,
... Branch(
... None,
... Leaf(no=2),
... Leaf(no=3),
... no=1
... ),
... Leaf(no=4),
... no=0
... )
>>> for parent_no, child_no, parent, child, child_depth in tree.iter_edges():
... print(parent_no, child_no, child_depth)
None 0 0
0 1 1
1 2 2
1 3 2
0 4 1
counter = 0
def iterate(node, depth):
nonlocal counter
no = counter
if isinstance(node, Branch):
for child in (node.left, node.right):
counter += 1
yield no, counter, node, child, depth + 1
if isinstance(child, Branch):
yield from iterate(child, depth=depth + 1)
yield None, 0, None, self, 0
yield from iterate(self, depth=0)
def iter_blocks(
self, limits: typing.Dict[typing.Hashable, typing.Tuple[float, float]], depth=-1
"""Iterate over the blocks which enclose each node.
This only makes sense if the branches of the provided tree use the `<` operator as a
split rule.
Desired tree depth. Set to `-1` to iterate over the leaves.
>>> import operator
>>> tree = Branch(
... Split('x',, .5),
... Leaf(no=0),
... Branch(
... Split('y',, .5),
... Leaf(no=1),
... Leaf(no=2)
... )
... )
>>> for leaf, block in tree.iter_blocks(limits={'x': (0, 1), 'y': (0, 1)}):
... print(, block)
0 {'x': (0, 0.5), 'y': (0, 1)}
1 {'x': (0.5, 1), 'y': (0, 0.5)}
2 {'x': (0.5, 1), 'y': (0.5, 1)}
if depth == 0:
yield (self, limits)
on, at = self.split.on,
l_limits = {**limits, on: (limits[on][0], at)} if on in limits else limits
yield from self.left.iter_blocks(limits=l_limits, depth=depth - 1)
r_limits = {**limits, on: (at, limits[on][1])} if on in limits else limits
yield from self.right.iter_blocks(limits=r_limits, depth=depth - 1)
def iter_splits(self, limits: typing.Dict[typing.Hashable, typing.Tuple[float, float]]):
"""Iterate over splits.
This only makes sense if the branches of the provided tree use the `<` operator as a split
>>> import operator
>>> tree = Branch(
... Split('x',, .5),
... Leaf(no=0),
... Branch(
... Split('y',, .5),
... Leaf(no=1),
... Leaf(no=2)
... )
... )
>>> for line in tree.iter_splits({'x': (0, 1), 'y': (0, 1)}):
... print(line)
{'x': (0.5, 0.5), 'y': (0, 1)}
{'x': (0.5, 1), 'y': (0.5, 0.5)}
on, at = self.split.on,
yield {**limits, on: (at, at)}
yield from self.left.iter_splits(limits={**limits, on: (limits[on][0], at)})
yield from self.right.iter_splits(limits={**limits, on: (at, limits[on][1])})
class Leaf(Node):
def __repr__(self):
return str(self.__dict__)
def path(self, x):
yield self
def size(self):
return 1
def height(self):
return 0
def iter_dfs(self, depth=0):
yield self, depth
def iter_leaves(self):
yield self
def iter_branches(self):
yield from ()
def iter_edges(self):
yield None, 0, None, self, 0
def iter_blocks(self, limits, depth=-1):
yield self, limits
def iter_splits(self, limits):
yield from ()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment