Last active
September 12, 2024 09:09
-
-
Save timgianitsos/0878a0b241cb5d0ad8b16ebc2b14322a to your computer and use it in GitHub Desktop.
Union-Find data structure, (aka Disjoint-Set) https://en.wikipedia.org/wiki/Disjoint-set_data_structure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Union-find data structure. Based on Josiah Carlson's code, | |
http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/215912 | |
with significant additional changes by D. Eppstein | |
https://www.ics.uci.edu/~eppstein/PADS/UnionFind.py | |
with additional changes by Tim Gianitsos | |
''' | |
class UnionFind: | |
''' | |
The Union-find data structure maintains a family of disjoint sets | |
of hashable objects. This allows for the efficient computation of | |
a specific use-case. The use-case goes by many names: number of | |
unique transitive closures, number of connected components, | |
number of equivalence classes, and the zeroth betti number. | |
https://en.wikipedia.org/wiki/Disjoint-set_data_structure | |
''' | |
def __init__(self): | |
'''Create a new empty union-find structure.''' | |
self.weights = {} | |
self.node_to_parent = {} | |
self.num_connected_components = 0 | |
self.len = 0 | |
def __iter__(self): | |
'''Iterate through all items ever found or unioned by this structure.''' | |
return iter(self.node_to_parent) | |
def __repr__(self): | |
return str(list(self.components())) | |
def __len__(self): | |
return self.len | |
def __getitem__(self, obj): | |
''' | |
Returns a name for the set containing the given item. | |
Each set is named by an arbitrarily-chosen one of its members; as | |
long as the set remains unchanged it will keep the same name. If | |
the item is not yet part of a set in X, a new singleton set is | |
created for it. | |
''' | |
# check for previously unknown obj | |
if obj not in self.node_to_parent: | |
self.node_to_parent[obj] = obj | |
self.weights[obj] = 1 | |
self.num_connected_components += 1 | |
self.len += 1 | |
return obj | |
# find path of items leading to the root | |
path = [obj] | |
root = self.node_to_parent[obj] | |
while root != path[-1]: | |
path.append(root) | |
root = self.node_to_parent[root] | |
# compress the path and return | |
for ancestor in path: | |
self.node_to_parent[ancestor] = root | |
return root | |
def union(self, *items): | |
''' | |
Merges the sets containing each item into a single larger set. | |
If any item is not yet part of a set in X, it is added to X | |
as one of the members of the merged set. | |
''' | |
roots = {self[x] for x in items} | |
heaviest = max(roots, key=lambda r: self.weights[r]) # Argmax | |
roots.remove(heaviest) | |
for root in roots: | |
self.weights[heaviest] += self.weights[root] | |
self.node_to_parent[root] = heaviest | |
self.num_connected_components -= len(roots) | |
def components(self): | |
'''Return a collection of the connected components''' | |
res = {} | |
for node in self.node_to_parent: | |
res.setdefault(self[node], set()).add(node) | |
return res.values() | |
if __name__ == '__main__': | |
u = UnionFind() | |
assert u.num_connected_components == 0 | |
assert len(u) == 0 | |
u.union(1, 2) | |
assert u.num_connected_components == 1 | |
assert len(u) == 2 | |
u.union(2, 3) | |
assert u.num_connected_components == 1 | |
assert len(u) == 3 | |
u.union(4, 5) | |
assert u.num_connected_components == 2 | |
assert len(u) == 5 | |
u.union(6, 7) | |
assert u.num_connected_components == 3 | |
assert len(u) == 7 | |
u.union(7, 8) | |
assert u.num_connected_components == 3 | |
assert len(u) == 8 | |
u.union(8, 9) | |
assert u.num_connected_components == 3 | |
assert len(u) == 9 | |
u.union(8, 9) # duplicate | |
assert u.num_connected_components == 3 | |
assert len(u) == 9 | |
u.union(10) | |
assert u.num_connected_components == 4 | |
assert len(u) == 10 | |
u.union(4, 7) | |
assert u.num_connected_components == 3 | |
assert len(u) == 10 | |
print(f'Connected components: {u}') | |
print('Success!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment