Skip to content

Instantly share code, notes, and snippets.

@ychennay
Last active May 19, 2019 18:04
Show Gist options
  • Save ychennay/7b7143b5707b25d6dcb3292341e819d0 to your computer and use it in GitHub Desktop.
Save ychennay/7b7143b5707b25d6dcb3292341e819d0 to your computer and use it in GitHub Desktop.
import numpy as np
def plot_decision_boundaries(X, y, model_class, bootstrap=False,
x_label=None, y_label=None, title=None, **model_params):
# adapted from https://gist.github.com/anandology/772d44d291a9daa198d4
reduced_data = X[:, :2] # take only the first two feature columns (since we are plotting a 2D contour map)
# we need to recombine the data and target together since we need
# to bootstrap sample from them with replacement
combined_data = np.column_stack((reduced_data, y))
combined_df = pd.DataFrame(combined_data)
bootstrapped_models = []
if bootstrap: # bootstrap sample and fit model B times
for b in range(B):
bootstrap_model = model_class(**model_params) # instantiate a generic model
bootstrap_dataset = combined_df.sample(len(combined_df), replace=True) # sample w/ replacement
bootstrap_X = bootstrap_dataset.iloc[:, :-1] # get all columns except last one
bootstrap_y = bootstrap_dataset.iloc[:,-1] # get last column (target)
bootstrap_model.fit(bootstrap_X, bootstrap_y)
bootstrapped_models.append(bootstrap_model)
else: # single model use case
model = model_class(**model_params)
model.fit(reduced_data, y)
# Step size of the mesh. Decrease to increase the quality of the VQ.
h = .1 # point in the mesh [x_min, m_max]x[y_min, y_max].
# Plot the decision boundary. For that, we will assign a color to each
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
# Obtain labels for each point in mesh using the model.
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
if bootstrap: # majority vote of the B bootstrapped models
Z = np.zeros(xx.shape)
for model in bootstrapped_models:
Z += model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
# average and then round to either 0 or 1
Z = np.around(Z / (len(bootstrapped_models) * 1.0))
else:
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.4, cmap='RdYlGn')
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
clb = plt.colorbar()
clb.set_label('# of Predicted Conversions')
# annotate the chart
if x_label:
plt.xlabel(x_label)
if y_label:
plt.ylabel(y_label)
if title:
plt.title(title)
return plt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment