Skip to content

Instantly share code, notes, and snippets.

@charris
Created July 24, 2020 00:00
Show Gist options
  • Save charris/62f2dadc0ab597196635e8803eab786a to your computer and use it in GitHub Desktop.
Save charris/62f2dadc0ab597196635e8803eab786a to your computer and use it in GitHub Desktop.
union-find
include "stdlib.pxd"
cdef class UnionFind :
"""Union Find class.
todo : check memory allocations
"""
cdef int *_dad
cdef int *_sib
cdef int _size
cdef int _capacity
cdef int _setcount
cdef void _reset(self) :
cdef int i
cdef int *dad = self._dad
cdef int *sib = self._sib
for i from 0 <= i < self._size :
dad[i] = -1
sib[i] = i
self._setcount = self._size
cdef int _root(self, int elt) :
cdef int top
cdef int nxt
cdef int *dad = self._dad
if elt >= self._size :
return -1
# find root
top = elt
while dad[top] >= 0 :
top = dad[top]
# compress links
if elt != top :
while dad[elt] != top :
nxt = dad[elt]
dad[elt] = top
elt = nxt
return top
cdef int _is_valid_index(self, int i) :
return 0 <= i < self._size
def __cinit__(self, int size=0, int capacity=100) :
cdef int i
if size > capacity :
capacity = size
self._size = size
self._capacity = capacity
self._dad = <int *>malloc(capacity*sizeof(int))
self._sib = <int *>malloc(capacity*sizeof(int))
self._reset()
def __dealloc__(self) :
free(self._dad)
free(self._sib)
def size(self) :
return self._size
def capacity(self) :
return self._capacity
def setcount(self) :
return self._setcount
def clear(self) :
self._size = 0
self._setcount = 0
def reset(self) :
self._reset()
def felt(self, int elt1, int elt2) :
assert self._is_valid_index(elt1), "elt1 out of range."
assert self._is_valid_index(elt2), "elt2 out of range."
return self._root(elt1) != self._root(elt2)
def union(self, elts) :
cdef int n = len(elts)
cdef int *dad = self._dad
cdef int *sib = self._sib
cdef int r1, r2, tmp, i, elt1, elt2
if n == 0 :
return
elt1 = elts[0]
assert self._is_valid_index(elt1), "element out of range"
r1 = self._root(elts[0])
for i from 1 <= i < n :
elt2 = elts[i]
assert self._is_valid_index(elt2), "element out of range"
r2 = self._root(elts[i])
if r1 != r2 :
self._setcount -= 1
tmp = sib[r1]
sib[r1] = sib[r2]
sib[r2] = tmp
if dad[r1] < dad[r2] :
dad[r1] += dad[r2]
dad[r2] = r1
else :
dad[r2] += dad[r1]
dad[r1] = r2
def getset(self, int elt) :
cdef int *sib = self._sib
cdef int nxt
assert self._is_valid_index(elt), "elt is out of range"
set = [elt]
nxt = sib[elt]
while nxt != elt :
set.append(nxt)
nxt = sib[nxt]
return set
def getallsets(self) :
cdef int *dad = self._dad
cdef int n = self._size
cdef int i
sets = []
for i from 0 <= i < n :
if dad[i] < 0 :
sets.append(self.getset(i))
return sets
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment