Last active
May 7, 2017 22:33
-
-
Save alopatindev/4e8f2240cb2354da2d4734e27a432dac to your computer and use it in GitHub Desktop.
K-means clustering algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0.3 0.2 | |
10.0 20.4 | |
11.5 55.1 | |
42.3 53.2 | |
3.0 4 | |
4 4 | |
55 12 | |
60 10 | |
12 45 | |
5 1 | |
4 51 | |
1 34 | |
23 44 | |
4.5 55.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// https://www.youtube.com/watch?v=_aWzGGNrcic | |
// https://www.youtube.com/watch?v=RD0nNK51Fp8 | |
import scala.io.Source | |
import scala.util.Random | |
import java.io.{BufferedWriter, FileWriter} | |
object Kmeans extends App { | |
case class Point(val x: Double, val y: Double) { | |
def distance(other: Point): Double = { | |
val dx = x - other.x | |
val dy = y - other.y | |
Math.sqrt(dx * dx + dy * dy) | |
} | |
def nearestCentroid(centroids: List[Centroid]): Centroid = centroids | |
.map { c => (distance(c), c) } | |
.sortBy { case (d, c) => d } | |
.head | |
._2 | |
override def toString = f"<$x%.1f, $y%.1f>" | |
} | |
type Centroid = Point | |
type Clusters = Map[Centroid, List[Point]] | |
class Dataset(val left: Double, val right: Double, val top: Double, val bottom: Double, val xs: List[Double], val ys: List[Double]) { | |
def width = right - left | |
def height = top - bottom | |
def meanX = computeMean(xs) | |
def meanY = computeMean(ys) | |
def mean = Point(meanX, meanY) | |
def computeMean(xs: List[Double]): Double = xs.sum / xs.length.toDouble | |
def nextRandomCentroid(old: Set[Point]): Centroid = { | |
val x = left + Random.nextDouble() * width | |
val y = bottom + Random.nextDouble() * height | |
val result = Point(x, y) | |
if (old contains result) nextRandomCentroid(old) | |
else result | |
} | |
} | |
object Dataset { | |
def apply(points: List[Point]): Dataset = { | |
val xs = points.map { _.x } | |
val ys = points.map { _.y } | |
new Dataset( | |
left = xs.min, | |
right = xs.max, | |
top = ys.max, | |
bottom = ys.min, | |
xs = xs, | |
ys = ys | |
) | |
} | |
} | |
private def initialCentroids(k: Int, points: List[Point]): List[Point] = { | |
val dataset = Dataset(points) | |
def helper(acc: Set[Point]): Set[Point] = | |
if (acc.size < k) { | |
val newAcc = acc + dataset.nextRandomCentroid(acc) | |
helper(newAcc) | |
} else { | |
acc | |
} | |
helper(Set()).toList | |
} | |
private val EPS = 0.001 | |
private def almostEqual(a: Double, b: Double): Boolean = Math.abs(a - b) <= EPS | |
private def almostEqual(a: Point, b: Point): Boolean = almostEqual(a.x, b.x) && almostEqual(a.y, b.y) | |
private def almostEqual(a: List[Point], b: List[Point]): Boolean = | |
a.size == b.size && (a zip b).filter { case (c, d) => almostEqual(c, d) }.size == a.size | |
private def almostEqual(clusters: Clusters, newClusters: Clusters): Boolean = | |
clusters.size == newClusters.size && | |
clusters.size == (clusters.keys zip newClusters.keys) | |
.filter { case (a, b) => almostEqual(a, b) } | |
.filter { case (a, b) => almostEqual(clusters(a), newClusters(b)) } | |
.size | |
private def assignPoints(points: List[Point], centroids: List[Centroid]): Clusters = points | |
.map { p => (p.nearestCentroid(centroids), p) } | |
.groupBy { case (centroid, point) => centroid } | |
.mapValues { centroidsAndPoints => | |
centroidsAndPoints.map { case (centroid, point) => point } | |
} | |
private def improveCentroids(clusters: Clusters): List[Centroid] = clusters | |
.map { case (centroid, points) => Dataset(points).mean } | |
.toList | |
def computeClusters(k: Int, points: List[Point]): Clusters = { | |
require(k > 0) | |
require(!points.isEmpty) | |
val centroids = initialCentroids(k, points) | |
val clusters = assignPoints(points, centroids) | |
def helper(clusters: Clusters): Clusters = { | |
val newCentroids = improveCentroids(clusters) | |
val newClusters = assignPoints(points, newCentroids) | |
if (almostEqual(clusters, newClusters)) clusters | |
else helper(newClusters) | |
} | |
helper(clusters) | |
} | |
def writeFile(filename: String, data: String) { | |
val w = new BufferedWriter(new FileWriter(filename)) | |
w.write(data) | |
w.close() | |
} | |
def parsePoints(filename: String): List[Point] = Source | |
.fromFile(filename) | |
.getLines | |
.map { line => { | |
val coords = line | |
.split(" ") | |
.map { _.toDouble } | |
Point(coords(0), coords(1)) | |
}} | |
.toList | |
val k = 3 | |
val points = parsePoints("data.in") | |
val clusters = computeClusters(k, points) | |
clusters | |
.zipWithIndex | |
.foreach { case ((centroid, points), index) => | |
writeFile( | |
f"data_${index}.dat", | |
points | |
.map { case Point(x, y) => s"$x $y\n"} | |
.mkString | |
) | |
} | |
println(clusters) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/sh | |
set -e | |
rm *.dat -fv | |
scalac Kmeans.scala | |
scala Kmeans | |
CMD='set xrange[-10:100]; set yrange[-10:100]; plot ' | |
P=5 | |
for i in *.dat; do | |
CMD="${CMD}'${i}' with points pointtype $P," | |
(((P++))) | |
done | |
CMD="${CMD}; pause mouse key" | |
echo "${CMD}" | |
gnuplot -e "${CMD}" |
Author
alopatindev
commented
Mar 28, 2017
Spark + Databricks
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.{Vector, Vectors}
def parseLine(line: String): Vector = Vectors.dense(line
.split(" ")
.map { _.toDouble })
val rdd = sc.textFile("/FileStore/tables/.../data.in")
val data = rdd
.map { parseLine(_) }
.cache()
val kmeans = new KMeans()
.setK(3)
.setEpsilon(0.00001)
.setMaxIterations(50)
val model = kmeans.run(data)
// display(model) // FIXME: appears empty to me
val centroids = model
.clusterCenters
.map { _.toArray.toList }
.toList
println(s"centroids are $centroids")
val clustersToPoints = data
.map { point => (model.predict(point), point) }
.groupByKey()
.mapValues { _.toList }
.collect()
.toMap
println(s"clustersToPoints are $clustersToPoints")
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment