Skip to content

Instantly share code, notes, and snippets.

@timgianitsos
Last active January 24, 2023 23:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save timgianitsos/0878a0b241cb5d0ad8b16ebc2b14322a to your computer and use it in GitHub Desktop.
Save timgianitsos/0878a0b241cb5d0ad8b16ebc2b14322a to your computer and use it in GitHub Desktop.
Union-Find data structure, (aka Disjoint-Set) https://en.wikipedia.org/wiki/Disjoint-set_data_structure
'''
Union-find data structure. Based on Josiah Carlson's code,
http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/215912
with significant additional changes by D. Eppstein
https://www.ics.uci.edu/~eppstein/PADS/UnionFind.py
with additional changes by Tim Gianitsos
'''
class UnionFind:
'''
The Union-find data structure maintains a family of disjoint sets
of hashable objects. This allows for the efficient computation of
a specific use-case. The use-case goes by many names: number of
unique transitive closures, number of connected components,
number of equivalence classes, and the zeroth betti number.
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
'''
def __init__(self):
'''Create a new empty union-find structure.'''
self.weights = {}
self.node_to_parent = {}
self.num_connected_components = 0
self.len = 0
def __iter__(self):
'''Iterate through all items ever found or unioned by this structure.'''
return iter(self.node_to_parent)
def __repr__(self):
return str(list(self.components()))
def __len__(self):
return self.len
def __getitem__(self, obj):
'''
Returns a name for the set containing the given item.
Each set is named by an arbitrarily-chosen one of its members; as
long as the set remains unchanged it will keep the same name. If
the item is not yet part of a set in X, a new singleton set is
created for it.
'''
# check for previously unknown obj
if obj not in self.node_to_parent:
self.node_to_parent[obj] = obj
self.weights[obj] = 1
self.num_connected_components += 1
self.len += 1
return obj
# find path of items leading to the root
path = [obj]
root = self.node_to_parent[obj]
while root != path[-1]:
path.append(root)
root = self.node_to_parent[root]
# compress the path and return
for ancestor in path:
self.node_to_parent[ancestor] = root
return root
def union(self, *items):
'''
Merges the sets containing each item into a single larger set.
If any item is not yet part of a set in X, it is added to X
as one of the members of the merged set.
'''
roots = {self[x] for x in items}
heaviest = max(roots, key=lambda r: self.weights[r]) # Argmax
roots.remove(heaviest)
for root in roots:
self.weights[heaviest] += self.weights[root]
self.node_to_parent[root] = heaviest
self.num_connected_components -= len(roots)
def components(self):
'''Return a collection of the connected components'''
res = {}
for node in self.node_to_parent:
res.setdefault(self[node], set()).add(node)
return res.values()
if __name__ == '__main__':
u = UnionFind()
assert u.num_connected_components == 0
assert len(u) == 0
u.union(1, 2)
assert u.num_connected_components == 1
assert len(u) == 2
u.union(2, 3)
assert u.num_connected_components == 1
assert len(u) == 3
u.union(4, 5)
assert u.num_connected_components == 2
assert len(u) == 5
u.union(6, 7)
assert u.num_connected_components == 3
assert len(u) == 7
u.union(7, 8)
assert u.num_connected_components == 3
assert len(u) == 8
u.union(8, 9)
assert u.num_connected_components == 3
assert len(u) == 9
u.union(8, 9) # duplicate
assert u.num_connected_components == 3
assert len(u) == 9
u.union(10)
assert u.num_connected_components == 4
assert len(u) == 10
u.union(4, 7)
assert u.num_connected_components == 3
assert len(u) == 10
print(f'Connected components: {u}')
print('Success!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment