Skip to content

Instantly share code, notes, and snippets.

@vrjkmr
Last active September 1, 2020 18:46
Show Gist options
  • Save vrjkmr/f41ab9ddbbbdbc47285264ff3a0c25a9 to your computer and use it in GitHub Desktop.
Save vrjkmr/f41ab9ddbbbdbc47285264ff3a0c25a9 to your computer and use it in GitHub Desktop.
Plotting decision boundaries
# imports
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
# set seed
seed = 1
np.random.seed(seed)
# generate dummy dataset (500 instances, 2 classes, 2 features)
X, y = make_blobs(n_samples=500, centers=2, n_features=2,
cluster_std=4.5, random_state=seed)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=seed)
# build and train model
model = LogisticRegression()
model.fit(X_train, y_train)
# 1. generate 2D background grid
pad = 0.5
min_x1, max_x1 = np.min(X_train[:, 0]) - pad, np.max(X_train[:, 0]) + pad
min_x2, max_x2 = np.min(X_train[:, 1]) - pad, np.max(X_train[:, 1]) + pad
def generate_grid_points(min_x, max_x, min_y, max_y, resolution=100):
"""Generate resolution * resolution points within a given range."""
xx, yy = np.meshgrid(
np.linspace(min_x, max_x, resolution),
np.linspace(min_y, max_y, resolution)
)
return np.c_[xx.ravel(), yy.ravel()]
# generate a grid of 100 x 100 = 10k points
grid_points = generate_grid_points(min_x1, max_x1, min_x2, max_x2)
# 2. get model's predictions for grid
background = model.predict(grid_points)
# plot data along with decision boundary
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
cmap = "Set1"
# 3. plot grid with predictions (this forms the decision boundary)
ax.scatter(grid_points[:, 0], grid_points[:, 1], c=background,
cmap=cmap, alpha=0.4, s=4)
# 4. plot training and test data
scatter = ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train,
cmap=cmap, marker=".", label="Train")
ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test,
cmap=cmap, marker="x", label="Test")
# other specifications
ax.set_title("Logistic Regression")
ax.set_xlim([min_x1, max_x1])
ax.set_ylim([min_x2, max_x2])
ax.axes.xaxis.set_visible(False)
ax.axes.yaxis.set_visible(False)
legend = ax.legend(loc="best", title="Data")
ax.add_artist(legend)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment