Skip to content

Instantly share code, notes, and snippets.

@ianoc
Forked from azymnis/KMeansJob.scala
Last active September 29, 2015 17:56
Show Gist options
  • Save ianoc/5a5513939a44b0914262 to your computer and use it in GitHub Desktop.
Save ianoc/5a5513939a44b0914262 to your computer and use it in GitHub Desktop.
K-Means in scalding
import com.twitter.algebird.{Aggregator, Semigroup}
import com.twitter.scalding._
import scala.util.Random
/**
* This job is a tutorial of sorts for scalding's Execution[T] abstraction.
* It is a simple implementation of Lloyd's algorithm for k-means on 2D data.
*
* http://en.wikipedia.org/wiki/K-means_clustering
*
* The assumption here is that the number of clusters is much smaller than the
* number of points, and the cluster centroids can easily fit in memory.
*
* Input: a text TSV file with two columns, where each row corresponds to
* the (x, y) coordinates of all points in the dataset.
*
* Output: a three column TSV where the first two columns are the (x, y)
* coordinates of each point and the last column is the final cluster id.
*
* E.g. to cluster data into 3 clusters:
*
* scald --local KMeansJob.scala --clusters 3 --input input.tsv --output output.tsv
*
*/
object KMeansJob extends ExecutionApp[Unit](args) {
override def job = Execution.getArgs.flatMap { args =>
// Number of clusters to partition the data into
val numClusters = args("clusters").toInt
val inputFile = args("input")
val outputFile = args("output")
implicit val pointSemigroup = PointSemigroup
// Reads in points and allocates them to random clusters
val initialPoints: TypedPipe[Point] =
TypedPipe.from(TypedTsv[(Double, Double)](inputFile))
.map { case (x, y) => Point(x, y, Random.nextInt(numClusters)) }
updateAndCheck(initialPoints).flatMap { points =>
points
.map { p => (p.x, p.y, p.cluster) }
.writeExecution(TypedTsv[(Double, Double, Int)](outputFile))
}
}
// This represents a single step in the kMeans iteration.
// Given the current cluster assignments, we first calculate the mean of each cluster.
// We then allocate each point to its closest cluster.
// The output of this step is an Execution which wraps a tuple of the number
// of changed points, as well as the new cluster assignments.
def kMeansStep(currentPoints: TypedPipe[Point]): Execution[(Long, TypedPipe[Point])] = {
// Update step: calculate the new centroid of each cluster.
// Wrap that into an Execution[Iterable[Point]]
val newClusterMeansExecution = currentPoints
.map { p => (p.cluster, (1L, p)) }
.group
.sum
.map { case (cluster, (count, summed)) => summed.normalize(count) }
.toIterableExecution
.map(_.toSeq)
newClusterMeansExecution.flatMap { meanSeq =>
// Assignment step: assign each point to the cluster that corresponds
// to the closest centroid.
val newPointsWithDeltas = currentPoints.map { p =>
val newClosestCluster = p.closestPoint(meanSeq).cluster
val newPoint = p.copy(cluster = newClosestCluster)
if (p.cluster == newClosestCluster) {
(newPoint, 0L)
} else {
(newPoint, 1L)
}
}
val newPoints = newPointsWithDeltas.map { case (point, _) => point }
newPointsWithDeltas
.map { case (point, delta) => delta }
.aggregate(Aggregator.fromSemigroup[Long])
.toOptionExecution
.map { deltaOpt =>
val delta = deltaOpt.getOrElse(0L)
(delta, newPoints)
}
}
}
// This recursive method first performs an update and then checks if any
// points have changed. If this is true, it performs a further update.
// Otherwise we exit the recursion with the final solution.
def updateAndCheck(points: TypedPipe[Point]): Execution[TypedPipe[Point]] = {
kMeansStep(points).flatMap {
case (count, newPoints) if (count == 0L) => Execution.from(newPoints)
case (count, newPoints) => {
System.out.println("%d points changed".format(count))
updateAndCheck(newPoints)
}
}
}
}
case class Point(x: Double, y: Double, cluster: Int) {
def normalize(count: Long): Point = Point(x / count, y / count, cluster)
def squareDistanceFrom(other: Point): Double =
math.pow(this.x - other.x, 2) + math.pow(this.y - other.y, 2)
def closestPoint(points: Seq[Point]): Point = points
.map { p => (p, squareDistanceFrom(p)) }
.sortWith { _._2 <= _._2 }
.head._1
}
object PointSemigroup extends Semigroup[Point] {
def plus(l: Point, r: Point) =
Point(l.x + r.x, l.y + r.y, l.cluster)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment