Skip to content

Instantly share code, notes, and snippets.

@raghavrv
Created January 10, 2015 13:07
Show Gist options
  • Save raghavrv/d51fed2cd65873f79414 to your computer and use it in GitHub Desktop.
Save raghavrv/d51fed2cd65873f79414 to your computer and use it in GitHub Desktop.
The Silhouette plot for different n_cluster values when clustered using KMeans
# Generating the sample data from make_blobs
X, y = make_blobs(n_samples=100,
n_features=2,
centers=4,
cluster_std=1.0,
center_box=(-10.0, 10.0),
shuffle=True,
random_state=0) # For reproducibility
range_n_clusters = [ 2, 4, 6 ]
fignum = 0
for n_clusters in range_n_clusters:
fignum += 1
fig = plt.figure(fignum, figsize = (6, 6))
plt.grid()
plt.xlim([-0.1, 1])
plt.ylim([0, len(X)])
# Initialize kmeans clusterer with 3 clusters and a random genertor seed of 10
# for reproducibility
kmeans_clusterer = KMeans(n_clusters=n_clusters, random_state=10)
cluster_labels = kmeans_clusterer.fit_predict(X)
# Compute the silhouette scores for each sample
sample_silhouette_values = silhouette_samples(X, cluster_labels)
sorted_clustered_sample_silhouette_values = []
for i in np.unique(cluster_labels):
ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]
# Add the ith_cluster_silhouette_values after sorting them
ith_cluster_silhouette_values.sort()
# The 0 sample is to differentiate clearly between the different clusters
sorted_clustered_sample_silhouette_values += ith_cluster_silhouette_values.tolist() + [0]
plt.plot(sorted_clustered_sample_silhouette_values, range(len(X) + n_clusters))
fig.show()
fignum += 1
fig = plt.figure(fignum, figsize = (6, 6))
plt.grid()
for k in range(len(X)):
color = cm.spectral(float(cluster_labels[k]) / n_clusters, 1)
plt.plot(X[k, 0], X[k, 1], 'o', marker='.', c=color)
plt.title("The visualization of the clustered data")
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment