Skip to content

Instantly share code, notes, and snippets.

@warmlogic
Created September 10, 2015 22:28
Show Gist options
  • Save warmlogic/bb9810f7a0dc350297a4 to your computer and use it in GitHub Desktop.
Save warmlogic/bb9810f7a0dc350297a4 to your computer and use it in GitHub Desktop.
# calculate and visualize silhouette score from k-means clustering.
# plots first two features in 2D and first three features in 3D.
# from: http://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_silhouette_analysis.html
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, silhouette_samples
from mpl_toolkits.mplot3d import Axes3D
##### cluster data into K=1..K_MAX clusters #####
K_MAX = 10
KK = range(1,K_MAX+1)
KM = []
for k in KK:
thisKM = KMeans(n_clusters=k, init='k-means++', n_init=10, n_jobs=10)
thisKM.fit(trans[:,:n_components])
KM.append(thisKM)
# choose a random subset to visualize
n=10000
# n=7500
# n=5000
sil_rows = np.random.choice(range(trans.shape[0]), n, replace=False)
min_n_clus = 2
max_n_clus = K_MAX
silhouettes = []
X = trans[sil_rows,:n_components]
for k in KM:
n_clusters = k.n_clusters
cluster_labels = k.labels_[sil_rows]
if (n_clusters >= min_n_clus) & (n_clusters <= max_n_clus):
sample_silhouette_values = silhouette_samples(X, cluster_labels)
# silhouette_avg = silhouette_score(X, cluster_labels, metric='euclidean')
silhouette_avg = np.mean(sample_silhouette_values)
print 'k=%d, score=%.5f' % (n_clusters, silhouette_avg)
silhouettes.append(silhouette_avg)
fig = plt.figure(figsize=(14,5))
ax1 = fig.add_subplot(1, 3, 1)
ax2 = fig.add_subplot(1, 3, 2)
ax3 = fig.add_subplot(1, 3, 3, projection='3d')
y_lower = 10
for i in range(n_clusters):
# Aggregate the silhouette scores for samples belonging to
# cluster i, and sort them
ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]
ith_cluster_silhouette_values.sort()
size_cluster_i = ith_cluster_silhouette_values.shape[0]
y_upper = y_lower + size_cluster_i
color = plt.cm.spectral(float(i) / n_clusters)
ax1.fill_betweenx(np.arange(y_lower, y_upper),
0, ith_cluster_silhouette_values,
facecolor=color, edgecolor=color, alpha=0.7)
# Label the silhouette plots with their cluster numbers at the middle
ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
# Compute the new y_lower for next plot
y_lower = y_upper + 10 # 10 for the 0 samples
ax1.set_title("Silhouette plot")
ax1.set_xlabel("Silhouette coefficient values")
ax1.set_ylabel("Cluster label")
# The vertical line for average silhoutte score of all the values
ax1.axvline(x=silhouette_avg, color="red", linestyle="--")
ax1.set_yticks([]) # Clear the yaxis labels / ticks
ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
sil_colors = plt.cm.spectral(cluster_labels.astype(float) / n_clusters)
# 2nd Plot showing the actual clusters formed in 2D space
ax2.scatter(X[:, 0], X[:, 1], marker='.', s=30, lw=0, alpha=0.7, c=sil_colors)
# Labeling the clusters
centers = k.cluster_centers_
# Draw white circles at cluster centers
ax2.scatter(centers[:, 0], centers[:, 1], marker='o', c="white", alpha=1, s=200)
for i, c in enumerate(centers):
ax2.scatter(c[0], c[1], marker='$%d$' % i, alpha=1, s=50)
ax2.set_title("2D visualization of clustered data")
ax2.set_xlabel("Feature space, 1st feature")
ax2.set_ylabel("Feature space, 2nd feature")
# 3rd Plot showing the actual clusters formed in 3D space
ax3.scatter(X[:, 0], X[:, 1], X[:, 2], marker='.', s=30, lw=0, alpha=0.7, c=sil_colors);
# # Labeling the clusters
# centers = k.cluster_centers_
# # Draw white circles at cluster centers
# ax3.scatter(centers[:, 0], centers[:, 1], centers[:, 2], marker='o', c="white", alpha=1, s=200)
# for i, c in enumerate(centers):
# ax3.scatter(c[0], c[1], c[2], marker='$%d$' % i, alpha=1, s=50)
ax3.set_title("3D visualization of clustered data")
ax3.set_xlabel("Feature space, 1st feature")
ax3.set_ylabel("Feature space, 2nd feature")
ax3.set_zlabel("Feature space, 3nd feature")
plt.suptitle(("Silhouette score on sample data=%.4f"
" for KMeans clustering (k=%d)"
", kept %d PCs"
% (silhouette_avg, n_clusters, n_components)),
fontsize=14, fontweight='bold')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment