Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Function to plot the decision boundaries of a scikit-learn classification model.
def plot_decision_boundaries(X, y, model_class, **model_params):
"""Function to plot the decision boundaries of a classification model.
This uses just the first two columns of the data for fitting
the model as we need to find the predicted value for every point in
scatter plot.
One possible improvement could be to use all columns fot fitting
and using the first 2 columns and median of all other columns
for predicting.
Adopted from:
reduced_data = X[:, :2]
model = model_class(**model_params), y)
# Step size of the mesh. Decrease to increase the quality of the VQ.
h = .02 # point in the mesh [x_min, m_max]x[y_min, y_max].
# Plot the decision boundary. For that, we will assign a color to each
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
# Obtain labels for each point in mesh using the model.
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
return plt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment