Skip to content

Instantly share code, notes, and snippets.

@tvwerkhoven
Forked from bistaumanga/kMeans.py
Created June 2, 2018 13:31
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tvwerkhoven/4fdc9baad760240741a09292901d3abd to your computer and use it in GitHub Desktop.
Save tvwerkhoven/4fdc9baad760240741a09292901d3abd to your computer and use it in GitHub Desktop.
KMeans Clustering Implemented in python with numpy
'''Implementation and of K Means Clustering
Requires : python 2.7.x, Numpy 1.7.1+'''
import numpy as np
def kMeans(X, K, maxIters = 10, plot_progress = None):
centroids = X[np.random.choice(np.arange(len(X)), K)]
for i in range(maxIters):
# Cluster Assignment step
C = np.array([np.argmin([np.dot(x_i-y_k, x_i-y_k) for y_k in centroids]) for x_i in X])
# Ensure we have K clusters, otherwise reset centroids and start over
# If there are fewer than K clusters, outcome will be nan.
if (len(np.unique(C)) < K):
centroids = X[np.random.choice(np.arange(len(X)), K)]
else:
# Move centroids step
centroids = [X[C == k].mean(axis = 0) for k in range(K)]
if plot_progress != None: plot_progress(X, C, np.array(centroids))
return np.array(centroids) , C
'''dEMONSTRATION of K Means Clustering
Requires : python 2.7.x, Numpy 1.7.1+, Matplotlib, 1.2.1+'''
import sys
import pylab as plt
import numpy as np
plt.ion()
def show(X, C, centroids, keep = False):
import time
time.sleep(0.5)
plt.cla()
plt.plot(X[C == 0, 0], X[C == 0, 1], '*b',
X[C == 1, 0], X[C == 1, 1], '*r',
X[C == 2, 0], X[C == 2, 1], '*g')
plt.plot(centroids[:,0],centroids[:,1],'*m',markersize=20)
plt.draw()
if keep :
plt.ioff()
plt.show()
# generate 3 cluster data
# data = np.genfromtxt('data1.csv', delimiter=',')
m1, cov1 = [9, 8], [[1.5, 2], [1, 2]]
m2, cov2 = [5, 13], [[2.5, -1.5], [-1.5, 1.5]]
m3, cov3 = [3, 7], [[0.25, 0.5], [-0.1, 0.5]]
data1 = np.random.multivariate_normal(m1, cov1, 250)
data2 = np.random.multivariate_normal(m2, cov2, 180)
data3 = np.random.multivariate_normal(m3, cov3, 100)
X = np.vstack((data1,np.vstack((data2,data3))))
np.random.shuffle(X)
from kMeans import kMeans
centroids, C = kMeans(X, K = 3, plot_progress = show)
show(X, C, centroids, True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment