{{ message }}

Instantly share code, notes, and snippets.

# bistaumanga/kMeans.py

Last active Oct 27, 2019
KMeans Clustering Implemented in python with numpy
 '''Implementation and of K Means Clustering Requires : python 2.7.x, Numpy 1.7.1+''' import numpy as np def kMeans(X, K, maxIters = 10, plot_progress = None): centroids = X[np.random.choice(np.arange(len(X)), K), :] for i in range(maxIters): # Cluster Assignment step C = np.array([np.argmin([np.dot(x_i-y_k, x_i-y_k) for y_k in centroids]) for x_i in X]) # Move centroids step centroids = [X[C == k].mean(axis = 0) for k in range(K)] if plot_progress != None: plot_progress(X, C, np.array(centroids)) return np.array(centroids) , C
 '''dEMONSTRATION of K Means Clustering Requires : python 2.7.x, Numpy 1.7.1+, Matplotlib, 1.2.1+''' import sys import pylab as plt import numpy as np plt.ion() def show(X, C, centroids, keep = False): import time time.sleep(0.5) plt.cla() plt.plot(X[C == 0, 0], X[C == 0, 1], '*b', X[C == 1, 0], X[C == 1, 1], '*r', X[C == 2, 0], X[C == 2, 1], '*g') plt.plot(centroids[:,0],centroids[:,1],'*m',markersize=20) plt.draw() if keep : plt.ioff() plt.show() # generate 3 cluster data # data = np.genfromtxt('data1.csv', delimiter=',') m1, cov1 = [9, 8], [[1.5, 2], [1, 2]] m2, cov2 = [5, 13], [[2.5, -1.5], [-1.5, 1.5]] m3, cov3 = [3, 7], [[0.25, 0.5], [-0.1, 0.5]] data1 = np.random.multivariate_normal(m1, cov1, 250) data2 = np.random.multivariate_normal(m2, cov2, 180) data3 = np.random.multivariate_normal(m3, cov3, 100) X = np.vstack((data1,np.vstack((data2,data3)))) np.random.shuffle(X) from kMeans import kMeans centroids, C = kMeans(X, K = 3, plot_progress = show) show(X, C, centroids, True)

### tvwerkhoven commented May 31, 2018 • edited

 kMeans:12 fails always when the function is given a list of zeroes as input for X. It also fails on 'real' data a few percent of the time (in my application). It's because the number of clusters in C is less than K, such that X[C == k].mean() gives an error for k which is not in C. I solved this by checking if len(np.unique(C)) < K, and if so, reset centroids to (new) random sample. See https://gist.github.com/tvwerkhoven/4fdc9baad760240741a09292901d3abd for fix.