Last active
February 12, 2018 22:41
-
-
Save omgimanerd/ff0b9a8dd225f2d2e94aa4be76bc91d9 to your computer and use it in GitHub Desktop.
The KMeans clustering algorithm implemented using both numpy and native Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Kmeans clustering algorithm heavily influenced by | |
# https://gist.github.com/iandanforth/5862470 | |
import random | |
def get_color_distance(rgb1, rgb2): | |
""" | |
Given two RGB color values, this returns the squared euclidean distance | |
between the two colors. | |
""" | |
assert len(rgb1) == 3 and len(rgb2) == 3 | |
return sum([(rgb1[i] - rgb2[i]) ** 2 for i in range(3)]) | |
def get_centroid(colors, weights): | |
""" | |
Given a list of colors and a parallel list of weights, this function | |
returns the centroid of the points. | |
""" | |
for i, weight in enumerate(weights): | |
colors[i] = [c * weight for c in colors[i]] | |
rgb = zip(*colors) | |
wsum = sum(weights) | |
return [sum(v) / wsum for v in rgb] | |
def get_closest_centroid_index(p, centroids): | |
""" | |
Given a point p and list of centroids, this function returns the index | |
of the centroid in the list that p is closest to. | |
""" | |
index = 0 | |
min_distance = get_color_distance(p, centroids[0]) | |
for i in range(1, len(centroids)): | |
distance = get_color_distance(p, centroids[i]) | |
if distance < min_distance: | |
min_distance = distance | |
index = i | |
return index | |
def kmeans(k, colors, weights, cutoff): | |
""" | |
Given a k value, a list of colors, a parallel containing the weights | |
of those colors, and a centroid shift cutoff value, this function will | |
return k centroids | |
""" | |
# Select k random centroids to start with | |
centroids = random.sample(colors, k) | |
biggest_shift = cutoff + 1 | |
j = 0 | |
while biggest_shift > cutoff: | |
clusters = [[] for i in centroids] | |
cluster_weights = [[] for i in centroids] | |
shifts = [[] for i in centroids] | |
# For each point, figure out which centroid it is closest to and add | |
# it to a that centroid's cluster. We will represent the clusters as | |
# a list of list parallel to the list of centroids. | |
for color, weight in zip(colors, weights): | |
index = get_closest_centroid_index(color, centroids) | |
clusters[index].append(color) | |
cluster_weights[index].append(weight) | |
# Calculate the amount that the new centroids shifted. When the maximum | |
# shift amount is lower than a specified threshold, then we stop the | |
# algorithm. This ensures that the centroids have stopped shifting and | |
# we have found the desired approximation for their location. | |
for i, cluster in enumerate(clusters): | |
new_centroid = centroids[i] | |
if len(cluster) > 0: | |
new_centroid = get_centroid(cluster, cluster_weights[i]) | |
shifts[i] = get_color_distance(new_centroid, centroids[i]) | |
centroids[i] = new_centroid | |
biggest_shift = max(shifts) | |
return centroids, clusters | |
if __name__ == '__main__': | |
""" | |
Generating some randomly distributed clusters to test the algorithm | |
$ python kmeans.py | |
to test, must have scipy and sklearn installed. | |
""" | |
from mpl_toolkits.mplot3d import Axes3D | |
from sklearn.datasets.samples_generator import make_blobs | |
import matplotlib.pyplot as plt | |
centers = [[25, 25, 25], [100, 100, 100], [2, 57, 20]] | |
p, l = make_blobs(n_samples=100, centers=centers, cluster_std=5, random_state=0) | |
w = [1 for i in range(100)] | |
centroids, clusters = kmeans(3, list(p), w, 2) | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
r, g, b = zip(*p) | |
cr, cg, cb = zip(*centroids) | |
ax.scatter(r, g, b, c='red') | |
ax.scatter(cr, cg, cb, c='green', s=100) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Kmeans clustering algorithm heavily influenced by | |
# https://gist.github.com/iandanforth/5862470 | |
import numpy as np | |
def _get_centroid(colors, weights): | |
return (colors * weights).sum(axis=0) / weights.sum() | |
def _get_index_closest(p, centroids): | |
""" | |
Given a point p and list of centroids, this function returns the index | |
of the centroid in the list that p is closest to. | |
""" | |
return np.argmin(((centroids - p) ** 2).sum(axis=1)) | |
def kmeans(k, colors, weights, cutoff): | |
""" | |
Given a k value, a list of colors, a parallel list containing the weights | |
of those colors, and a centroid shift cutoff value, this function will | |
return k centroids | |
""" | |
assert len(colors) == len(weights) | |
# Multiply the colors by their weights | |
colors = np.array(colors) | |
weights = np.array(weights).reshape(len(weights), 1) | |
# Pick k random colors as starting centroids | |
centroids = colors[np.random.randint(colors.shape[0], size=k),:] | |
biggest_shift = cutoff + 1 | |
while biggest_shift > cutoff: | |
# Calculate which centroid each color is closest to. This generates an | |
# array of indices representing which centroid the point is closest to. | |
# This array is parallel to the points array. | |
closest = np.array([_get_index_closest(c, centroids) for c in colors]) | |
# Cluster the points by grouping them according to which centroid | |
# they're closest to. We will also cluster the weights of the points | |
# for recalculation of the centroid later. | |
clusters = np.array([colors[closest == i] for i in range(k)]) | |
cluster_weights = np.array([weights[closest == i] for i in range(k)]) | |
# Recalculate the locations of the centroids. | |
new_centroids = np.array([ | |
_get_centroid(c, w) for c, w in zip(clusters, cluster_weights)]) | |
# Calculate the amount that the new centroids shifted. When this amount | |
# is lower than a specified threshold, then we stop the algorithm. | |
biggest_shift = ((new_centroids - centroids) ** 2).sum(axis=0).min() | |
centroids = new_centroids | |
return centroids, clusters | |
if __name__ == '__main__': | |
""" | |
Generating some randomly distributed clusters to test the algorithm | |
$ python kmeans.py | |
to test, must have scipy and sklearn installed. | |
""" | |
from mpl_toolkits.mplot3d import Axes3D | |
from sklearn.datasets.samples_generator import make_blobs | |
import matplotlib.pyplot as plt | |
centers = [[25, 25, 25], [100, 100, 100], [2, 57, 20]] | |
p, l = make_blobs(n_samples=100, centers=centers, cluster_std=5, | |
random_state=0) | |
weights = np.ones(100) | |
centroids, clusters = kmeans(3, list(p), weights, 1) | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
cr, cg, cb = zip(*centroids) | |
ax.scatter(cr, cg, cb, c='green', s=100) | |
for cluster in clusters: | |
r, g, b = zip(*cluster) | |
ax.scatter(r, g, b, c='red', s=10) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment