Skip to content

Instantly share code, notes, and snippets.

@Madhuka
Created April 29, 2015 09:17
Show Gist options
  • Save Madhuka/2e27dce9680f42619b83 to your computer and use it in GitHub Desktop.
Save Madhuka/2e27dce9680f42619b83 to your computer and use it in GitHub Desktop.
Affinity Propagation clustering sample
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import AffinityPropagation
from sklearn import metrics
from sklearn.datasets.samples_generator import make_blobs
# generating sampl data
centers = [[5, 5], [0, 0], [1, 5],[5, -1]]
X, labels_true =make_blobs(n_samples=500, n_features=5, centers=centers, cluster_std=0.9, center_box=(1, 10.0), shuffle=True, random_state=0)
# Compute Affinity Propagation
af = AffinityPropagation(max_iter=150, preference =-120).fit(X)
cluster_centers_indices = af.cluster_centers_indices_
labels = af.labels_
n_clusters_ = len(cluster_centers_indices)
#print results
print('Estimated number of clusters: %d' % n_clusters_)
print("Homogeneity: %0.3f" % metrics.homogeneity_score(labels_true, labels))
print("Completeness: %0.3f" % metrics.completeness_score(labels_true, labels))
print("V-measure: %0.3f" % metrics.v_measure_score(labels_true, labels))
print("Adjusted Rand Index: %0.3f"% metrics.adjusted_rand_score(labels_true, labels))
print("Adjusted Mutual Information: %0.3f"% metrics.adjusted_mutual_info_score(labels_true, labels))
print("Silhouette Coefficient: %0.3f"% metrics.silhouette_score(X, labels))
# Drawing chart
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle
plt.close('all')
plt.figure(1)
plt.clf()
colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
for k, col in zip(range(n_clusters_), colors):
class_members = labels == k
cluster_center = X[cluster_centers_indices[k]]
plt.plot(X[class_members, 0], X[class_members, 1], col + '.')
plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
markeredgecolor='k', markersize=14)
for x in X[class_members]:
plt.plot([cluster_center[0], x[0]], [cluster_center[1], x[1]], col)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment