Skip to content

Instantly share code, notes, and snippets.

@BaxterEaves
Created October 6, 2016 17:20
Show Gist options
  • Save BaxterEaves/f345681f956428c30c3955bcf8686b29 to your computer and use it in GitHub Desktop.
Save BaxterEaves/f345681f956428c30c3955bcf8686b29 to your computer and use it in GitHub Desktop.
K-means algorithm in Haskell
-- This is the first thing I've ever written in Haskell by myself.
-- Be gentle.
import Data.List (genericLength)
type Assignment = [Int]
-- |Cluster data into k clusters
kmeans :: (Floating a, Ord a) => [[a]] -> Int -> Assignment
kmeans xs k = _kmeans xs $ nearestMeans xs (take k xs)
-- |Helper function for kmeans
_kmeans :: (Floating a, Ord a) => [[a]] -> Assignment -> Assignment
_kmeans xs asgn
| newAsgn == asgn = asgn
| otherwise = _kmeans xs newAsgn
where newAsgn = nearestMeans xs $ calcMeans xs asgn
-- |Recompute the means
calcMeans :: (Floating a) => [[a]] -> Assignment -> [[a]]
calcMeans xs asgn = map listMean [[x | (x,i) <- zip xs asgn, i == j] | j <- [0..k]]
where k = maximum asgn
-- |Mean of a list of lists
listMean :: (Floating a) => [[a]] -> [a]
listMean [] = error "mean of empty list "
listMean xs = map (/k) $ listSum xs
where k = genericLength xs
-- |Sum of a list of lists
listSum :: (Floating a) => [[a]] -> [a]
listSum xs = foldl1 (\x1 x2 -> zipWith (+) x1 x2) xs
-- |Find the mean nearest each datum
nearestMeans :: (Floating a, Ord a) => [[a]] -> [[a]] -> [Int]
nearestMeans xs ms = map argmin dists
where dists = [map (sumSquares x) ms | x <- xs]
-- |Calculate the sum of squared error between two lists
sumSquares :: (Floating a) => [a] -> [a] -> a
sumSquares xs ys = sum [(x-y)^2 | (x,y) <- zip xs ys]
-- |Get the index of the smallest element in a list
argmin :: (Ord a) => [a] -> Int
argmin (x:xs) = _argmin xs x 0 0
-- |Helper function for argmin
_argmin :: (Ord a) => [a] -> a -> Int -> Int -> Int
_argmin [x] val cix ix
| x < val = cix + 1
| otherwise = ix
_argmin (x:xs) val cix ix
| x < val = _argmin xs x (cix+1) (cix+1)
| otherwise = _argmin xs val (cix+1) ix
testData = [
[ 19.44228191, 10.01954739],
[ 19.53244135, 9.27926943],
[ 19.47813524, 10.11702067],
[ 19.53707403, 10.34821127],
[ 20.22351335, 10.15144988],
[ 19.28171056, 9.97469793],
[ 20.02206671, 10.44740153],
[ 20.50197584, 10.21343973],
[ 19.21130219, 10.12152178],
[ 20.12801058, 9.72609094],
[ 19.83785357, 10.26330778],
[ 20.41310726, 9.46206339],
[ 19.45552404, 9.40547838],
[ 19.89360546, 10.44014115],
[ 20.36125385, 9.90567986],
[ 20.35849112, 10.48796981],
[ 20.88746346, 10.35747471],
[ 20.49813333, 10.65385512],
[ 20.62931967, 10.31353257],
[ 19.71910217, 10.11721853],
[ 23.88999494, 9.49780921],
[ 23.53417881, 10.22200125],
[ 24.20515374, 10.20352821],
[ 24.27042841, 8.94589297],
[ 24.33017732, 10.46796864],
[ 24.1848164 , 9.39154275],
[ 24.25760608, 10.19827234],
[ 24.14990874, 9.95071613],
[ 23.79546847, 9.40557273],
[ 24.12813273, 9.45732106],
[ 24.88221126, 10.67626518],
[ 24.24725526, 9.38441264],
[ 23.70951565, 9.48528594],
[ 23.84057452, 10.4090588 ],
[ 25.30846271, 10.06039913],
[ 23.88388447, 9.42833493],
[ 23.77114657, 9.98375689],
[ 23.40834507, 10.49376609],
[ 23.98728878, 10.22910904],
[ 22.96215117, 9.97424991],
[ 19.58555678, 11.8102289 ],
[ 20.03125014, 10.6541982 ],
[ 20.61266943, 11.811877 ],
[ 19.81763777, 12.87517492],
[ 19.70835143, 11.64469843],
[ 20.08416497, 11.49784966],
[ 19.82591625, 11.04150988],
[ 19.28604287, 12.57047196],
[ 19.60033648, 12.37376337],
[ 19.73600258, 11.13350926],
[ 20.6532054 , 12.40243406],
[ 19.88155388, 11.5102145 ],
[ 20.16608498, 11.52421444],
[ 19.84189865, 12.16468441],
[ 20.07146249, 11.9723343 ],
[ 19.80540283, 12.25633198],
[ 20.36592307, 12.04606745],
[ 19.76154363, 12.04326451],
[ 21.13411946, 11.97343609],
[ 20.00825767, 12.22197451]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment