Skip to content

Instantly share code, notes, and snippets.

@inside-code-yt
Created September 30, 2022 14:33
Show Gist options
  • Save inside-code-yt/daa2792992422d4d2bdbb4a39ed99120 to your computer and use it in GitHub Desktop.
Save inside-code-yt/daa2792992422d4d2bdbb4a39ed99120 to your computer and use it in GitHub Desktop.
import networkx as nx
import matplotlib.pyplot as plt
class DisjointSet:
def __init__(self, elems):
self.elems = elems
self.parent = {}
self.size = {}
self.graph = nx.DiGraph()
for elem in elems:
self.make_set(elem)
def make_set(self, x):
self.parent[x] = x
self.size[x] = 1
self.graph.add_edge(x, x)
def find(self, x):
if self.parent[x] == x:
return x
else:
self.graph.remove_edge(x, self.parent[x])
self.parent[x] = self.find(self.parent[x])
self.graph.add_edge(x, self.parent[x])
return self.parent[x]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return
elif self.size[root_x] < self.size[root_y]:
self.graph.remove_edge(root_x, self.parent[root_x])
self.parent[root_x] = root_y
self.graph.add_edge(root_x, self.parent[root_x])
self.size[root_y] += self.size[root_x]
else:
self.graph.remove_edge(root_y, self.parent[root_y])
self.parent[root_y] = root_x
self.graph.add_edge(root_y, self.parent[root_y])
self.size[root_x] += self.size[root_y]
def show(self):
nx.draw(self.graph, with_labels=True)
plt.show()
data = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
ds = DisjointSet(data)
ds.union('B', 'F')
ds.union('F', 'H')
print(ds.find('H'))
ds.union('A', 'G')
ds.union('F', 'A')
ds.show()
print(ds.find('G'))
ds.union('C', 'D')
ds.union('C', 'E')
print(ds.find('E'))
ds.union('D', 'A')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment