Last active
July 12, 2017 07:27
-
-
Save ericjster/6847183 to your computer and use it in GitHub Desktop.
Change for scikit-learn example of dbscan clustering.
Improve performance of plot_dbscan.py by minimizing calls to plot.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
=================================== | |
Demo of DBSCAN clustering algorithm | |
=================================== | |
Finds core samples of high density and expands clusters from them. | |
""" | |
print(__doc__) | |
import numpy as np | |
from sklearn.cluster import DBSCAN | |
from sklearn import metrics | |
from sklearn.datasets.samples_generator import make_blobs | |
from sklearn.preprocessing import StandardScaler | |
import time | |
############################################################################## | |
# Generate sample data | |
centers = [[1, 1], [-1, -1], [1, -1]] | |
X, labels_true = make_blobs(n_samples=750, centers=centers, cluster_std=0.4, | |
random_state=0) | |
X = StandardScaler().fit_transform(X) | |
############################################################################## | |
# Compute DBSCAN | |
db = DBSCAN(eps=0.3, min_samples=10).fit(X) | |
core_samples = db.core_sample_indices_ | |
core_samples_mask = np.zeros_like(db.labels_,dtype=bool) | |
core_samples_mask[db.core_sample_indices_] = True | |
labels = db.labels_ | |
# Number of clusters in labels, ignoring noise if present. | |
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) | |
print('Estimated number of clusters: %d' % n_clusters_) | |
print("Homogeneity: %0.3f" % metrics.homogeneity_score(labels_true, labels)) | |
print("Completeness: %0.3f" % metrics.completeness_score(labels_true, labels)) | |
print("V-measure: %0.3f" % metrics.v_measure_score(labels_true, labels)) | |
print("Adjusted Rand Index: %0.3f" | |
% metrics.adjusted_rand_score(labels_true, labels)) | |
print("Adjusted Mutual Information: %0.3f" | |
% metrics.adjusted_mutual_info_score(labels_true, labels)) | |
print("Silhouette Coefficient: %0.3f" | |
% metrics.silhouette_score(X, labels)) | |
############################################################################## | |
# Plot result | |
import pylab as pl | |
begTime = time.clock() | |
# Black removed and is used for noise instead. | |
unique_labels = set(labels) | |
colors = pl.cm.Spectral(np.linspace(0, 1, len(unique_labels))) | |
for k, col in zip(unique_labels, colors): | |
if k == -1: | |
# Black used for noise. | |
col = 'k' | |
if 0: | |
# | |
# Call plot two times for each class (label), | |
# to distinguish between the core and non-core members by marker size. | |
# Use numpy boolean arrays to create the subset passed to pl.plot. | |
# | |
# For 200 elements this takes .2 sec. | |
# For 2000 elements this takes .2 sec. | |
class_member_mask = labels==k | |
xy = X[class_member_mask & core_samples_mask] | |
pl.plot(xy[:,0], xy[:,1], 'o', markerfacecolor=col, | |
markeredgecolor='k', markersize=14) | |
xy = X[class_member_mask & np.logical_not(core_samples_mask)] | |
pl.plot(xy[:,0], xy[:,1], 'o', markerfacecolor=col, | |
markeredgecolor='k', markersize=6) | |
if 1: | |
# For 200 elements this takes .2 sec. | |
# For 2000 elements this takes .2 sec. | |
class_members = set(index[0] for index in np.argwhere(labels == k)) | |
cluster_core_samples = set(index for index in core_samples | |
if labels[index] == k) | |
xy = X[list(class_members.intersection(cluster_core_samples))] | |
pl.plot(xy[:,0], xy[:,1], 'o', markerfacecolor=col, | |
markeredgecolor='k', markersize=14) | |
xy = X[list(class_members.difference(cluster_core_samples))] | |
pl.plot(xy[:,0], xy[:,1], 'o', markerfacecolor=col, | |
markeredgecolor='k', markersize=6) | |
if 0: | |
# The cost per call to plot is expensive, so better to pass a numpy array. | |
# For 200 elements this takes 0.9 sec. | |
# For 2000 elements this takes 7.7 sec. | |
class_members = [index[0] for index in np.argwhere(labels == k)] | |
cluster_core_samples = [index for index in core_samples | |
if labels[index] == k] | |
for index in class_members: | |
x = X[index] | |
if index in core_samples and k != -1: | |
markersize = 14 | |
else: | |
markersize = 6 | |
pl.plot(x[0], x[1], 'o', markerfacecolor=col, | |
markeredgecolor='k', markersize=markersize) | |
endTime = time.clock() | |
print "Time: %.1f sec" % (endTime-begTime) | |
pl.title('Estimated number of clusters: %d' % n_clusters_) | |
pl.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment