Skip to content

Instantly share code, notes, and snippets.

@LiutongZhou
Last active June 5, 2022 21:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save LiutongZhou/f67bf1a1546a996531e51ccb81898d5b to your computer and use it in GitHub Desktop.
Save LiutongZhou/f67bf1a1546a996531e51ccb81898d5b to your computer and use it in GitHub Desktop.
"""UnionFind (Disjoint Sets)"""
from typing import Optional, Iterable, Hashable, Any
class UnionFind:
def __init__(
self, initial_disjoint_items: Optional[Iterable[Hashable]] = None
):
"""Initialize a UnionFind of disjoint sets"""
# maps from element to its root parent
self.parent = (
{u: u for u in initial_disjoint_items} if initial_disjoint_items else {}
)
self._size = dict.fromkeys(self.parent, 1) # size of each disjoint set
self.num_sets = len(self.parent)
def add(self, u: Hashable):
"""Add an isolated item as a new disjoint set"""
if u not in self.parent:
self.parent[u] = u
self.num_sets += 1
self._size[u] = 1
def find(self, u: Hashable) -> Any:
"""Returns the root parent of the set that element u belongs to"""
parent, size, _u = self.parent, self._size, u
assert u in parent, f"{u} has not been added yet"
# find root parent
while (pu := parent[_u]) != _u:
_u = pu
_u = u
# Path compression
while _u != pu:
size.pop(_u, None)
_u, parent[_u] = parent[_u], pu
return pu
def union(self, u: Hashable, v: Hashable):
"""Union two disjoint sets if u and v are in disjoint sets"""
pu, pv = self.find(u), self.find(v)
size = self._size
if pu != pv: # skip if u and v are in same set already
if size[pu] < size[pv]: # Merge u set to v set
self.parent[pu] = pv
size[pv] += size[pu]
size.pop(pu, None)
else:
self.parent[pv] = pu
size[pu] += size[pv]
size.pop(pv, None)
self.num_sets -= 1
def is_connected(self, u: Hashable, v: Hashable) -> bool:
"""Return True if u is connected with v else False"""
return self.find(u) == self.find(v)
def get_set_size(self, u: Hashable) -> int:
"""Return size of the disjoint set that u belongs to"""
return self._size[self.find(u)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment