Skip to content

Instantly share code, notes, and snippets.

@omgimanerd
Last active February 12, 2018 22:41
Show Gist options
  • Save omgimanerd/ff0b9a8dd225f2d2e94aa4be76bc91d9 to your computer and use it in GitHub Desktop.
Save omgimanerd/ff0b9a8dd225f2d2e94aa4be76bc91d9 to your computer and use it in GitHub Desktop.
The KMeans clustering algorithm implemented using both numpy and native Python
#!/usr/bin/env python3
# Kmeans clustering algorithm heavily influenced by
# https://gist.github.com/iandanforth/5862470
import random
def get_color_distance(rgb1, rgb2):
"""
Given two RGB color values, this returns the squared euclidean distance
between the two colors.
"""
assert len(rgb1) == 3 and len(rgb2) == 3
return sum([(rgb1[i] - rgb2[i]) ** 2 for i in range(3)])
def get_centroid(colors, weights):
"""
Given a list of colors and a parallel list of weights, this function
returns the centroid of the points.
"""
for i, weight in enumerate(weights):
colors[i] = [c * weight for c in colors[i]]
rgb = zip(*colors)
wsum = sum(weights)
return [sum(v) / wsum for v in rgb]
def get_closest_centroid_index(p, centroids):
"""
Given a point p and list of centroids, this function returns the index
of the centroid in the list that p is closest to.
"""
index = 0
min_distance = get_color_distance(p, centroids[0])
for i in range(1, len(centroids)):
distance = get_color_distance(p, centroids[i])
if distance < min_distance:
min_distance = distance
index = i
return index
def kmeans(k, colors, weights, cutoff):
"""
Given a k value, a list of colors, a parallel containing the weights
of those colors, and a centroid shift cutoff value, this function will
return k centroids
"""
# Select k random centroids to start with
centroids = random.sample(colors, k)
biggest_shift = cutoff + 1
j = 0
while biggest_shift > cutoff:
clusters = [[] for i in centroids]
cluster_weights = [[] for i in centroids]
shifts = [[] for i in centroids]
# For each point, figure out which centroid it is closest to and add
# it to a that centroid's cluster. We will represent the clusters as
# a list of list parallel to the list of centroids.
for color, weight in zip(colors, weights):
index = get_closest_centroid_index(color, centroids)
clusters[index].append(color)
cluster_weights[index].append(weight)
# Calculate the amount that the new centroids shifted. When the maximum
# shift amount is lower than a specified threshold, then we stop the
# algorithm. This ensures that the centroids have stopped shifting and
# we have found the desired approximation for their location.
for i, cluster in enumerate(clusters):
new_centroid = centroids[i]
if len(cluster) > 0:
new_centroid = get_centroid(cluster, cluster_weights[i])
shifts[i] = get_color_distance(new_centroid, centroids[i])
centroids[i] = new_centroid
biggest_shift = max(shifts)
return centroids, clusters
if __name__ == '__main__':
"""
Generating some randomly distributed clusters to test the algorithm
$ python kmeans.py
to test, must have scipy and sklearn installed.
"""
from mpl_toolkits.mplot3d import Axes3D
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
centers = [[25, 25, 25], [100, 100, 100], [2, 57, 20]]
p, l = make_blobs(n_samples=100, centers=centers, cluster_std=5, random_state=0)
w = [1 for i in range(100)]
centroids, clusters = kmeans(3, list(p), w, 2)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
r, g, b = zip(*p)
cr, cg, cb = zip(*centroids)
ax.scatter(r, g, b, c='red')
ax.scatter(cr, cg, cb, c='green', s=100)
plt.show()
#!/usr/bin/env python3
# Kmeans clustering algorithm heavily influenced by
# https://gist.github.com/iandanforth/5862470
import numpy as np
def _get_centroid(colors, weights):
return (colors * weights).sum(axis=0) / weights.sum()
def _get_index_closest(p, centroids):
"""
Given a point p and list of centroids, this function returns the index
of the centroid in the list that p is closest to.
"""
return np.argmin(((centroids - p) ** 2).sum(axis=1))
def kmeans(k, colors, weights, cutoff):
"""
Given a k value, a list of colors, a parallel list containing the weights
of those colors, and a centroid shift cutoff value, this function will
return k centroids
"""
assert len(colors) == len(weights)
# Multiply the colors by their weights
colors = np.array(colors)
weights = np.array(weights).reshape(len(weights), 1)
# Pick k random colors as starting centroids
centroids = colors[np.random.randint(colors.shape[0], size=k),:]
biggest_shift = cutoff + 1
while biggest_shift > cutoff:
# Calculate which centroid each color is closest to. This generates an
# array of indices representing which centroid the point is closest to.
# This array is parallel to the points array.
closest = np.array([_get_index_closest(c, centroids) for c in colors])
# Cluster the points by grouping them according to which centroid
# they're closest to. We will also cluster the weights of the points
# for recalculation of the centroid later.
clusters = np.array([colors[closest == i] for i in range(k)])
cluster_weights = np.array([weights[closest == i] for i in range(k)])
# Recalculate the locations of the centroids.
new_centroids = np.array([
_get_centroid(c, w) for c, w in zip(clusters, cluster_weights)])
# Calculate the amount that the new centroids shifted. When this amount
# is lower than a specified threshold, then we stop the algorithm.
biggest_shift = ((new_centroids - centroids) ** 2).sum(axis=0).min()
centroids = new_centroids
return centroids, clusters
if __name__ == '__main__':
"""
Generating some randomly distributed clusters to test the algorithm
$ python kmeans.py
to test, must have scipy and sklearn installed.
"""
from mpl_toolkits.mplot3d import Axes3D
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
centers = [[25, 25, 25], [100, 100, 100], [2, 57, 20]]
p, l = make_blobs(n_samples=100, centers=centers, cluster_std=5,
random_state=0)
weights = np.ones(100)
centroids, clusters = kmeans(3, list(p), weights, 1)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
cr, cg, cb = zip(*centroids)
ax.scatter(cr, cg, cb, c='green', s=100)
for cluster in clusters:
r, g, b = zip(*cluster)
ax.scatter(r, g, b, c='red', s=10)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment