Created
November 20, 2013 11:52
-
-
Save isaacl/7561953 to your computer and use it in GitHub Desktop.
Disjoint set data structure.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
""" | |
Simple implementation of disjoint set data structure using | |
path compression and union by rank. | |
Sequences of Find and Union run in O(a(x)) where a() is the | |
inverse of the Ackermann function. That means it's fast, for | |
python. | |
http://en.wikipedia.org/wiki/Disjoint-set_data_structure | |
""" | |
class DisjointSets(object): | |
def __init__(self, iterable=None): | |
self.values = {} | |
self.ranks = [] | |
self.parents = [] | |
if iterable: | |
for obj in iterable: | |
self.Find(obj) | |
def __hash__(self): | |
raise TypeError('Unhashable type.') | |
def __len__(self): | |
return len(self.values) | |
def FindByNum(self, cur_num): | |
"""Find root of existing elt with path compression.""" | |
cur_par = self.parents[cur_num] | |
elts = [] | |
while cur_num != cur_par: | |
elts.append(cur_num) | |
cur_num = cur_par | |
cur_par = self.parents[cur_num] | |
for val in elts: | |
self.parents[val] = cur_num | |
return cur_num | |
def Find(self, obj): | |
"""Return index corresponding to root.""" | |
if obj in self.values: | |
return self.FindByNum(self.values[obj]) | |
else: | |
num = len(self) | |
self.ranks.append(0) | |
self.values[obj] = num | |
self.parents.append(num) | |
return num | |
def UnionByNum(self, val1, val2): | |
"""Combine two trees using known indices.""" | |
root1 = self.FindByNum(val1) | |
root2 = self.FindByNum(val2) | |
if root1 == root2: | |
return | |
rank1 = self.ranks[root1] | |
rank2 = self.ranks[root2] | |
# Lower rank tree points to higher ranked tree. | |
if rank1 < rank2: | |
self.parents[root1] = root2 | |
elif rank1 > rank2: | |
self.parents[root2] = root1 | |
else: | |
self.parents[root2] = root1 | |
self.ranks[root1] += 1 | |
def Union(self, obj1, obj2): | |
"""Combine two trees.""" | |
return self.UnionByNum(self.values[obj1], self.values[obj2]) | |
def GetSets(self): | |
"""Return the groups of objects.""" | |
roots = {} | |
for val, num in self.values.iteritems(): | |
roots.setdefault(self.FindByNum(num), []).append(val) | |
return roots.values() | |
def __str__(self): | |
return str(self.GetSets()) | |
def __iter__(self): | |
return iter(self.GetSets()) | |
def __repr__(self): | |
parent_lists = {} | |
for num, par in enumerate(self.parents): | |
if num != par: | |
parent_lists.setdefault(par, []) | |
roots = [] | |
for obj, num in self.values.iteritems(): | |
node = (num, obj) | |
if num in parent_lists: | |
node = (node, parent_lists[num]) | |
if self.parents[num] == num: | |
roots.append(node) | |
else: | |
parent_lists[self.parents[num]].append(node) | |
return str(sorted(roots)) | |
if __name__ == '__main__': | |
import random | |
import time | |
import pprint | |
def RunSteps(elts, prob): | |
ss = DisjointSets() | |
starttime = time.time() | |
unions = 0 | |
for size, elt in enumerate(elts): | |
if size > 2 and random.random() < prob * (1.3 * size - 5) / size: | |
ss.UnionByNum(*random.sample(xrange(size), 2)) | |
unions += 1 | |
ss.Find(elt) | |
return ss, time.time() - starttime, unions + size | |
sets, _, _ = RunSteps('ABCDEFGHIJKLMNO', .8) | |
pprint.pprint(sorted(sets.GetSets())) | |
print '\nRunning time test...' | |
sets, elapsed, ops = RunSteps(xrange(10**5), .4) | |
print '%.1f ns/op for %d ops. Avg set size: %.1f.' % ( | |
10**9 * elapsed / ops, ops, len(sets) * 1.0/len(sets.GetSets())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment