Last active
August 29, 2015 14:07
-
-
Save neizod/4d9bc6f3dc810f6c5ba7 to your computer and use it in GitHub Desktop.
k-means clustering.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys | |
from itertools import count | |
class Point(object): | |
def __init__(self, pp): | |
self.x, self.y = pp | |
def __repr__(self): | |
return 'Point([{}, {}])'.format(self.x, self.y) | |
def distance(self, other): | |
return ((self.x - other.x)**2 + (self.y - other.y)**2)**0.5 | |
def nearest(self, clusters): | |
return min((self.distance(cluster), cluster) for cluster in clusters)[1] | |
class Cluster(Point): | |
def __init__(self, pp, ls=None): | |
super().__init__(pp) | |
self.ls = ls or set() | |
def __lt__(self, other): | |
return len(self.ls) < len(other.ls) | |
def __repr__(self): | |
return 'Cluster([{:.2f}, {:.2f}], {})'.format(self.x, self.y, self.ls) | |
def is_stable(self): | |
return (self.x, self.y) == self.find_centroid() | |
def find_centroid(self): | |
if not self.ls: | |
return (self.x, self.y) | |
x = sum(point.x for point in self.ls) / len(self.ls) | |
y = sum(point.y for point in self.ls) / len(self.ls) | |
return (x, y) | |
def update_centroid(self): | |
self.x, self.y = self.find_centroid() | |
def add(self, point): | |
self.ls.add(point) | |
def clear(self): | |
self.ls = set() | |
if len(sys.argv) != 3: | |
exit('usage: ./kmeans.py <map> <start>') | |
points = [Point(int(n) for n in line.split()) for line in open(sys.argv[1])] | |
clusters = [Cluster(int(n) for n in line.split()) for line in open(sys.argv[2])] | |
for i in count(1): | |
print('==== iteration: {} ===='.format(i)) | |
for cluster in clusters: | |
cluster.update_centroid() | |
cluster.clear() | |
for point in points: | |
cluster = point.nearest(clusters) | |
cluster.add(point) | |
for cluster in clusters: | |
print(cluster) | |
if all(cluster.is_stable() for cluster in clusters): | |
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 10 | |
2 5 | |
8 4 | |
5 8 | |
7 5 | |
6 4 | |
1 2 | |
4 9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 10 | |
5 8 | |
1 2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment