Skip to content

Instantly share code, notes, and snippets.

@MarcinMoskala
Created June 3, 2018 11:28
Show Gist options
  • Save MarcinMoskala/fe541e1c9f02ecd8e6511e01a27b1294 to your computer and use it in GitHub Desktop.
Save MarcinMoskala/fe541e1c9f02ecd8e6511e01a27b1294 to your computer and use it in GitHub Desktop.
import org.junit.Test
open class KMeansSolver<P : Any>(
val createInitialMean: (List<P>) -> P,
val distanceBetween: (P, P) -> Double,
val calculateMean: (List<P>) -> P
) {
fun solve(
points: List<P>,
maxAvgError: Double
): List<P> {
var meansNum = 2
while (true) {
val means = solve(points, meansNum)
val error = calculateError(points, means)
if (error / points.size < maxAvgError) return means
meansNum++
}
}
fun solve(
points: List<P>,
meansNumber: Int
): List<P> {
var means = List(meansNumber) { createInitialMean(points) }
var prevError: Double = calculateError(points, means)
while(true) {
means = nextMeans(points, means)
val error = calculateError(points, means)
if(error >= prevError) return means
prevError = error
}
}
fun solve(
points: List<P>,
meansNumber: Int,
iterationNum: Int = points.size
): List<P> {
var means = List(meansNumber) { createInitialMean(points) }
repeat(iterationNum) {
means = nextMeans(points, means)
}
return means
}
protected open fun calculateError(points: List<P>, means: List<P>): Double {
val closestMean = points.map { p -> closestMean(p, means) }
val error = (points zip closestMean).sumByDouble { (p, q) -> distanceBetween(p, q) }
return error
}
protected open fun nextMeans(points: List<P>, means: List<P>): List<P> {
val grouped = points.groupBy { point -> closestMean(point, means) }
val newMeans = grouped.map { (_, group) -> calculateMean(group) }
val meansWithoutPoints: List<P> = (means elementMinus grouped.keys)
return newMeans + meansWithoutPoints.moveToClosestPointUntaken(points, newMeans)
}
protected fun closestMean(point: P, means: List<P>): P =
means.minBy { mean -> distanceBetween(point, mean) }!!
// Improvement: For mean without points, move to closest untaken point
fun List<P>.moveToClosestPointUntaken(points: List<P>, newMeans: List<P>): List<P> {
val untakenPoints = points - newMeans
return map { m -> untakenPoints.minBy { p -> distanceBetween(p, m) }!! }
}
}
class KMeansSolverTests : SeedTest() {
@Test
fun `For two sepearete groups of points, two means end up in the middle of them`() {
assertEquals(listOf(1.0, 6.0), DoublesKMeansSolver.solve(listOf(0.0, 2.0, 5.0, 7.0), 2))
assertEquals(listOf(2.0, 6.0), DoublesKMeansSolver.solve(listOf(1.0, 2.0, 3.0, 5.0, 6.0, 7.0), 2))
assertEquals(listOf(2.0, 6.0), DoublesKMeansSolver.solve(listOf(1.0, 2.0, 3.0, 6.0), 2))
}
@Test
fun `Single mean is equal to average`() {
val doubles = listOf(0.0, 2.0, 5.0, 7.0, 1.0)
assertEquals(doubles.average(), DoublesKMeansSolver.solve(doubles, 1).first())
}
@Test
fun `For the same amount of averages and points, they are equal`() {
val doubles = listOf(1.0, 3.0, 5.0, 7.0, 10.0)
assertEquals(doubles, DoublesKMeansSolver.solve(doubles, doubles.size))
}
companion object {
val distanceTo = { d1: Double, d2: Double -> (d1 - d2) * (d1 - d2) }
object DoublesKMeansSolver : KMeansSolver<Double>(
createInitialMean = { list -> rand.nextDouble() * (list.max()!! - list.min()!!) },
distanceBetween = distanceTo,
calculateMean = { list -> list.average() }
)
}
}
fun main(args: Array<String>) {
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment