Skip to content

Instantly share code, notes, and snippets.

@tamuhey
Created September 24, 2019 18:44
Show Gist options
  • Save tamuhey/c151f7242cfd22e72bb56b7e5034ed37 to your computer and use it in GitHub Desktop.
Save tamuhey/c151f7242cfd22e72bb56b7e5034ed37 to your computer and use it in GitHub Desktop.
class UnionFind:
def __init__(self, objects):
self.objects = list(objects)
self.weights = {i: 1 for i in range(len(self.objects))}
self.parents = list(range(len(self.objects)))
self.obj2num = {k: i for i, k in enumerate(self.objects)}
def add(self, obj):
self.objects.append(obj)
n = len(self.objects)-1
self.obj2num[obj] = n
self.weights[n] = 1
self.parents.append(n)
return obj
def find(self, obj):
if obj not in self.obj2num:
return self.add(obj)
num = self.obj2num[obj]
pi = self._find(num)
return self.objects[pi]
def _find(self, i):
pi = self.parents[i]
if i == self.parents[i]:
return i
pi = self._find(pi)
self.parents[i] = pi
return pi
def union(self, obj1, obj2):
po1 = self.find(obj1)
po2 = self.find(obj2)
if po1 == po2:
return
n1 = self.obj2num[po1]
n2 = self.obj2num[po2]
w1 = self.weights[n1]
w2 = self.weights[n2]
if w1 > w2:
n2, n1 = n1, n2
self.parents[n2] = n1
self.weights[n1] = w2+w1
del self.weights[n2]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment