Skip to content

Instantly share code, notes, and snippets.

@behdad
Created June 2, 2020 17:11
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save behdad/0d36c27b6a664726a5b9c6050fc27d42 to your computer and use it in GitHub Desktop.
Save behdad/0d36c27b6a664726a5b9c6050fc27d42 to your computer and use it in GitHub Desktop.
unionfind.py
from __future__ import print_function, division, absolute_import
#from fontTools.misc.py23 import *
class UnionFind(object):
def __init__(self, items):
self._sets = dict((x,set([x])) for x in items)
self._map = dict((x,x) for x in items)
def _union_pair(self, a, b):
mapping = self._map
a, b = mapping[a], mapping[b]
if a != b:
sa, sb = self._sets[a], self._sets[b]
if len(sa) < len(sb):
a, b = b, a
sa, sb = sb, sa
sa.update(sb)
for item in sb:
mapping[item] = a
del self._sets[b]
return a
def union(self, *args):
return reduce(self._union_pair, args)
def find(self, item):
return self._map[item]
def __getitem__(self, item):
return self.find(item)
def get_mapping(self):
return dict(self._map)
def get_set(self, item):
return set(self._sets[self._map[item]])
def get_frozenset(self, item):
return frozenset(self._sets[self._map[item]])
def get_sets(self):
return [set(s) for s in self._sets.values()]
def get_frozensets(self):
return frozenset(frozenset(s) for s in self._sets.values())
uf = UnionFind([1, 2, 3, 4, 5, 6, 7, 8])
uf.union(1, 2, 3)
uf.union(8, 6)
uf.union(4, 5)
uf.union(8, 2)
print(uf.find(4))
print(uf[4])
print(uf[5])
print(uf.get_mapping())
print(uf.get_frozensets() == set([frozenset([1,2,3,6,8]), frozenset([4,5]), frozenset([7])]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment