Skip to content

Instantly share code, notes, and snippets.

@anandology
Created March 18, 2016 22:25
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anandology/772d44d291a9daa198d4 to your computer and use it in GitHub Desktop.
Save anandology/772d44d291a9daa198d4 to your computer and use it in GitHub Desktop.
Function to plot the decision boundaries of a scikit-learn classification model.
def plot_decision_boundaries(X, y, model_class, **model_params):
"""Function to plot the decision boundaries of a classification model.
This uses just the first two columns of the data for fitting
the model as we need to find the predicted value for every point in
scatter plot.
One possible improvement could be to use all columns fot fitting
and using the first 2 columns and median of all other columns
for predicting.
Adopted from:
http://scikit-learn.org/stable/auto_examples/ensemble/plot_voting_decision_regions.html
http://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html
"""
reduced_data = X[:, :2]
model = model_class(**model_params)
model.fit(reduced_data, y)
# Step size of the mesh. Decrease to increase the quality of the VQ.
h = .02 # point in the mesh [x_min, m_max]x[y_min, y_max].
# Plot the decision boundary. For that, we will assign a color to each
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
# Obtain labels for each point in mesh using the model.
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
return plt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment