Skip to content

Instantly share code, notes, and snippets.

@mat3u
Last active August 29, 2015 14:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mat3u/73d195385aa00f2d12f5 to your computer and use it in GitHub Desktop.
Save mat3u/73d195385aa00f2d12f5 to your computer and use it in GitHub Desktop.
k-means in F#
module kMeans
type Distance<'elem, 'cost when 'cost : comparison> = 'elem -> 'elem -> 'cost
type Centroid<'elem> = 'elem seq -> 'elem
type Initializator<'elem> = int -> 'elem seq -> 'elem seq
let randomSubset' (rng: System.Random) k data =
let n = Seq.length data
match k with
| 0 -> Seq.empty
| k -> seq {
for k in [1..k] do
let p = Seq.nth (rng.Next(n)) data
yield p
}
let randomSubset k data =
randomSubset' (new System.Random()) k data
let closest (distance: Distance<'elem, 'cost>)
(centroids: 'elem list) element =
centroids |> List.minBy (fun c -> distance c element)
let findNewCentroids (distance: Distance<'elem, 'cost>)
(centroid: Centroid<'elem>)
data centroids =
let closestCentroid = centroids |> Seq.toList |> closest distance
data
|> Seq.groupBy (fun e -> closestCentroid e)
|> Seq.map (fun (c, e) -> centroid e)
let execute' (randSubset: Initializator<'elem>)
(distance: Distance<'elem, 'cost>)
(centroid: Centroid<'elem>)
(k: int)
(data: 'elem seq) =
let update = findNewCentroids distance centroid
let checkStop current next =
current
|> Seq.sort
|> Seq.exists2 (fun a b -> (distance a b) <> (distance a a)) (next |> Seq.sort)
|> not
let rec epoch data current =
let next = update data current
let stop = checkStop current next
match stop with
| true -> current
| false -> epoch data next
let initial = randSubset k data
epoch data initial
let execute distance centroid k data = execute' randomSubset distance centroid k data
module kMeans
open NUnit.Framework
let distance = (fun a b -> abs(a - b))
let avg = Seq.average
[<Test>]
let ``Should generate random subset`` () =
let data = seq [1..10]
let expected = [9;5;8]
let rg = kMeans.randomSubset' (System.Random(50))
let result = rg 3 data |> Seq.toList
Assert.AreEqual(expected, result)
[<Test>]
let ``Should find centroids of given groups`` () =
let centroids = seq [1.0;5.0;9.0]
let precision = 0.1
let data = seq {
for c in centroids do
let n = [-1.0 .. precision .. 1.0] |> List.map (fun f -> f + c)
yield! n
}
let randomizer = kMeans.randomSubset' (System.Random(50))
let newCentroids = kMeans.execute' randomizer distance avg 3 data
let result = newCentroids |> Seq.sort |> Seq.toList
Assert.AreEqual(1.0, result.[0], 0.1)
Assert.AreEqual(5.0, result.[1], 0.1)
Assert.AreEqual(9.0, result.[2], 0.1)
[<Test>]
let ``Should select closest value`` () =
let centroids = [1.0 .. 10.0]
let value = 3.6
let expected = 4.0
// This is because of multiple enumeration bug
let result = kMeans.closest distance centroids value
let result2 = kMeans.closest distance centroids value
let result3 = kMeans.closest distance centroids value
let result4 = kMeans.closest distance centroids value
Assert.AreEqual(expected, result)
Assert.AreEqual(expected, result2)
Assert.AreEqual(expected, result3)
Assert.AreEqual(expected, result4)
[<Test>]
let ``Should select new centroids`` () =
let data = seq [1.0 .. 10.0]
let centroids = seq [2.0; 8.0; 9.0]
let expected = [3.0; 7.0; 9.5]
let result = kMeans.findNewCentroids distance avg data centroids |> Seq.toList
Assert.AreEqual(expected, result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment