Created
June 3, 2018 11:28
-
-
Save MarcinMoskala/fe541e1c9f02ecd8e6511e01a27b1294 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.junit.Test | |
open class KMeansSolver<P : Any>( | |
val createInitialMean: (List<P>) -> P, | |
val distanceBetween: (P, P) -> Double, | |
val calculateMean: (List<P>) -> P | |
) { | |
fun solve( | |
points: List<P>, | |
maxAvgError: Double | |
): List<P> { | |
var meansNum = 2 | |
while (true) { | |
val means = solve(points, meansNum) | |
val error = calculateError(points, means) | |
if (error / points.size < maxAvgError) return means | |
meansNum++ | |
} | |
} | |
fun solve( | |
points: List<P>, | |
meansNumber: Int | |
): List<P> { | |
var means = List(meansNumber) { createInitialMean(points) } | |
var prevError: Double = calculateError(points, means) | |
while(true) { | |
means = nextMeans(points, means) | |
val error = calculateError(points, means) | |
if(error >= prevError) return means | |
prevError = error | |
} | |
} | |
fun solve( | |
points: List<P>, | |
meansNumber: Int, | |
iterationNum: Int = points.size | |
): List<P> { | |
var means = List(meansNumber) { createInitialMean(points) } | |
repeat(iterationNum) { | |
means = nextMeans(points, means) | |
} | |
return means | |
} | |
protected open fun calculateError(points: List<P>, means: List<P>): Double { | |
val closestMean = points.map { p -> closestMean(p, means) } | |
val error = (points zip closestMean).sumByDouble { (p, q) -> distanceBetween(p, q) } | |
return error | |
} | |
protected open fun nextMeans(points: List<P>, means: List<P>): List<P> { | |
val grouped = points.groupBy { point -> closestMean(point, means) } | |
val newMeans = grouped.map { (_, group) -> calculateMean(group) } | |
val meansWithoutPoints: List<P> = (means elementMinus grouped.keys) | |
return newMeans + meansWithoutPoints.moveToClosestPointUntaken(points, newMeans) | |
} | |
protected fun closestMean(point: P, means: List<P>): P = | |
means.minBy { mean -> distanceBetween(point, mean) }!! | |
// Improvement: For mean without points, move to closest untaken point | |
fun List<P>.moveToClosestPointUntaken(points: List<P>, newMeans: List<P>): List<P> { | |
val untakenPoints = points - newMeans | |
return map { m -> untakenPoints.minBy { p -> distanceBetween(p, m) }!! } | |
} | |
} | |
class KMeansSolverTests : SeedTest() { | |
@Test | |
fun `For two sepearete groups of points, two means end up in the middle of them`() { | |
assertEquals(listOf(1.0, 6.0), DoublesKMeansSolver.solve(listOf(0.0, 2.0, 5.0, 7.0), 2)) | |
assertEquals(listOf(2.0, 6.0), DoublesKMeansSolver.solve(listOf(1.0, 2.0, 3.0, 5.0, 6.0, 7.0), 2)) | |
assertEquals(listOf(2.0, 6.0), DoublesKMeansSolver.solve(listOf(1.0, 2.0, 3.0, 6.0), 2)) | |
} | |
@Test | |
fun `Single mean is equal to average`() { | |
val doubles = listOf(0.0, 2.0, 5.0, 7.0, 1.0) | |
assertEquals(doubles.average(), DoublesKMeansSolver.solve(doubles, 1).first()) | |
} | |
@Test | |
fun `For the same amount of averages and points, they are equal`() { | |
val doubles = listOf(1.0, 3.0, 5.0, 7.0, 10.0) | |
assertEquals(doubles, DoublesKMeansSolver.solve(doubles, doubles.size)) | |
} | |
companion object { | |
val distanceTo = { d1: Double, d2: Double -> (d1 - d2) * (d1 - d2) } | |
object DoublesKMeansSolver : KMeansSolver<Double>( | |
createInitialMean = { list -> rand.nextDouble() * (list.max()!! - list.min()!!) }, | |
distanceBetween = distanceTo, | |
calculateMean = { list -> list.average() } | |
) | |
} | |
} | |
fun main(args: Array<String>) { | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment