Skip to content

Instantly share code, notes, and snippets.

@isaacl
Created November 20, 2013 11:52
Show Gist options
  • Save isaacl/7561953 to your computer and use it in GitHub Desktop.
Save isaacl/7561953 to your computer and use it in GitHub Desktop.
Disjoint set data structure.
#!/usr/bin/python
"""
Simple implementation of disjoint set data structure using
path compression and union by rank.
Sequences of Find and Union run in O(a(x)) where a() is the
inverse of the Ackermann function. That means it's fast, for
python.
http://en.wikipedia.org/wiki/Disjoint-set_data_structure
"""
class DisjointSets(object):
def __init__(self, iterable=None):
self.values = {}
self.ranks = []
self.parents = []
if iterable:
for obj in iterable:
self.Find(obj)
def __hash__(self):
raise TypeError('Unhashable type.')
def __len__(self):
return len(self.values)
def FindByNum(self, cur_num):
"""Find root of existing elt with path compression."""
cur_par = self.parents[cur_num]
elts = []
while cur_num != cur_par:
elts.append(cur_num)
cur_num = cur_par
cur_par = self.parents[cur_num]
for val in elts:
self.parents[val] = cur_num
return cur_num
def Find(self, obj):
"""Return index corresponding to root."""
if obj in self.values:
return self.FindByNum(self.values[obj])
else:
num = len(self)
self.ranks.append(0)
self.values[obj] = num
self.parents.append(num)
return num
def UnionByNum(self, val1, val2):
"""Combine two trees using known indices."""
root1 = self.FindByNum(val1)
root2 = self.FindByNum(val2)
if root1 == root2:
return
rank1 = self.ranks[root1]
rank2 = self.ranks[root2]
# Lower rank tree points to higher ranked tree.
if rank1 < rank2:
self.parents[root1] = root2
elif rank1 > rank2:
self.parents[root2] = root1
else:
self.parents[root2] = root1
self.ranks[root1] += 1
def Union(self, obj1, obj2):
"""Combine two trees."""
return self.UnionByNum(self.values[obj1], self.values[obj2])
def GetSets(self):
"""Return the groups of objects."""
roots = {}
for val, num in self.values.iteritems():
roots.setdefault(self.FindByNum(num), []).append(val)
return roots.values()
def __str__(self):
return str(self.GetSets())
def __iter__(self):
return iter(self.GetSets())
def __repr__(self):
parent_lists = {}
for num, par in enumerate(self.parents):
if num != par:
parent_lists.setdefault(par, [])
roots = []
for obj, num in self.values.iteritems():
node = (num, obj)
if num in parent_lists:
node = (node, parent_lists[num])
if self.parents[num] == num:
roots.append(node)
else:
parent_lists[self.parents[num]].append(node)
return str(sorted(roots))
if __name__ == '__main__':
import random
import time
import pprint
def RunSteps(elts, prob):
ss = DisjointSets()
starttime = time.time()
unions = 0
for size, elt in enumerate(elts):
if size > 2 and random.random() < prob * (1.3 * size - 5) / size:
ss.UnionByNum(*random.sample(xrange(size), 2))
unions += 1
ss.Find(elt)
return ss, time.time() - starttime, unions + size
sets, _, _ = RunSteps('ABCDEFGHIJKLMNO', .8)
pprint.pprint(sorted(sets.GetSets()))
print '\nRunning time test...'
sets, elapsed, ops = RunSteps(xrange(10**5), .4)
print '%.1f ns/op for %d ops. Avg set size: %.1f.' % (
10**9 * elapsed / ops, ops, len(sets) * 1.0/len(sets.GetSets()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment