Skip to content

Instantly share code, notes, and snippets.

Last active March 8, 2024 13:40
Show Gist options
  • Save umbertogriffo/b599aa9b9a156bb1e8775c8cdbfb688a to your computer and use it in GitHub Desktop.
Save umbertogriffo/b599aa9b9a156bb1e8775c8cdbfb688a to your computer and use it in GitHub Desktop.
Step by step Code Tutorial on implementing a basic k-means in Spark in order to cluster a geo-located devices


  • Download dataset here


* Follow the well-comented code kmeans.scala

Main Steps of algorithm:

  • Choose K random geo-located points as starting centers
  • Find all points closest to each center
  • Find the new center of each cluster
  • Loop until the total distance between one iteration's points and the next is less than the convergence distance specified
package LearningScala.LearningScala.local.kmeans
object kmeans {
// Find K Means of device status locations
// Input data: file(s) with device status data (delimited by ',')
// including latitude (4th field) and longitude (5th field) of device locations
// (lat,lon of 0,0 indicates unknown location)
import scala.math.pow
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
// The squared distances between two points
def distanceSquared(p1: (Double, Double), p2: (Double, Double)) = {
pow(p1._1 - p2._1, 2) + pow(p1._2 - p2._2, 2)
// The sum of two points
def addPoints(p1: (Double, Double), p2: (Double, Double)) = {
(p1._1 + p2._1, p1._2 + p2._2)
// for a point p and an array of points, return the index in the array of the point closest to p
def closestPoint(p: (Double, Double), points: Array[(Double, Double)]): Int = {
var index = 0
var bestIndex = 0
var closest = Double.PositiveInfinity
for (i <- 0 until points.length) {
val dist = distanceSquared(p, points(i))
if (dist < closest) {
closest = dist
bestIndex = i
def main(args: Array[String]) {
// Set Windows System property
//System.setProperty("hadoop.home.dir", "c:/winutil/");
val conf = new SparkConf().setAppName("First Scala app").setMaster("local[*]")
val sc = new SparkContext(conf)
// The device status data file(s)
val filename = "loudacre/*"
// K is the number of means (center points of clusters) to find
val K = 5
// ConvergeDist -- the threshold "distance" between iterations at which we decide we are done
val convergeDist = .1
// Parse the device status data file into pairs
val fileRdd = sc.textFile(filename)
val pairLatLongRdd = => line.split(',')).map(pair => (pair(3).toDouble, pair(4).toDouble)).filter(point => !((point._1 == 0) && (point._2 == 0))).
for ((a, b) <- pairLatLongRdd.take(2)) {
println("Lat: " + a + " Long : " + b);
//start with K randomly selected points from the dataset as center points
var kPoints = pairLatLongRdd.takeSample(false, K, 42)
println("K Center points initialized :");
for ((a, b) <- kPoints) {
println("Lat: " + a + " Long : " + b);
// loop until the total distance between one iteration's points and the next is less than the convergence distance specified
var tempDist = Double.PositiveInfinity
while (tempDist > convergeDist) {
// For each key (k-point index), find a new point by calculating the average of each closest point
// for each point, find the index of the closest kpoint.
// map to (index, (point,1)) as follow:
// (1, ((2.1,-3.4),1))
// (0, ((5.1,-7.4),1))
// (1, ((8.1,-4.4),1))
val closestToKpointRdd = => (closestPoint(point, kPoints), (point, 1)))
// For each key (k-point index), reduce by sum (addPoints) the latitudes and longitudes of all the points closest to that k-point, and the number of closest points
// E.g.
// (1, ((4.325,-5.444),2314))
// (0, ((6.342,-7.532),4323))
// The reduced RDD should have at most K members.
//val pointCalculatedRdd = closestToKpointRdd.reduceByKey((v1, v2) => ((addPoints(v1._1, v2._1), v1._2 + v2._2)))
val pointCalculatedRdd = closestToKpointRdd.reduceByKey { case ((point1, n1), (point2, n2)) => (addPoints(point1, point2), n1 + n2) }
// For each key (k-point index), find a new point by calculating the average of each closest point
// (index, (totalX,totalY),n) to (index, (totalX/n,totalY/n))
//val newPointRdd = => (center._1, (center._2._1._1 / center._2._2, center._2._1._2 / center._2._2))).sortByKey()
val newPoints = { case (i, (point, n)) => (i, (point._1 / n, point._2 / n)) }.collectAsMap()
// calculate the total of the distance between the current points (kPoints) and new points (localAverageClosestPoint)
tempDist = 0.0
for (i <- 0 until K) {
// That distance is the delta between iterations. When delta is less than convergeDist, stop iterating
tempDist += distanceSquared(kPoints(i), newPoints(i))
println("Distance between iterations: " + tempDist);
// Copy the new points to the kPoints array for the next iteration
for (i <- 0 until K) {
kPoints(i) = newPoints(i)
// Display the final center points
println("Final center points :");
for (point <- kPoints) {
// take 10 randomly selected device from the dataset and recall the model
val deviceRdd = => line.split(',')).map(pair => (pair(1), (pair(3).toDouble, pair(4).toDouble))).filter(device => !((device._2._1 == 0) && (device._2._2 == 0))).
var points = deviceRdd.takeSample(false, 10, 42)
for ((device, point) <- points) {
val k = closestPoint(point, kPoints)
println("device: " + device + " to K: " + k);
Copy link

I've updated the link to the dataset. Now, you can download the dataset without requesting access.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment