Skip to content

Instantly share code, notes, and snippets.

Last active January 19, 2020 18:10
Show Gist options
  • Save benman1/47d34cd062ad171ade5076025bf321c4 to your computer and use it in GitHub Desktop.
Save benman1/47d34cd062ad171ade5076025bf321c4 to your computer and use it in GitHub Desktop.
plot decision boundaries of clustering or classification methods
import matplotlib.pyplot as plt
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import FastICA
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.base import TransformerMixin
class Reduce(TransformerMixin):
def fit(self, X, y=None):
return self
def transform(self, X):
return X[:, :2]
def plot_decision_boundary(data, kmeans, title='No title', h=.001, model=None, highlight_centroids=False, ica=Reduce()):
'''Based on
data - the dataset to be visualized with clusters
kmeans - the clustering algorithm
title - the title to displayed with the plot
h - Step size of the mesh. Decrease to increase the quality of the VQ.
model - model to re-learn projections in lower-dimensional space. Don't if None.
highlight_centroids - whether to show the centroids (False)
centroids might have little bearing in a different space.
ica - a dimensionality reduction method with fit and transform. This has to result in
two dimensions, e.g. FastICA(n_components=2)
#ica =
reduced_data = ica.fit_transform(data)
if model is not None:
svc =, kmeans.predict(data))
# Plot the decision boundary. For that, we will assign a color to each
x_min, x_max = reduced_data[:, 0].min() - 1e-15, reduced_data[:, 0].max() + 1e-15
y_min, y_max = reduced_data[:, 1].min() - 1e-15, reduced_data[:, 1].max() + 1e-15
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
# Obtain labels for each point in mesh. Use last trained model.
if model:
preds = svc.predict(
np.c_[xx.ravel(), yy.ravel()]
preds = kmeans.predict(
np.c_[xx.ravel(), yy.ravel()]
pred_dict = {predval: i for i, predval in enumerate(np.unique(preds))}
Z = preds.reshape(xx.shape)
Z, interpolation=None,
extent=(xx.min(), xx.max(), yy.min(), yy.max()),,
aspect='auto', origin='lower',
plt.plot(reduced_data[:, 0], reduced_data[:, 1], 'k.', markersize=2)
if highlight_centroids:
# Plot the centroids as a white X
centroids = ica.transform(kmeans.cluster_centers_)
centroids = np.array(
[centroid for i, centroid in enumerate(centroids) if i in pred_dict]
centroids[:, 0], centroids[:, 1],
marker='x', s=169, linewidths=3,
color='w', zorder=10
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment