Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Kernel K-means.
"""Kernel K-means"""
# Author: Mathieu Blondel <>
# License: BSD 3 clause
import numpy as np
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.utils import check_random_state
class KernelKMeans(BaseEstimator, ClusterMixin):
Kernel K-means
Kernel k-means, Spectral Clustering and Normalized Cuts.
Inderjit S. Dhillon, Yuqiang Guan, Brian Kulis.
KDD 2004.
def __init__(self, n_clusters=3, max_iter=50, tol=1e-3, random_state=None,
kernel="linear", gamma=None, degree=3, coef0=1,
kernel_params=None, verbose=0):
self.n_clusters = n_clusters
self.max_iter = max_iter
self.tol = tol
self.random_state = random_state
self.kernel = kernel
self.gamma = gamma = degree
self.coef0 = coef0
self.kernel_params = kernel_params
self.verbose = verbose
def _pairwise(self):
return self.kernel == "precomputed"
def _get_kernel(self, X, Y=None):
if callable(self.kernel):
params = self.kernel_params or {}
params = {"gamma": self.gamma,
"coef0": self.coef0}
return pairwise_kernels(X, Y, metric=self.kernel,
filter_params=True, **params)
def fit(self, X, y=None, sample_weight=None):
n_samples = X.shape[0]
K = self._get_kernel(X)
sw = sample_weight if sample_weight else np.ones(n_samples)
self.sample_weight_ = sw
rs = check_random_state(self.random_state)
self.labels_ = rs.randint(self.n_clusters, size=n_samples)
dist = np.zeros((n_samples, self.n_clusters))
self.within_distances_ = np.zeros(self.n_clusters)
for it in xrange(self.max_iter):
self._compute_dist(K, dist, self.within_distances_,
labels_old = self.labels_
self.labels_ = dist.argmin(axis=1)
# Compute the number of samples whose cluster did not change
# since last iteration.
n_same = np.sum((self.labels_ - labels_old) == 0)
if 1 - float(n_same) / n_samples < self.tol:
if self.verbose:
print "Converged at iteration", it + 1
self.X_fit_ = X
return self
def _compute_dist(self, K, dist, within_distances, update_within):
"""Compute a n_samples x n_clusters distance matrix using the
kernel trick."""
sw = self.sample_weight_
for j in xrange(self.n_clusters):
mask = self.labels_ == j
if np.sum(mask) == 0:
raise ValueError("Empty cluster found, try smaller n_cluster.")
denom = sw[mask].sum()
denomsq = denom * denom
if update_within:
KK = K[mask][:, mask] # K[mask, mask] does not work.
dist_j = np.sum(np.outer(sw[mask], sw[mask]) * KK / denomsq)
within_distances[j] = dist_j
dist[:, j] += dist_j
dist[:, j] += within_distances[j]
dist[:, j] -= 2 * np.sum(sw[mask] * K[:, mask], axis=1) / denom
def predict(self, X):
K = self._get_kernel(X, self.X_fit_)
n_samples = X.shape[0]
dist = np.zeros((n_samples, self.n_clusters))
self._compute_dist(K, dist, self.within_distances_,
return dist.argmin(axis=1)
if __name__ == '__main__':
from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=1000, centers=5, random_state=0)
km = KernelKMeans(n_clusters=5, max_iter=100, random_state=0, verbose=1)
print km.fit_predict(X)[:10]
print km.predict(X[:10])
Copy link

ajilling commented Jun 19, 2019

Is there a way to make this not fail when an empty cluster is found?

Copy link

ravi2k1 commented Nov 15, 2019

How can I use kmeans++ using the same code?

Copy link

lonevetad commented Jan 30, 2020

Important not for everyone:
If You are looking for "gaussian kernel" just pass 'rbf' as "metric".
Sometimes I forget it.

Copy link

arnab-007 commented Aug 27, 2020

Any advice on how to extend this to multi-kernel k-means?

Copy link

mblondel commented Nov 16, 2020

The code was written for Python 2 and you're using Python 3. Replace xrange by range and print ... by print(...).

Copy link

nancychenxizhong commented Aug 24, 2022

@amueller Sorry for late reply (why no notifcations in gists?!). I haven't compared it to other algorithms but I am open to inclusion in scikit-learn if someone wants to work on it.

@mblondel I am happy to work on it. Are you still open to inclusion of this gist to scikit-learn?

Copy link

mblondel commented Aug 24, 2022

Sure, go ahead!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment