Created
May 29, 2011 05:19
-
-
Save pims/997501 to your computer and use it in GitHub Desktop.
Naive python translation of https://gist.github.com/995804
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import sys | |
from math import sqrt | |
from random import random | |
""" | |
Naive python translation of https://gist.github.com/995804 | |
Public domain. | |
""" | |
INFINITY = float('inf') | |
class Point: | |
def __init__(self, x, y): | |
self.x = float(x) | |
self.y = float(y) | |
def __str__(self): | |
return "Point{%s, %s}" % (self.x, self.y) | |
def dist_to(self, p): | |
xs = (self.x - p.x)**2 | |
ys = (self.y - p.y)**2 | |
return sqrt(xs + ys) | |
class Cluster: | |
# Constructor with a starting centerpoint | |
def __init__(self, center): | |
self.center = center | |
self.points = [] | |
# Recenters the centroid point and removes all of the associated points | |
def recenter(self): | |
xa = ya = 0 | |
old_center = self.center | |
# Sum up all x/y coords | |
for point in self.points: | |
xa += point.x | |
ya += point.y | |
# Average out data | |
xa /= len(self.points) | |
ya /= len(self.points) | |
# Reset center and return distance moved | |
self.center = Point(xa, ya) | |
return old_center.dist_to(self.center) | |
def kmeans(data, k, delta=0.001): | |
clusters = [] | |
# Assign intial values for all clusters | |
for i in xrange(0, k): | |
index = int(len(data) * random()) | |
rand_point = data[index] | |
c = Cluster(rand_point) | |
clusters.append(c) | |
# Loop | |
while True: | |
# Assign points to clusters | |
for point in data: | |
min_dist = +INFINITY | |
min_cluster = None | |
# Find the closest cluster | |
for cluster in clusters: | |
dist = point.dist_to(cluster.center) | |
if dist < min_dist: | |
min_dist = dist | |
min_cluster = cluster | |
# Add to closest cluster | |
min_cluster.points.append(point) | |
# Check deltas | |
max_delta = -INFINITY | |
for cluster in clusters: | |
dist_moved = cluster.recenter() | |
# Get largest delta | |
if dist_moved > max_delta: | |
max_delta = dist_moved | |
# Check exit condition | |
if max_delta < delta: | |
return clusters | |
# Reset points for the next iteration | |
for cluster in clusters: | |
cluster.points = [] | |
if __name__ == '__main__': | |
import matplotlib | |
import matplotlib.pyplot as plt | |
dataset = """ | |
48.2641334571,86.4516903905 | |
0.114004262656,35.8368597414 | |
97.4319168245,92.8009240744 | |
24.4614031388,18.3292584382 | |
36.2367675367,32.8294024271 | |
75.5836860736,68.30729977 | |
38.6577034445,25.7701728584 | |
28.2607136287,64.4493377817 | |
61.5358486771,61.2195232194 | |
1.52352224798,38.5083779618 | |
11.6392182793,68.2369021579 | |
53.9486870607,53.9136556533 | |
14.6671651772,26.0132534731 | |
65.9506725878,82.5639317581 | |
58.3682872339,51.6414580337 | |
12.6918921252,2.28888447759 | |
31.7587852231,18.1368234166 | |
63.6631115204,24.933301389 | |
29.1652289905,34.456759171 | |
44.3830953085,70.4813875779 | |
47.0571691145,65.3507625811 | |
74.0584537502,98.2271944247 | |
55.8929146157,86.6196265477 | |
20.4744253473,12.0025149302 | |
14.2867767281,40.2850440995 | |
40.43551369,94.5410407116 | |
87.6178871195,12.4700151639 | |
47.2703048197,93.0636237124 | |
59.7895104175,69.2621288413 | |
80.8612333922,42.9183411179 | |
31.1271795535,55.6669044656 | |
78.9671049353,65.833739365 | |
39.8324533414,63.0343115139 | |
79.126343548,14.9128874133 | |
65.8152400306,77.5202358013 | |
75.2762752704,42.4858435609 | |
29.6475948493,61.2068411763 | |
67.421857106,54.8955604259 | |
10.4652931501,29.7954139372 | |
32.0272462745,99.5422900971 | |
80.1520927001,84.2710379142 | |
2.27240208403,41.2138854089 | |
44.4601509555,1.72563901513 | |
16.8676021068,35.3415636277 | |
58.1977544121,29.2752085455 | |
24.6119080085,39.9440735137 | |
63.0759798755,60.9841014448 | |
30.9289119657,95.0173219502 | |
8.54972950047,41.7384441737 | |
61.2606910793,4.06738902059 | |
83.2302091964,11.6373312879 | |
89.4443065362,42.5694882801 | |
24.5619318152,97.7947977804 | |
50.3134024475,40.6429336223 | |
58.1422402033,36.1112632557 | |
32.0668520827,29.9924151435 | |
89.6057447137,84.9532177777 | |
9.8876440816,18.2540486261 | |
17.9670383961,47.596032257 | |
50.2977668282,93.6851189223 | |
98.0700386253,86.5816924579 | |
10.8175290981,26.4344732252 | |
34.7463851288,24.4154447141 | |
92.5470100593,17.3595513748 | |
79.0426629356,4.59850018907 | |
89.9791366918,29.523946842 | |
3.89920214563,91.3650215111 | |
35.4669861576,62.1865368798 | |
2.78150918086,24.5280230552 | |
50.0390951889,57.0414421682 | |
64.4521660758,48.4962172448 | |
94.4915452316,56.6508179406 | |
47.1655534769,15.8292055671 | |
94.2027011374,45.6802385454 | |
30.5846324871,54.783635876 | |
57.7043252948,0.286661610381 | |
41.7908674949,14.7206014023 | |
59.6689465934,64.8849831965 | |
92.2553335495,55.9096460272 | |
48.493467262,69.4766837809 | |
23.1837859581,71.4406867443 | |
29.0737623652,66.9391416961 | |
95.7442323112,89.4677505059 | |
68.7707275828,40.9900140055 | |
84.5445737133,32.1707309618 | |
67.4126251988,56.6710579117 | |
10.688352016,28.1745892928 | |
56.7620324155,18.3034334207 | |
50.6751320678,86.6916908032 | |
74.6185482896,34.022483532 | |
20.7011996002,32.855295357 | |
11.479054664,1.59204297586 | |
51.6805387648,25.4063026358 | |
84.4109522357,47.237632645 | |
90.6395051745,57.7917166935 | |
58.6159601042,84.1226173848 | |
46.2184509277,28.559934585 | |
97.0302485783,41.3135022812 | |
31.3144587058,87.2459910122 | |
5.93357833962,95.6812831872 | |
""" | |
data = [] | |
k = int(sys.argv[1]) if len(sys.argv) > 1 else 4 | |
for line in dataset.split(): | |
if line: | |
x,y = line.split(',', 2) | |
data.append(Point(x,y)) | |
clusters = kmeans(data, k) | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
for cluster in clusters: | |
_x = [point.x for point in cluster.points] | |
_y = [point.y for point in cluster.points] | |
ax.plot(_x,_y,'o') | |
plt.show() |
index = int(len(data) * random()) rand_point = data[index]
should be written as
rand_point = random.choice(data)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice! :) Great to see people interested in my article!
I actually started with the intent of writing it in Python, oddly enough :p