Skip to content

Instantly share code, notes, and snippets.

@L-TChen
Last active January 22, 2019 22:48
Show Gist options
  • Save L-TChen/c7aad499760a0852fb4bc39771a25b36 to your computer and use it in GitHub Desktop.
Save L-TChen/c7aad499760a0852fb4bc39771a25b36 to your computer and use it in GitHub Desktop.
An imperative implementation of Union-Find algorithm in Haskell
{-# LANGUAGE BangPatterns #-}
module Union where
import Data.Ix
import Data.Array
import Data.Array.ST
import Control.Monad
import Control.Monad.ST
type Union i = Array i (i, Int)
type UnionM s i = STArray s i (i, Int)
sizeM :: (Ix i) => i -> UnionM s i -> ST s Int
sizeM i u = snd <$> readArray u i
parentM :: (Ix i) => i -> UnionM s i -> ST s i
parentM i u = fst <$> readArray u i
connectedM :: (Ix i) => i -> i -> UnionM s i -> ST s Bool
connectedM i j u = (==) <$> findM i u <*> findM j u
rebuildM :: (Ix i) => (i, i) -> UnionM s i -> ST s ()
rebuildM bnd u = forM_ bnd (`findM` u)
componentsM :: (Ix i) => UnionM s i -> ST s [(i, Int)]
componentsM u = do
bnd <- getBounds u
rebuildM bnd u
xs <- getAssocs u
return [ (par, s) | (i, (par, s)) <- xs, s /= 0 ]
findM :: (Ix i, Eq i) => i -> UnionM s i -> ST s i
findM i u = do
!j <- parentM i u
if i == j then return j else do
!pi <- findM j u
writeArray u i (pi, 0)
return pi
createM :: (Ix i) => (i, i) -> ST s (UnionM s i)
createM bnd = newListArray bnd $ zip (range bnd) [1,1..]
unionM :: (Ix i) => i -> i -> UnionM s i -> ST s ()
unionM i j u = do
!pi <- findM i u
!pj <- findM j u
if pi == pj then return () else do
!n <- sizeM pi u
!m <- sizeM pj u
if n < m then linkM n m pi pj u else linkM m n pj pi u
linkM :: (Ix i) => Int -> Int -> i -> i -> UnionM s i -> ST s ()
linkM n m i j u = writeArray u i (j, 0) >> writeArray u j (j, n+m)
unionsM :: (Ix i) => [(i, i)] -> UnionM s i -> ST s ()
unionsM reps u = forM_ reps (\(i, j) -> unionM i j u)
unions :: (Ix i) => (i, i) -> [(i, i)] -> Union i
unions bnd reps = runSTArray $ do
u <- createM bnd
unionsM reps u
rebuildM bnd u
return u
connected :: (Ix i) => i -> i -> Union i -> Bool
connected i j u = find i u == find j u
find :: (Ix i, Eq i) => i -> Union i -> i
find i = fst . (!i)
size :: (Ix i) => i -> Union i -> Int
size i u = snd (u ! find i u)
components :: Ix i => Union i -> [(i, Int)]
components u = [ (par, s) | i <- range bnd, let (par, s) = u!i, s /= 0 ]
where bnd = bounds u
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment