Skip to content

Instantly share code, notes, and snippets.

@recamshak
Last active August 2, 2022 17:21
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save recamshak/f5761d2d7bcd4cb0fe109aba01d8331c to your computer and use it in GitHub Desktop.
Save recamshak/f5761d2d7bcd4cb0fe109aba01d8331c to your computer and use it in GitHub Desktop.
sklearn DBSCAN with O(n) memory budget
from sklearn.datasets import make_blobs
from sklearn.cluster import dbscan
from sklearn.cluster._dbscan_inner import dbscan_inner
from sklearn.metrics import pairwise_distances_chunked
from scipy.sparse import csr_matrix
import numpy as np
# dataset
n = 50000
ds, _ = make_blobs(n, 100, 50)
# dbscan parameters
eps = 20
min_samples = 5
# Build a sparse adjacency matrix. Two samples are adjacent if their euclidiean distance is smaller than `eps`.
# The memory usage can be tuned by adjusting `working_memory` in `pairwise_distances_chunked`.
# See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise_distances_chunked.html
#
# Remark: `indptr` is O(n) so it doesn't need to be stored in file and memmapped but it's convenient to have it
# along with `indices`.
with open("indices.dat", "wb") as f_indices, open("indptr.dat", "wb") as f_indptr:
nnz = 0
f_indptr.write(np.intp(0).tobytes())
for block in pairwise_distances_chunked(ds):
neighbors_indices = block < eps
csr = csr_matrix(neighbors_indices)
f_indices.write(csr.indices.astype(np.intp).tobytes())
f_indptr.write((csr.indptr[1:] + nnz).astype(np.intp).tobytes())
nnz += csr.nnz
indices = np.memmap("indices.dat", np.intp, mode="r")
indptr = np.memmap("indptr.dat", np.intp, mode="r")
# the following is an adaptation of the original dbscan code from sklearn
n_neighbors = np.ediff1d(indptr)
neighborhoods = np.empty(n, dtype=object)
neighborhoods[:] = np.split(indices, indptr[1:-1])
labels = np.full(n, -1, dtype=np.intp)
core_samples = np.asarray(n_neighbors >= min_samples, dtype=np.uint8)
dbscan_inner(core_samples, neighborhoods, labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment