Skip to content

Instantly share code, notes, and snippets.

@fepegar
Last active March 2, 2018 09:52
Show Gist options
  • Save fepegar/ae258801b0a667b22d56b7aa76a76477 to your computer and use it in GitHub Desktop.
Save fepegar/ae258801b0a667b22d56b7aa76a76477 to your computer and use it in GitHub Desktop.
Naive example implementation of K-means clustering algorithm
import numpy as np
import matplotlib.pyplot as plt
def k_means(points, num_classes, epsilon=1e-5, plot=False):
means = np.random.rand(num_classes, 2)
diff_means = np.inf
while(diff_means > epsilon):
# Assignment step
distances = np.zeros((N, num_classes))
for k in range(num_classes):
distances[:, k] = np.sum(np.square(points - means[k, :]), axis=1)
labels = np.argmin(distances, axis=1)
# Update step
new_means = np.zeros_like(means)
for k in range(num_classes):
points_class = points[labels == k, :]
plt.scatter(means[k, 0], means[k, 1], color=f'C{k}', s=40)
plt.scatter(points_class[:, 0], points_class[:, 1], color=f'C{k}',
s=10, alpha=0.5)
new_means[k, :] = np.mean(points_class, axis=0)
plt.show()
# Convergence check
diff_means = np.sum(np.square(new_means - means))
means = new_means
return means, labels
if __name__ == '__main__':
# Test
N = 1000
points = np.random.rand(N, 2)
num_classes = 5
means, labels = k_means(points, num_classes, plot=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment