Skip to content

Instantly share code, notes, and snippets.

@elnikkis
Created March 23, 2022 11:00
Show Gist options
  • Save elnikkis/323e1dc7e5bb17289da089c5324d1cb6 to your computer and use it in GitHub Desktop.
Save elnikkis/323e1dc7e5bb17289da089c5324d1cb6 to your computer and use it in GitHub Desktop.
DBSCANの全距離対計算する版実装
'''Implement DBSCAN clustering algorithm'''
import numpy as np
from scipy.spatial import distance
from scipy.sparse.csgraph import connected_components
def dbscan(X, eps=0.5, minPts=2):
X = np.array(X)
if X.ndim != 2:
raise ValueError('X must be 2d array')
# calculate distance matrix
distances = distance.squareform(distance.pdist(X))
# neighbor graph
adj_matrix = distances < eps
# Identify core points
# 距離行列の各行を見て、eps以下の点がminPts個以上あればコア点
counts = np.count_nonzero(adj_matrix, axis=1)
core_mask = counts >= minPts
# print('core:', core_mask)
# Find the connected components of core points
# コア点に絞り込む
adj_matrix[:, ~core_mask] = False
adj_matrix[~core_mask, :] = False
# print(adj_matrix)
n_components, labels = connected_components(adj_matrix, directed=False)
# print(n_components)
# Assign each non-core point to a nearby cluster
noncore_dists = distances[~core_mask]
# print(noncore_dists)
# 自分自身以外で
ma = np.ma.array(noncore_dists, mask=noncore_dists<=0)
# コア点以外で
ma[:, ~core_mask] = np.ma.masked
# eps以上の点で
ma[noncore_dists >= eps] = np.ma.masked
# 最近傍の点を探す
# print(ma)
# 条件を満たす点が残っているか
min_idxs = ma.argmin(axis=1)
# 条件を満たす点が残っているか
is_noise = np.all(ma.mask, axis=1)
# print(min_idxs)
labels[~core_mask] = np.take(labels, min_idxs)
# mark as noise
nonzero_idx = (~core_mask).nonzero()[0]
noise_idx = nonzero_idx[is_noise]
np.put(labels, noise_idx, -1)
return labels
def make_data():
centers = [[1, 1], [-1, -1], [1, -1]]
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler
X, labels_true = make_blobs(
n_samples=750, centers=centers, cluster_std=0.4, random_state=0
)
X = StandardScaler().fit_transform(X)
return X, labels_true
if __name__ == '__main__':
X = [
[1, 1],
[2, 1],
[1, 2],
[3, 3],
[4, 4],
]
# X, labels_true = make_data()
print(X)
result = dbscan(X, eps=1.1, minPts=2)
# result = dbscan(X, eps=0.3, minPts=10)
print(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment