Skip to content

Instantly share code, notes, and snippets.

@spyhi
Created November 17, 2017 12:15
Show Gist options
  • Save spyhi/2ce0ab2d008a4a1b46e41b9e2dcf04e3 to your computer and use it in GitHub Desktop.
Save spyhi/2ce0ab2d008a4a1b46e41b9e2dcf04e3 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.cluster import KMeans
from sklearn.utils import check_random_state
from sklearn.datasets.samples_generator import make_blobs
RUNS = 1
_CLUSTERS = 4
T_INIT_RANGE = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 14, 15, 16, 17, 18, 19, 20])
def kmeansinittest():
"""This script runs the test to determine
the effect on increasing number of random inits"""
initfig = plt.figure()
plots = []
inertia = np.empty((len(T_INIT_RANGE), RUNS))
for run in range(RUNS):
data = make_blobs(n_samples=200, n_features=2, centers=_CLUSTERS,
cluster_std=1)
for i, t_init in enumerate(T_INIT_RANGE):
km = KMeans(n_clusters=_CLUSTERS, init='random', n_init=t_init,
max_iter=300, algorithm='full')
km.fit(data[0])
inertia[i, run] = km.inertia_
errorp = plt.errorbar(T_INIT_RANGE, inertia.mean(axis=1), inertia.std(axis=1))
print(inertia.mean(axis=1), inertia.std(axis=1))
plots.append(errorp)
plt.xlabel('n_init')
plt.ylabel('inertia')
plt.title("Mean inertia for various k-means init across %d runs" % RUNS)
h = .02
x_min, x_max = data[0][:,0].min()-1, data[0][:,0].max()+1
y_min, y_max = data[0][:,1].min()-1, data[0][:,1].max()+1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z = km.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
centroids = km.cluster_centers_
fig2 = plt.figure()
plt.subplot(111)
plt.imshow(Z, interpolation='nearest',
extent=(xx.min(), xx.max(), yy.min(), yy.max()),
aspect='auto', origin='lower', cmap='Pastel2')
plt.scatter(data[0][:,0], data[0][:,1], marker='o', c=data[1][:], s=25, cmap='Set1')
plt.scatter(centroids[:,0], centroids[:,1], marker="x", c="w", s=150, linewidths=3)
plt.show()
kmeansinittest()
@spyhi
Copy link
Author

spyhi commented Nov 17, 2017

Outputs means and standard deviations in console for inspection, and plt outputs should look something like this (error vs n_init, and final plot/fit)

inertia_1krun_1
samplekmeansclustering_1krun_1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment