Skip to content

Instantly share code, notes, and snippets.

@erikerlandson
Created January 6, 2015 19:12
Show Gist options
  • Save erikerlandson/c3c35f0b1aae737fc884 to your computer and use it in GitHub Desktop.
Save erikerlandson/c3c35f0b1aae737fc884 to your computer and use it in GitHub Desktop.
A k-medoids implementation for Spark RDDs
def kMedoids[T :ClassTag, U >: T :ClassTag](
data: RDD[T],
k: Int,
metric: (U,U) => Double,
sampleSize: Int = 10000,
maxIterations: Int = 10,
resampleInterval: Int = 3
): (Seq[T], Double) = {
val n = data.count
require(k > 0)
require(n >= k)
val ss = math.min(sampleSize, n).toInt
val fraction = math.min(1.0, ss.toDouble / n.toDouble)
var sample: Array[T] = data.sample(false, fraction).take(ss)
// initialize medoids to a set of (k) random and unique elements
var medoids: Seq[T] = Random.shuffle(sample.toSet).take(k).toSeq
require(medoids.length >= k)
val mdf = (x: T, mv: Seq[T]) => {
mv.view.zipWithIndex.map { z => (metric(x, z._1), z._2) }.min
}
var itr = 1
var halt = itr > maxIterations
var lastMetric = sample.map { x => mdf(x, medoids)._1 }.sum
while (!halt) {
println(s"\n\nitr= $itr")
// update the sample periodically
if (itr > 1 && (itr % resampleInterval) == 0) sample = data.sample(false, fraction).take(ss)
// assign each element to its closest medoid
val dmed = Array.fill[ArrayBuffer[T]](k)(new ArrayBuffer())
sample.foldLeft(dmed)((dm, x) => {
dm(mdf(x, medoids)._2) += x
dm
})
// this should always be true: each cluster should at least contain its own medoid
require(dmed.map(_.length).min > 0)
/*
for (c <- dmed) {
println(s"c= ${c.toSet.take(5)}")
}
*/
// update the medoids for each cluster
// to do, support generalizations of metric sum, e.g. Minkowski, and/or
// user-defined function:
medoids = dmed.map { clust =>
clust.minBy { e =>
clust.foldLeft(0.0)((x, v) => x + metric(v, e))
}
}
val newMetric = sample.map { x => mdf(x, medoids)._1 }.sum
// todo: test some function of metric values over time as an optional halting condition
// when improvement stops
println(s"last= $lastMetric new= $newMetric")
lastMetric = newMetric
itr += 1
if (itr > maxIterations) halt = true
}
// return most recent cluster medoids
(medoids, lastMetric)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment