Skip to content

Instantly share code, notes, and snippets.

@blackheaven
Created June 1, 2018 21:25
Show Gist options
  • Save blackheaven/16ca2e2d7f0d88e6801a63e5276e6d31 to your computer and use it in GitHub Desktop.
Save blackheaven/16ca2e2d7f0d88e6801a63e5276e6d31 to your computer and use it in GitHub Desktop.
Coding Dojo (18-05-30) on K-means
module Kata
( kmeans
, randV
, prettifier
) where
import Data.List(groupBy, sortBy, minimumBy, unfoldr)
import Data.Function(on)
import System.Random(newStdGen, randomRs)
type Vector = (Float, Float)
type Cluster = [Vector]
type Distance = Vector -> Vector -> Float
type UnitOfWork = [Cluster]
kmeans :: Int -> [Vector] -> [Cluster]
kmeans nbClusters vectors = head (drop 7 (iterate round' (initVector nbClusters vectors)))
type Mean = Cluster -> Vector
means :: Mean
means cluster = meanCluster $ foldr1 accumulate cluster
where accumulate (accX, accY) (currentVectorX, currentVectorY) = (accX + currentVectorX, accY + currentVectorY)
meanCluster (accX, accY) = (accX / len, accY / len)
len = fromInteger $ toInteger $ length cluster
distance :: Distance
distance (vectorX1, vectorY1) (vectorX2, vectorY2) = sqrt $ (vectorX1-vectorX2)**2 + (vectorY1-vectorY2)**2
initVector :: Int -> [Vector] -> UnitOfWork
initVector nbClusters vectors = groupCluster $ zip clusterIndexes vectors
where clusterIndexes = cycle $ [1..nbClusters]
type Round = UnitOfWork -> UnitOfWork
round' :: Round
round' previousRound = groupCluster $ map addNearest everyVectors
where clusterMeans = map means previousRound
everyVectors = concat previousRound
addNearest vector = (minimumBy (\min1 min2 -> compare (distance vector min1) (distance vector min2)) clusterMeans, vector)
groupCluster :: (Eq key, Ord key) => [(key, vector)] -> [[vector]]
groupCluster indexedVectors = map (map snd) $ groupBy ((==) `on` fst) $ sortBy (compare `on` fst) indexedVectors
randV :: Float -> Float -> IO [(Float, Float)]
randV ll lh = newStdGen >>= return . unfoldr (\(x:y:zs) -> Just ((x, y), zs)) . randomRs (ll, lh)
prettifier :: Cluster -> String
prettifier xs = unlines $ [replicate 15 '-'] ++ map show xs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment