Created
June 28, 2017 00:21
-
-
Save theodox/14d9e35a14414c4ca82b41d933754119 to your computer and use it in GitHub Desktop.
A simple implementation of the KMeans algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict, namedtuple | |
import operator | |
import random | |
from functools import reduce | |
import math | |
class Vector(namedtuple('vector', 'x y z')): | |
"""A generic Vector object""" | |
def __div__(self, other): | |
return Vector(self.x / other, self.y / other, self.z / other) | |
def __add__(self, other): | |
return Vector(self.x + other.x, self.y + other.y, self.z + other.z) | |
def __sub__(self, other): | |
return Vector(self.x - other.x, self.y - other.y, self.z - other.z) | |
def length(self): | |
return math.sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2) | |
def distance(self, other): | |
return math.sqrt((self.x - other.x) ** 2 + (self.y - other.y) ** 2 + (self.z - other.z) ** 2) | |
def hash_points(*points): | |
"""Returns a dictionary of hash, Vector for a list of points""" | |
return {hash(p): [Vector(*p), None] for p in points} | |
def get_bounds(data): | |
"""The bounding box of a list of Vectors""" | |
xmin = ymin = zmin = xmax = ymax = zmax = 0 | |
for p in data: | |
xmin = min(xmin, p.x) | |
ymin = min(ymin, p.y) | |
zmin = min(xmin, p.z) | |
xmax = max(xmax, p.x) | |
ymax = max(ymax, p.y) | |
zmax = max(xmax, p.z) | |
return xmin, ymin, zmin, xmax, ymax, zmax | |
def assign(data, means): | |
"""Partition <data> among <means> and recalculate means | |
data is a dictionary hash: (point, means index) | |
means is a list of Vectors | |
returns the same dictionary with revised assignements, and a new list of means with the same indexes | |
""" | |
recompute = defaultdict(list) | |
for point_hash, entry in data.items(): | |
point, owner = entry | |
_, owner = min((point.distance(means[m]), m) for m in range(len(means))) | |
data[point_hash][1] = owner | |
recompute[owner].append(point) | |
new_means = [ | |
reduce(operator.add, recompute[k]) / len(recompute[k]) for k in recompute.keys() | |
] | |
return data, new_means | |
def randomize_seeds(k, xmin, ymin, zmin, xmax, ymax, zmax): | |
"""given a count <k> and a bounding box, return a list of <k> randomly generated seed pointd""" | |
r = random.Random(k) | |
def rand_point(): | |
return Vector(r.randint(xmin, xmax), r.randint(ymin, ymax), r.randint(zmin, zmax)) | |
return [rand_point() for c in range(k + 1)] | |
kmeans_result = namedtuple('kmeans', 'data means iterations') | |
def kmeans(k, points, max_iteration=100): | |
""" | |
Given a number of seed points and a list of point tuples, iteratively apply the kmeans algorithm until the buckets | |
are stable or max_iterations is reached | |
returns a kmeans_result namedtuple, whose fields are | |
data: a dictionary {hash: Vector, bucket index} of points | |
means: a list [Vector] of means points, corresponing to the bucket index | |
iterations: the number of iterations used | |
""" | |
data = hash_points(*points) | |
xmin, ymin, zmin, xmax, ymax, zmax = get_bounds([j[0] for j in data.values()]) | |
means = randomize_seeds(k, xmin, ymin, zmin, xmax, ymax, zmax) | |
failsafe = 1 | |
while failsafe < max_iteration: | |
new_data, new_means = assign(data, means) | |
if new_means == means: | |
break | |
data = new_data | |
means = new_means | |
failsafe += 1 | |
return data, means, failsafe | |
bucket = namedtuple('bucket', 'mean points') | |
def get_clusters(data, means): | |
"""given a data dictionary and a 'means' list, return a list of tuples 'mean', 'points' | |
""" | |
buckets = defaultdict(list) | |
for point, owner in data.values(): | |
buckets[owner].append(point) | |
return [bucket(means[m], buckets[m]) for m in range(len(means))] | |
cluster = namedtuple('cluster', 'center members volume ratio') | |
def cluster_stats(data, means): | |
'''given s data dictionary and a 'means' list, return a tuple of | |
'center' -- the mean position | |
'members' -- integer number of members | |
'volume' -- the cubic volume of the cluster | |
'ratio' -- the numbers of members per cubic unit of the the cluster | |
''' | |
results = [] | |
clusters = get_clusters(data, means) | |
for centroid, members in clusters: | |
count = len(members) | |
bbox = get_bounds(members) | |
volume = math.sqrt((bbox[0] - bbox[3]) ** 2 + (bbox[1] - bbox[4]) ** 2 + (bbox[2] - bbox[5]) ** 2) | |
results.append(cluster(centroid, count, volume, count / volume)) | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment