Skip to content

Instantly share code, notes, and snippets.

@disa-mhembere
Created April 5, 2015 17:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save disa-mhembere/8162f436ff2baed755b6 to your computer and use it in GitHub Desktop.
Save disa-mhembere/8162f436ff2baed755b6 to your computer and use it in GitHub Desktop.
Spark example implementation of k-means
# This script was wholly obtained from: https://spark-summit.org/2013/exercises/machine-learning-with-spark.html.
# I take no credit/blame for the implementation.
import os
import sys
import numpy as np
from pyspark import SparkContext
def setClassPath():
oldClassPath = os.environ.get('SPARK_CLASSPATH', '')
cwd = os.path.dirname(os.path.realpath(__file__))
os.environ['SPARK_CLASSPATH'] = cwd + ":" + oldClassPath
def parseVector(line):
return np.array([float(x) for x in line.split(',')])
def closestPoint(p, centers):
bestIndex = 0
closest = float("+inf")
for i in range(len(centers)):
dist = np.sum((p - centers[i]) ** 2)
if dist < closest:
closest = dist
bestIndex = i
return bestIndex
def average(points):
numVectors = len(points)
out = np.array(points[0])
for i in range(2, numVectors):
out += points[i]
out = out / numVectors
return out
if __name__ == "__main__":
setClassPath()
master = open("/root/spark-ec2/cluster-url").read().strip()
masterHostname = open("/root/spark-ec2/masters").read().strip()
sc = SparkContext(master, "PythonKMeans")
K = 10
convergeDist = 1e-5
lines = sc.textFile(
"hdfs://" + masterHostname + ":9000/wikistats_featurized")
data = lines.map(
lambda x: (x.split("#")[0], parseVector(x.split("#")[1]))).cache()
count = data.count()
print "Number of records " + str(count)
# TODO: PySpark does not support takeSample(). Use first K points instead.
centroids = map(lambda (x, y): y, data.take(K))
tempDist = 1.0
while tempDist > convergeDist:
closest = data.map(
lambda (x, y) : (closestPoint(y, centroids), y))
pointsGroup = closest.groupByKey()
newCentroids = pointsGroup.mapValues(
lambda x : average(x)).collectAsMap()
tempDist = sum(np.sum((centroids[x] - y) ** 2) for (x, y) in newCentroids.iteritems())
for (x, y) in newCentroids.iteritems():
centroids[x] = y
print "Finished iteration (delta = " + str(tempDist) + ")"
sys.stdout.flush()
print "Clusters with some articles"
numArticles = 10
for i in range(0, len(centroids)):
samples = data.filter(lambda (x,y) : closestPoint(y, centroids) == i).take(numArticles)
for (name, features) in samples:
print name
print " "
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment