Skip to content

Instantly share code, notes, and snippets.

@iamaziz
Last active Jun 15, 2019
Embed
What would you like to do?
Vanilla k_means algorithm with zero import (for any number of dimensions)
def distance(u, v):
"""
Calculates Euclidean distance between two point
distance = square_root( sum(u_i - v_i)^2 )
u: [float, float], point1
v: [float, float], point2
"""
sum_ = sum((u[i] - v[i]) ** 2 for i in range(len(u)))
return sum_ ** (1 / 2)
def get_closer(target, *args):
"""
Return the closest point (from points in `args`) to target
target: [float], target point
*args: [[float]], list of points
"""
min_distance = float("inf")
for point in args:
d = distance(point, target)
if d < min_distance:
min_distance = d
closer = point
return closer
def get_center(cluster):
"""
Calculates the centroid point for `cluster`
cluster: [[float]], list of the points in cluster
"""
center = []
n = len(cluster)
for i in range(len(cluster[0])):
c = sum(p[i] for p in cluster) / n
center.append(round(c, 1))
return center
def k_means(data, k=2, *centers):
"""
Recursive k_means algorithm
data: [[float]], data points to consider for clustering
k: int, number of clusters
centers: [[float]], optiona - initial centroids
"""
centers = list(centers) if centers else [data[i] for i in range(k)]
clusters = [[] for _ in range(k)]
for point in data:
nearest = get_closer(point, *centers)
nearest_cluster_index = centers.index(nearest)
clusters[nearest_cluster_index].append(point)
new_centers = [get_center(cluster) for cluster in clusters]
if centers == new_centers: return clusters, centers
return k_means(data, k, *new_centers)
# -- Test
>>> weights = [74, 77, 81, 76, 80, 91, 88, 93, 88, 92]
>>> heights = [179, 182, 181, 175, 174, 182, 178, 178, 174, 173]
>>> data = [list(point) for point in zip(weights, heights)]
>>> data
[[74, 179], [77, 182], [81, 181], [76, 175], [80, 174], [91, 182], [88, 178], [93, 178], [88, 174], [92, 173]]
>>> clusters, centroids = k_means(data)
>>> for c in clusters: print(c)
[[74, 179], [77, 182], [81, 181], [76, 175], [80, 174]]
[[91, 182], [88, 178], [93, 178], [88, 174], [92, 173]]
>>> for c in centroids: print(c)
[77.6, 178.2]
[90.4, 177.0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment