Skip to content

Instantly share code, notes, and snippets.

@gabrieldernbach
Last active June 6, 2022 17:39
Show Gist options
  • Save gabrieldernbach/29ab74de604ea18991fef8b74d08a0d9 to your computer and use it in GitHub Desktop.
Save gabrieldernbach/29ab74de604ea18991fef8b74d08a0d9 to your computer and use it in GitHub Desktop.
fast non linear clustering on millions of datapoints
import numpy as np
import matplotlib.pyplot as plt
from sklearn.kernel_approximation import Nystroem
from sklearn.cluster import MiniBatchKMeans
# dot in the middle
X = np.random.randn(100, 2)
# circle around
Y = X / np.sqrt((X**2).mean(1, keepdims=True)) * 8
Y = Y + np.random.randn(100, 2)
data = np.concatenate([X, Y], 0)
emb = Nystroem(n_components=100, gamma=0.1).fit_transform(data)
out = MiniBatchKMeans(n_clusters=2).fit_predict(emb)
plt.scatter(*data.T, c=out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment