Skip to content

Instantly share code, notes, and snippets.

@MortisHuang
Last active June 11, 2019 09:01
Show Gist options
  • Save MortisHuang/2e9847c158173d7eab2e94311237230b to your computer and use it in GitHub Desktop.
Save MortisHuang/2e9847c158173d7eab2e94311237230b to your computer and use it in GitHub Desktop.
Let's train a K-Means model to cluster the MNIST handwritten digits to 10 clusters.
from sklearn.cluster import KMeans
from keras.datasets import mnist
import numpy as np
def accu(y_true, y_pred):
"""
Calculate clustering accuracy. Require scikit-learn installed
# Arguments
y: true labels, numpy.array with shape `(n_samples,)`
y_pred: predicted labels, numpy.array with shape `(n_samples,)`
# Return
accuracy, in [0,1]
"""
y_true = y_true.astype(np.int64)
assert y_pred.size == y_true.size
D = max(y_pred.max(), y_true.max()) + 1
w = np.zeros((D, D), dtype=np.int64)
for i in range(y_pred.size):
w[y_pred[i], y_true[i]] += 1
from sklearn.utils.linear_assignment_ import linear_assignment
ind = linear_assignment(w.max() - w)
return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x = np.concatenate((x_train, x_test))
y = np.concatenate((y_train, y_test))
x = x.reshape((x.shape[0], -1))
x = np.divide(x, 255.)
# 10 clusters
n_clusters = len(np.unique(y))
# Runs in parallel 4 CPUs
kmeans = KMeans(n_clusters=n_clusters, n_init=20, n_jobs=4)
# Train K-Means.
y_pred_kmeans = kmeans.fit_predict(x)
# Evaluate the K-Means clustering accuracy.
acc = accu(y,y_pred_kmeans)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment