Last active October 12, 2016 13:28
Simple Kruskal's MST in Haskell
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.List (sortBy)
import Data.Function (on)
import Data.Array (Array, bounds, rangeSize, elems, accumArray)
import Data.Maybe (fromMaybe)
-- Minimal immutable graph definition
type Edge = (Int, Int, Double)
type Graph = Array Int [Edge]
listEdges :: Graph -> [Edge]
listEdges = concat . elems
edgeWeight :: Edge -> Double
edgeWeight (_, _, w) = w
numVertices :: Graph -> Int
numVertices = rangeSize . bounds
graphOfEdges :: Int -> [Edge] -> Graph
graphOfEdges n = accumArray (flip (:)) [] (0, n - 1) . concat . map indexEdge
where indexEdge e@(from, to, _) = [(from, e), (to, e)]
-- Functional Disjoint Sets implementation, make a module out of it
-- to hide the implementation
data UnionFind = UnionFind { sizes :: IntMap Int, parents :: IntMap Int}
ufEmpty :: UnionFind
ufEmpty = UnionFind { sizes = IntMap.empty, parents = IntMap.empty}
ufJoin :: UnionFind -> Int -> Int -> UnionFind
ufJoin uf x y = if sx >= sy
then UnionFind { sizes = IntMap.insert px (sx + sy) (sizes uf)
, parents = IntMap.insert py px (parents uf) }
else ufJoin uf py px
where (px, py) = (ufParent uf x, ufParent uf y)
(sx, sy) = (ufSize uf x, ufSize uf y)
ufSize :: UnionFind -> Int -> Int
ufSize uf x = fromMaybe 1 $ IntMap.lookup x (sizes uf)
ufParent :: UnionFind -> Int -> Int
ufParent uf x = case IntMap.lookup x (parents uf) of
Just p -> ufParent uf p
Nothing -> x
ufConnected :: UnionFind -> Int -> Int -> Bool
ufConnected uf x y = ufParent uf x == ufParent uf y
-- Kruskal's MST algorithm. Produces lazy stream of edges of minimal
-- spanning forest sorted by weight (incresing)
mst :: Graph -> [Edge]
mst g = foldEdges (ufEmpty, 0) $ sortBy (compare `on` edgeWeight) $ listEdges g
where n = numVertices g
foldEdges _ [] = []
foldEdges state@(uf, numEdges) (e@(from, to, _):es) =
if numEdges == n - 1
then []
else if ufConnected uf from to
then foldEdges state es
else e : foldEdges (ufJoin uf from to, numEdges + 1) es
main :: IO ()
main = print $ mst $ graphOfEdges 6 [ (0, 1, 3), (0, 2, 1), (0, 3, 6)
, (1, 2, 5), (1, 4, 3)
, (2, 3, 5), (2, 4, 6), (2, 5, 4)
, (3, 5, 2)
, (4, 5, 6)
