Create a gist now

Instantly share code, notes, and snippets.

@yaseminn /Kmeans.scala Secret
Created Dec 5, 2015

Embed
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.{KMeansModel, KMeans}
import org.apache.spark.mllib.linalg.Vectors
object Kmeans {
def main(args: Array[String]) {
if (args.length < 4) {
println("usage: <input> <output> <numClusters> <maxIterations>")
System.exit(0)
}
val conf = new SparkConf
conf.setAppName("Spark KMeans Example").setMaster("local")
val context = new SparkContext(conf)
val input = args(0)
val output = args(1)
val K = args(2).toInt
val maxIteration = args(3).toInt
val runs = calculateRuns(args)
val data = context.textFile(input).map {
line => Vectors.dense(line.split(',').map(_.toDouble))
}.cache()
val clusters: KMeansModel = KMeans.train(data, K, maxIteration, runs)
println("cluster centers: " + clusters.clusterCenters.mkString(","))
val vectorsAndClusterIdx = data.map{ point =>
val prediction = clusters.predict(point)
(point.toString, prediction)
}
vectorsAndClusterIdx.saveAsTextFile(output)
context.stop()
}
def calculateRuns(args: Array[String]): Int = {
if (args.length > 4) args(4).toInt
else 1
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment