Created
April 29, 2015 09:17
-
-
Save Madhuka/2e27dce9680f42619b83 to your computer and use it in GitHub Desktop.
Affinity Propagation clustering sample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn.cluster import AffinityPropagation | |
from sklearn import metrics | |
from sklearn.datasets.samples_generator import make_blobs | |
# generating sampl data | |
centers = [[5, 5], [0, 0], [1, 5],[5, -1]] | |
X, labels_true =make_blobs(n_samples=500, n_features=5, centers=centers, cluster_std=0.9, center_box=(1, 10.0), shuffle=True, random_state=0) | |
# Compute Affinity Propagation | |
af = AffinityPropagation(max_iter=150, preference =-120).fit(X) | |
cluster_centers_indices = af.cluster_centers_indices_ | |
labels = af.labels_ | |
n_clusters_ = len(cluster_centers_indices) | |
#print results | |
print('Estimated number of clusters: %d' % n_clusters_) | |
print("Homogeneity: %0.3f" % metrics.homogeneity_score(labels_true, labels)) | |
print("Completeness: %0.3f" % metrics.completeness_score(labels_true, labels)) | |
print("V-measure: %0.3f" % metrics.v_measure_score(labels_true, labels)) | |
print("Adjusted Rand Index: %0.3f"% metrics.adjusted_rand_score(labels_true, labels)) | |
print("Adjusted Mutual Information: %0.3f"% metrics.adjusted_mutual_info_score(labels_true, labels)) | |
print("Silhouette Coefficient: %0.3f"% metrics.silhouette_score(X, labels)) | |
# Drawing chart | |
# Plot result | |
import matplotlib.pyplot as plt | |
from itertools import cycle | |
plt.close('all') | |
plt.figure(1) | |
plt.clf() | |
colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk') | |
for k, col in zip(range(n_clusters_), colors): | |
class_members = labels == k | |
cluster_center = X[cluster_centers_indices[k]] | |
plt.plot(X[class_members, 0], X[class_members, 1], col + '.') | |
plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col, | |
markeredgecolor='k', markersize=14) | |
for x in X[class_members]: | |
plt.plot([cluster_center[0], x[0]], [cluster_center[1], x[1]], col) | |
plt.title('Estimated number of clusters: %d' % n_clusters_) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment