Skip to content

Instantly share code, notes, and snippets.

@ianjosephwilson
Last active April 29, 2024 22:02
Show Gist options
  • Save ianjosephwilson/fe169f3ab5d1d34e9d25c2e1cea077b8 to your computer and use it in GitHub Desktop.
Save ianjosephwilson/fe169f3ab5d1d34e9d25c2e1cea077b8 to your computer and use it in GitHub Desktop.
from itertools import islice
def p_strs(items):
""" Generate strings that represent all the partitions of the given items.
Examples:
items = [1, 2, 3, 4]
assert len(items) == 4
[0,0,0,0] -- [[1,2,3,4]]
[0,0,0,1] -- [[1,2,3], [4]]
...
[0,1,2,3] -- [[1], [2], [3], [4]]
"""
items_len = len(items)
p_str = [0]*items_len
yield p_str
# Optimization: Don't remake the reverse indices everytime.
indices_to_try = tuple(range(items_len-1, 0, -1))
# Optimization: Don't call max on the prefix
# everytime.
# Each entry represents the largest number in
# the list up to and including that index.
max_track = [0]*items_len
# Save forward indices for updating
# any dependent max entries further down the line.
forward_indices_by_start = tuple([tuple(range(i, items_len)) for i in range(items_len)])
last_index = items_len - 1
while 1:
# Do not visit i=0 because it is always 0.
for i in indices_to_try:
if p_str[i] > max_track[i-1]:
# If the current part number is already
# bigger than all prior part numbers
# then the new prefix would be invalid.
# This assumes i != 0.
continue
else:
p_str = p_str[:]
p_char = p_str[i] + 1
p_str[i] = p_char
new_max = max_track[i] = max(p_char, max_track[i-1])
if i != last_index:
for j in forward_indices_by_start[i+1]:
# Update max tracker for this index
# and everything down the line because
# we are the new max (they will be 0)
max_track[j] = new_max
# Algorithm: After increasing our index
# we must place 0 in the trailing slots.
p_str[j] = 0
yield p_str
break
else:
break
def is_prefix_valid(p_str):
"""
This should probably be inlined back into the main function.
"""
if len(p_str) == 1:
return p_str == [0]
# The largest part number so far.
max_p = max(p_str[:-1])
return p_str[-1] <= max_p + 1
def get_partitions(items):
for p_str in p_strs(items):
yield get_parts(p_str, items)
def get_parts(p_str, items):
# Return parts from smallest part num to largest.
last_part_num = -1
indices = tuple(range(len(items)))
for part_num in p_str:
if part_num > last_part_num:
last_part_num = part_num
# Get all members of the part in original order.
yield (items[i] for i in indices if p_str[i] == part_num)
def simple_test():
found = [p_str for p_str in p_strs([1, 2, 3, 4])]
# These are from a research paper: https://arxiv.org/pdf/2105.07472.pdf
expected = [[int(c) for c in s] for s in "0000, 0001, 0010, 0011, 0012, 0100, 0101, 0102, 0110, 0111, 0112, 0120, 0121, 0122, 0123".split(', ')]
from pprint import pprint, pformat
assert found == expected, f"found({pformat(found)})!=expected({pformat(expected)})"
def test_partition():
bell_numbers = [1, 1, 2, 5, 15, 52, 203, 877, 4140, 21147, 115975, 678570, 4213597, 27644437]
simple_test()
length = 6
original_items = list(range(length))
partitions_count = 0
parts_count = 0
items_count = 0
for partition in get_partitions(original_items):
partitions_count += 1
for subset in partition:
parts_count += 1
for item in subset:
items_count += 1
#partition_str = ",".join([f'{{{",".join([str(item) for item in subset])}}}' for subset in partition])
assert partitions_count == bell_numbers[length], f"partitions count: {partitions_count} != bell number for {length}: {bell_numbers[length]}"
print (f"Found {partitions_count} partitions, {parts_count} parts and {items_count} items for length {length}")
if __name__ == '__main__':
test_partition()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment