Skip to content

Instantly share code, notes, and snippets.

@roman-kashitsyn
Last active October 12, 2016 13:28
Show Gist options
  • Save roman-kashitsyn/b2c8989c4d8c8a5a200537399d433412 to your computer and use it in GitHub Desktop.
Save roman-kashitsyn/b2c8989c4d8c8a5a200537399d433412 to your computer and use it in GitHub Desktop.
Simple Kruskal's MST in Haskell
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.List (sortBy)
import Data.Function (on)
import Data.Array (Array, bounds, rangeSize, elems, accumArray)
import Data.Maybe (fromMaybe)
-- Minimal immutable graph definition
type Edge = (Int, Int, Double)
type Graph = Array Int [Edge]
listEdges :: Graph -> [Edge]
listEdges = concat . elems
edgeWeight :: Edge -> Double
edgeWeight (_, _, w) = w
numVertices :: Graph -> Int
numVertices = rangeSize . bounds
graphOfEdges :: Int -> [Edge] -> Graph
graphOfEdges n = accumArray (flip (:)) [] (0, n - 1) . concat . map indexEdge
where indexEdge e@(from, to, _) = [(from, e), (to, e)]
-- Functional Disjoint Sets implementation, make a module out of it
-- to hide the implementation
data UnionFind = UnionFind { sizes :: IntMap Int, parents :: IntMap Int}
ufEmpty :: UnionFind
ufEmpty = UnionFind { sizes = IntMap.empty, parents = IntMap.empty}
ufJoin :: UnionFind -> Int -> Int -> UnionFind
ufJoin uf x y = if sx >= sy
then UnionFind { sizes = IntMap.insert px (sx + sy) (sizes uf)
, parents = IntMap.insert py px (parents uf) }
else ufJoin uf py px
where (px, py) = (ufParent uf x, ufParent uf y)
(sx, sy) = (ufSize uf x, ufSize uf y)
ufSize :: UnionFind -> Int -> Int
ufSize uf x = fromMaybe 1 $ IntMap.lookup x (sizes uf)
ufParent :: UnionFind -> Int -> Int
ufParent uf x = case IntMap.lookup x (parents uf) of
Just p -> ufParent uf p
Nothing -> x
ufConnected :: UnionFind -> Int -> Int -> Bool
ufConnected uf x y = ufParent uf x == ufParent uf y
-- Kruskal's MST algorithm. Produces lazy stream of edges of minimal
-- spanning forest sorted by weight (incresing)
mst :: Graph -> [Edge]
mst g = foldEdges (ufEmpty, 0) $ sortBy (compare `on` edgeWeight) $ listEdges g
where n = numVertices g
foldEdges _ [] = []
foldEdges state@(uf, numEdges) (e@(from, to, _):es) =
if numEdges == n - 1
then []
else if ufConnected uf from to
then foldEdges state es
else e : foldEdges (ufJoin uf from to, numEdges + 1) es
main :: IO ()
main = print $ mst $ graphOfEdges 6 [ (0, 1, 3), (0, 2, 1), (0, 3, 6)
, (1, 2, 5), (1, 4, 3)
, (2, 3, 5), (2, 4, 6), (2, 5, 4)
, (3, 5, 2)
, (4, 5, 6)
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment