Vanilla k_means algorithm with zero import (for any number of dimensions)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def distance(u, v): | |
""" | |
Calculates Euclidean distance between two point | |
distance = square_root( sum(u_i - v_i)^2 ) | |
u: [float, float], point1 | |
v: [float, float], point2 | |
""" | |
sum_ = sum((u[i] - v[i]) ** 2 for i in range(len(u))) | |
return sum_ ** (1 / 2) | |
def get_closer(target, *args): | |
""" | |
Return the closest point (from points in `args`) to target | |
target: [float], target point | |
*args: [[float]], list of points | |
""" | |
min_distance = float("inf") | |
for point in args: | |
d = distance(point, target) | |
if d < min_distance: | |
min_distance = d | |
closer = point | |
return closer | |
def get_center(cluster): | |
""" | |
Calculates the centroid point for `cluster` | |
cluster: [[float]], list of the points in cluster | |
""" | |
center = [] | |
n = len(cluster) | |
for i in range(len(cluster[0])): | |
c = sum(p[i] for p in cluster) / n | |
center.append(round(c, 1)) | |
return center | |
def k_means(data, k=2, *centers): | |
""" | |
Recursive k_means algorithm | |
data: [[float]], data points to consider for clustering | |
k: int, number of clusters | |
centers: [[float]], optiona - initial centroids | |
""" | |
centers = list(centers) if centers else [data[i] for i in range(k)] | |
clusters = [[] for _ in range(k)] | |
for point in data: | |
nearest = get_closer(point, *centers) | |
nearest_cluster_index = centers.index(nearest) | |
clusters[nearest_cluster_index].append(point) | |
new_centers = [get_center(cluster) for cluster in clusters] | |
if centers == new_centers: return clusters, centers | |
return k_means(data, k, *new_centers) | |
# -- Test | |
>>> weights = [74, 77, 81, 76, 80, 91, 88, 93, 88, 92] | |
>>> heights = [179, 182, 181, 175, 174, 182, 178, 178, 174, 173] | |
>>> data = [list(point) for point in zip(weights, heights)] | |
>>> data | |
[[74, 179], [77, 182], [81, 181], [76, 175], [80, 174], [91, 182], [88, 178], [93, 178], [88, 174], [92, 173]] | |
>>> clusters, centroids = k_means(data) | |
>>> for c in clusters: print(c) | |
[[74, 179], [77, 182], [81, 181], [76, 175], [80, 174]] | |
[[91, 182], [88, 178], [93, 178], [88, 174], [92, 173]] | |
>>> for c in centroids: print(c) | |
[77.6, 178.2] | |
[90.4, 177.0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment