Skip to content

Instantly share code, notes, and snippets.

@tkmharris
Last active May 15, 2021 14:33
Show Gist options
  • Save tkmharris/9e158305ed61c157875b48bc721b4955 to your computer and use it in GitHub Desktop.
Save tkmharris/9e158305ed61c157875b48bc721b4955 to your computer and use it in GitHub Desktop.
Quick implementation of the two trees for partitions discussed here: https://11011110.github.io/blog/2005/08/07/two-trees-on.html
"""
Quick implementation of the two trees for partitions discussed here:
https://11011110.github.io/blog/2005/08/07/two-trees-on.html
"""
from treelib import Node, Tree
# partition class
class InvalidPartitionError(Exception):
pass
class Partition:
def __init__(self, parts):
# check we have a valid partition
for idx in range(len(parts)):
if not isinstance(parts[idx], int) or parts[idx] <= 0:
raise InvalidPartitionError("Parts of a partition must be positive ints")
if idx > 0:
if parts[idx] > parts[idx - 1]:
raise ValueError("Parts of a partition must be non-increasing")
self.parts = parts
self.weight = sum(parts)
def __str__(self):
return ' '.join(map(str, self.parts))
def fixed_n_partitions_tree_children(self):
"""
returns: (possibly empty) list of partitions that are immediate descendents of self in the tree of partitions of fixed weight self.weight
"""
children = []
head, tail = self.parts[0], self.parts[1:]
for i in range(head - 1, (head - 1)//2, -1):
if tail == [] or (head - i) >= tail[0]:
child = Partition([i, head - i] + tail)
children.append(child)
return children
def all_partitions_tree_children(self):
"""
returns: (possibly empty) list of partitions that are immediate descendents of self in the infinte tree of all partitions
"""
child1 = Partition(self.parts + [1])
if len(self.parts) == 1 or self.parts[-1] < self.parts[-2]:
child2 = Partition(self.parts[:-1] + [self.parts[-1] + 1])
return [child1, child2]
else:
return [child1]
# tree creation methods
def fixed_n_partitions_tree(num):
"""
num: positive integer
returns: treelib tree of all partitions of num
"""
tree = Tree()
root = Partition([num])
queue = [(root, None)]
while queue:
partition, parent = queue.pop(0)
if parent is not None:
tree.create_node(str(partition), partition, parent=parent)
else:
tree.create_node(str(partition), partition)
for child in partition.fixed_n_partitions_tree_children():
queue.append((child, partition))
return tree
def all_partitions_tree(depth):
"""
depth: positive integer
returns treelib tree of all partitions of weight <= depth
"""
tree = Tree()
root = Partition([1])
queue = [(root, None)]
while queue:
partition, parent = queue.pop(0)
if parent is not None:
tree.create_node(str(partition), partition, parent=parent)
else:
tree.create_node(str(partition), partition)
if partition.weight < depth:
for child in partition.all_partitions_tree_children():
queue.append((child, partition))
return tree
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment