Skip to content

Instantly share code, notes, and snippets.

@pims
Created May 29, 2011 05:19
Show Gist options
  • Save pims/997501 to your computer and use it in GitHub Desktop.
Save pims/997501 to your computer and use it in GitHub Desktop.
Naive python translation of https://gist.github.com/995804
#!/usr/bin/env python
import sys
from math import sqrt
from random import random
"""
Naive python translation of https://gist.github.com/995804
Public domain.
"""
INFINITY = float('inf')
class Point:
def __init__(self, x, y):
self.x = float(x)
self.y = float(y)
def __str__(self):
return "Point{%s, %s}" % (self.x, self.y)
def dist_to(self, p):
xs = (self.x - p.x)**2
ys = (self.y - p.y)**2
return sqrt(xs + ys)
class Cluster:
# Constructor with a starting centerpoint
def __init__(self, center):
self.center = center
self.points = []
# Recenters the centroid point and removes all of the associated points
def recenter(self):
xa = ya = 0
old_center = self.center
# Sum up all x/y coords
for point in self.points:
xa += point.x
ya += point.y
# Average out data
xa /= len(self.points)
ya /= len(self.points)
# Reset center and return distance moved
self.center = Point(xa, ya)
return old_center.dist_to(self.center)
def kmeans(data, k, delta=0.001):
clusters = []
# Assign intial values for all clusters
for i in xrange(0, k):
index = int(len(data) * random())
rand_point = data[index]
c = Cluster(rand_point)
clusters.append(c)
# Loop
while True:
# Assign points to clusters
for point in data:
min_dist = +INFINITY
min_cluster = None
# Find the closest cluster
for cluster in clusters:
dist = point.dist_to(cluster.center)
if dist < min_dist:
min_dist = dist
min_cluster = cluster
# Add to closest cluster
min_cluster.points.append(point)
# Check deltas
max_delta = -INFINITY
for cluster in clusters:
dist_moved = cluster.recenter()
# Get largest delta
if dist_moved > max_delta:
max_delta = dist_moved
# Check exit condition
if max_delta < delta:
return clusters
# Reset points for the next iteration
for cluster in clusters:
cluster.points = []
if __name__ == '__main__':
import matplotlib
import matplotlib.pyplot as plt
dataset = """
48.2641334571,86.4516903905
0.114004262656,35.8368597414
97.4319168245,92.8009240744
24.4614031388,18.3292584382
36.2367675367,32.8294024271
75.5836860736,68.30729977
38.6577034445,25.7701728584
28.2607136287,64.4493377817
61.5358486771,61.2195232194
1.52352224798,38.5083779618
11.6392182793,68.2369021579
53.9486870607,53.9136556533
14.6671651772,26.0132534731
65.9506725878,82.5639317581
58.3682872339,51.6414580337
12.6918921252,2.28888447759
31.7587852231,18.1368234166
63.6631115204,24.933301389
29.1652289905,34.456759171
44.3830953085,70.4813875779
47.0571691145,65.3507625811
74.0584537502,98.2271944247
55.8929146157,86.6196265477
20.4744253473,12.0025149302
14.2867767281,40.2850440995
40.43551369,94.5410407116
87.6178871195,12.4700151639
47.2703048197,93.0636237124
59.7895104175,69.2621288413
80.8612333922,42.9183411179
31.1271795535,55.6669044656
78.9671049353,65.833739365
39.8324533414,63.0343115139
79.126343548,14.9128874133
65.8152400306,77.5202358013
75.2762752704,42.4858435609
29.6475948493,61.2068411763
67.421857106,54.8955604259
10.4652931501,29.7954139372
32.0272462745,99.5422900971
80.1520927001,84.2710379142
2.27240208403,41.2138854089
44.4601509555,1.72563901513
16.8676021068,35.3415636277
58.1977544121,29.2752085455
24.6119080085,39.9440735137
63.0759798755,60.9841014448
30.9289119657,95.0173219502
8.54972950047,41.7384441737
61.2606910793,4.06738902059
83.2302091964,11.6373312879
89.4443065362,42.5694882801
24.5619318152,97.7947977804
50.3134024475,40.6429336223
58.1422402033,36.1112632557
32.0668520827,29.9924151435
89.6057447137,84.9532177777
9.8876440816,18.2540486261
17.9670383961,47.596032257
50.2977668282,93.6851189223
98.0700386253,86.5816924579
10.8175290981,26.4344732252
34.7463851288,24.4154447141
92.5470100593,17.3595513748
79.0426629356,4.59850018907
89.9791366918,29.523946842
3.89920214563,91.3650215111
35.4669861576,62.1865368798
2.78150918086,24.5280230552
50.0390951889,57.0414421682
64.4521660758,48.4962172448
94.4915452316,56.6508179406
47.1655534769,15.8292055671
94.2027011374,45.6802385454
30.5846324871,54.783635876
57.7043252948,0.286661610381
41.7908674949,14.7206014023
59.6689465934,64.8849831965
92.2553335495,55.9096460272
48.493467262,69.4766837809
23.1837859581,71.4406867443
29.0737623652,66.9391416961
95.7442323112,89.4677505059
68.7707275828,40.9900140055
84.5445737133,32.1707309618
67.4126251988,56.6710579117
10.688352016,28.1745892928
56.7620324155,18.3034334207
50.6751320678,86.6916908032
74.6185482896,34.022483532
20.7011996002,32.855295357
11.479054664,1.59204297586
51.6805387648,25.4063026358
84.4109522357,47.237632645
90.6395051745,57.7917166935
58.6159601042,84.1226173848
46.2184509277,28.559934585
97.0302485783,41.3135022812
31.3144587058,87.2459910122
5.93357833962,95.6812831872
"""
data = []
k = int(sys.argv[1]) if len(sys.argv) > 1 else 4
for line in dataset.split():
if line:
x,y = line.split(',', 2)
data.append(Point(x,y))
clusters = kmeans(data, k)
fig = plt.figure()
ax = fig.add_subplot(111)
for cluster in clusters:
_x = [point.x for point in cluster.points]
_y = [point.y for point in cluster.points]
ax.plot(_x,_y,'o')
plt.show()
@cfdrake
Copy link

cfdrake commented May 29, 2011

Nice! :) Great to see people interested in my article!

I actually started with the intent of writing it in Python, oddly enough :p

@newacct
Copy link

newacct commented Nov 12, 2011

index = int(len(data) * random()) rand_point = data[index]
should be written as
rand_point = random.choice(data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment