Skip to content

Instantly share code, notes, and snippets.

@kputnam
Last active August 29, 2015 14:08
Show Gist options
  • Save kputnam/a65437b83187e866a201 to your computer and use it in GitHub Desktop.
Save kputnam/a65437b83187e866a201 to your computer and use it in GitHub Desktop.
Sequential and Parallel KMeans
{-# LANGUAGE ScopedTypeVariables #-}
module KMeans
( euclidean
, kMeans
, kMeansPar
, chunk
, random
) where
import Control.Monad
import Control.Monad.Primitive
import System.Random.MWC (Variate, Gen, uniform)
import Control.Parallel.Strategies
import Data.Ord
import Data.List
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as M
type Point a
= Vector a
type Centroid a
= Vector a
type Distance a b
= Point a -> Point a -> b
-- | Compute the squared distance between two points
euclidean :: Num a => Distance a a
euclidean a b = V.sum $ V.zipWith diffSq a b
where diffSq x1 x2 = (x1 - x2) ^ 2
data Intermediate a
= Intermediate !Int !(Vector a)
deriving (Eq, Show, Read)
iempty :: Num a => Int -> Intermediate a
iempty n = Intermediate 0 (V.replicate n 0)
isingle :: Point a -> Intermediate a
isingle x = Intermediate 1 x
iinsert :: Num a => Intermediate a -> Point a -> Intermediate a
iinsert (Intermediate n x) x' = Intermediate (n + 1) (V.zipWith (+) x x')
iappend :: Num a => Intermediate a -> Intermediate a -> Intermediate a
iappend (Intermediate n x) (Intermediate n' x') = Intermediate (n + n') (V.zipWith (+) x x')
icentroid :: Fractional a => Intermediate a -> Centroid a
icentroid (Intermediate n x) = fmap (/ (fromIntegral n)) x
-- | Lloyd's algorithm
kMeans :: forall a b. (Eq a, Fractional a, Ord b)
=> Int -- ^ maximum number of iterations
-> Distance a b -- ^ distance function between points
-> [Centroid a] -- ^ initial guess at centroids
-> [Point a] -- ^ set of data points to cluster
-> [Centroid a]
kMeans maxIterations distance cs xs = loop maxIterations cs
where
nClusters = length cs
nDimensions = V.length (head xs)
-- | Stop when centroids stop moving or we exhausted maximum number of iterations
loop n cs
| n <= 0 = cs
| cs == cs' = cs
| otherwise = loop (n-1) cs'
where cs' = clusters (update cs xs)
-- | For each point, select nearest centroid and update number of points and component-wise sums
update cs xs = V.create $ do
-- Initialize each cluster with Intermediate 0 0-vector
results <- M.replicate nClusters (iempty nDimensions)
-- Update nearest cluster's Intermediate value
forM_ xs $ \x -> do
let k = fst $ nearest x
i <- M.read results k
M.write results k (iinsert i x)
return results
where
centroids = zip [0..] cs
nearest x = minimumBy (comparing (distance x . snd)) centroids
-- | Given numbers of points and component-wise sums, compute centroids
clusters :: Vector (Intermediate a) -> [Centroid a]
clusters cs = [ icentroid i | i@(Intermediate n _) <- V.toList cs, n > 0 ]
-- | Break a vector into n chunks (in O(nChunks) time)
chunk :: Int -> Vector a -> [Vector a]
chunk n xs = zipWith slice [0..n-1] (extra:repeat 0)
where
slice k e = V.slice (size*k) (size+e) xs
(size, extra) = V.length xs `quotRem` n
-- | Lloyd's algorithm
kMeansPar :: forall a b. (Eq a, Fractional a, Ord b)
=> Int -- ^ maximum number of iterations
-> Distance a b -- ^ distance measure between two points
-> [Centroid a] -- ^ initial guess at centroids
-> [Vector (Point a)] -- ^ evenly-divided groups of points to cluster
-> [Centroid a]
kMeansPar maxIterations distance cs xs = loop maxIterations cs
where
nClusters = length cs
nDimensions = V.length (V.head (head xs))
-- | Stop when centroids stop moving or we exhausted maximum number of iterations
loop n cs
| n <= 0 = cs
| cs == cs' = cs
| otherwise = loop (n-1) cs'
where cs' = clusters $ foldr1 (V.zipWith iappend) (map (update cs) xs `using` parList rseq)
-- | For each point, select nearest centroid and update number of points and component-wise sums
update :: [Centroid a] -> Vector (Point a) -> Vector (Intermediate a)
update cs xs = V.create $ do
-- Initialize each cluster with Intermediate 0 0-vector
results <- M.replicate nClusters (iempty nDimensions)
-- Update nearest cluster's Intermediate value
V.forM_ xs $ \x -> do
let k = fst $ nearest x
i <- M.read results k
M.write results k (iinsert i x)
return results
where
centroids = zip [0..] cs
nearest x = minimumBy (comparing (distance x . snd)) centroids
-- | Given numbers of points and component-wise sums, compute centroids
clusters :: Vector (Intermediate a) -> [Centroid a]
clusters cs = [ icentroid i | i@(Intermediate n _) <- V.toList cs, n > 0 ]
-- | Generate a uniformly-distributed centroid
random :: (PrimMonad m, Variate a, Num a) => Int -> Gen (PrimState m) -> m (Centroid a)
random nDimensions gen = V.replicateM nDimensions (uniform gen)
---------------------------------------------------------------------------------------
let point x y z g = do { x' <- normal x 10 g; y' <- normal y 10 g; z' <- normal z 10 g; return $ V.fromList [x',y',z'] }
g <- createSystemRandom
-- Random data normally distributed around four centroids
as <- replicateM 200 $ point 100 100 100 g
bs <- replicateM 150 $ point 50 (-10) (-10) g
cs <- replicateM 130 $ point 100 (-10) 15 g
ds <- replicateM 230 $ point (-20) 80 180 g
let xs = mconcat [as,bs,cs,ds]
-- Random initial centroid locations
zs <- replicateM 4 (random 3 g)
-- [fromList [0.9848902563026655,0.765287551648172,0.7781795312633198]
-- ,fromList [0.7361694755398108,7.367420245929768e-2,0.6224181897188118]
-- ,fromList [0.8774773711990987,0.4621511922427829,0.5492891219506643]
-- ,fromList [0.43748278208197944,0.8174101215438477,0.8235959386276283]]
kMeans 1000 euclidean zs xs
-- [fromList [99.89809340402812,100.38225039289729,100.24009968006065]
-- ,fromList [49.761342796903456,-8.898956628199413,-9.947110251417078]
-- ,fromList [100.33603708616106,-9.597087217849067,14.854017607086883]
-- ,fromList [-20.36856619481266,79.89870265298171,179.94788176235113]]
kMeansPar 1000 euclidean zs (chunk 4 $ V.fromList xs)
-- [fromList [99.89809340402812,100.38225039289729,100.24009968006065]
-- ,fromList [49.761342796903456,-8.898956628199413,-9.947110251417078]
-- ,fromList [100.33603708616106,-9.597087217849067,14.854017607086883]
-- ,fromList [-20.36856619481266,79.89870265298171,179.94788176235113]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment