Last active
December 10, 2021 07:43
-
-
Save math-a3k/572fa2e9228fbbe2b3fa8471250e0327 to your computer and use it in GitHub Desktop.
Conditional Decision Function Graphing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapted from https://gist.github.com/math-a3k/4f660fcc7976a63049b92feed77ca759 | |
from itertools import combinations | |
import numpy as np | |
import matplotlib | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import ListedColormap | |
from scipy.special import comb | |
from sklearn.ensemble import HistGradientBoostingClassifier | |
matplotlib.use("GTK3Cairo") | |
np.random.seed(123456) | |
dataset_1 = np.column_stack(( | |
np.full(60, 1), | |
np.random.normal(4.5, 2, 60), | |
np.random.normal(7, 4, 60), | |
np.random.normal(220, 10, 60), | |
np.random.normal(8, 1, 60), | |
np.random.normal(5, 2, 60), | |
)) | |
dataset_0 = np.column_stack(( | |
np.full(40, 0), | |
np.random.normal(4.05, 2, 40), | |
np.random.normal(7.7, 5, 40), | |
np.random.normal(242, 20, 40), | |
np.random.normal(8, 4, 40), | |
np.random.normal(5.5, 2, 40), | |
)) | |
dataset = np.row_stack((dataset_0, dataset_1)) | |
labels = dataset[:, 0] | |
data = dataset[:, 1:] | |
cols = ["x1", "x2", "x3", "x4", "x5"] | |
cols_indexes = [i for i in range(len(cols))] | |
clf = HistGradientBoostingClassifier() | |
clf.fit(data, labels) | |
obs = [3, 3, 150, 0.5, 0.5] | |
n_graphs = comb(len(cols), 2) | |
n_graphs_rows = int(np.ceil(n_graphs / 2)) | |
cm = plt.cm.bwr | |
cm_bright = ListedColormap(['#0000FF', '#FF0000']) | |
mesh_steps = 200 | |
fig = plt.figure() | |
gridspec = fig.add_gridspec(n_graphs_rows, (2 if n_graphs > 1 else 1)) | |
for graph_index, pair in enumerate(combinations(cols_indexes, 2)): | |
obs_x = float(obs[pair[0]]) | |
x_min = np.minimum(np.min(data[:, pair[0]]), obs_x) | |
x_max = np.maximum(np.max(data[:, pair[0]]), obs_x) | |
x_rec = x_max - x_min | |
obs_y = float(obs[pair[1]]) | |
y_min = np.minimum(np.min(data[:, pair[1]]), obs_y) | |
y_max = np.maximum(np.max(data[:, pair[1]]), obs_y) | |
y_rec = y_max - y_min | |
margin_x = x_rec / 20 | |
margin_y = y_rec / 3 | |
axis_x_min, axis_x_max = x_min - margin_x, x_max + margin_x | |
axis_y_min, axis_y_max = y_min - margin_y, y_max + margin_y | |
mesh_step_x = (axis_x_max - axis_x_min) / mesh_steps | |
mesh_step_y = (axis_y_max - axis_y_min) / mesh_steps | |
xx, yy = np.meshgrid( | |
np.arange(axis_x_min, axis_x_max, mesh_step_x), | |
np.arange(axis_y_min, axis_y_max, mesh_step_y) | |
) | |
xx_ravel, yy_ravel = xx.ravel(), yy.ravel() | |
cols_for_stack = [] | |
for i in range(0, len(obs)): | |
if i == pair[0]: | |
cols_for_stack.append(xx_ravel) | |
elif i == pair[1]: | |
cols_for_stack.append(yy_ravel) | |
else: | |
cols_for_stack.append( | |
np.full(xx_ravel.shape, obs[i]) | |
) | |
grid = np.column_stack((*cols_for_stack, )) | |
if hasattr(clf, "decision_function"): | |
Z0 = clf.decision_function(grid) | |
Z1 = clf.predict(grid) | |
else: | |
Z = clf.predict_proba(grid)[:, 1] | |
Z0 = Z0.reshape(xx.shape) | |
Z1 = Z1.reshape(xx.shape) | |
sub_gridspec = gridspec[graph_index].subgridspec(1, 2) | |
for graph_sub_index in range(2): | |
ax = fig.add_subplot( | |
sub_gridspec[0, graph_sub_index] | |
) | |
if graph_sub_index == 0: | |
ax.contourf(xx, yy, Z0, cmap=cm, alpha=.8) | |
else: | |
ax.contourf(xx, yy, Z1, cmap=cm, alpha=.8) | |
ax.scatter(data[:, pair[0]], data[:, pair[1]], | |
c=labels, cmap=cm_bright, | |
edgecolors=None, alpha=0.6) | |
# Plot the observation | |
ax.axvline(obs[pair[0]], c="green") | |
ax.axhline(obs[pair[1]], c="green") | |
ax.scatter(obs[pair[0]], obs[pair[1]], c="green", | |
alpha=1, edgecolors='k') | |
# Set lims, ticks and labels | |
ax.set_xlim(axis_x_min, axis_x_max) | |
ax.set_ylim(axis_y_min, axis_y_max) | |
ax.tick_params(axis='both', which='major', | |
labelsize=4, pad=1) | |
ax.set_xticks((axis_x_min, axis_x_max)) | |
ax.set_yticks((axis_y_min, axis_y_max)) | |
ax.set_xlabel(cols[pair[0]], labelpad=-5, size=6) | |
ax.set_ylabel(cols[pair[1]], labelpad=-10, size=6) | |
fig.tight_layout() | |
# 0 - > Blue, 1 -> Red | |
print("obs", obs) | |
print("obs prediction", clf.predict([obs, ])) # <- RED | |
print("obs dec fun", clf.decision_function([obs, ])) # <- RED region | |
obs_2 = [3, 3, 250, 0.5, 0.5] | |
# x3 to 250 should be BLUE, as the dec func is blue in all planes with x3 in that value | |
print("obs_2", obs_2) | |
print("obs_2 prediction", clf.predict([obs_2, ])) # <- BLUE, Expected | |
print("obs_2 dec fun", clf.decision_function([obs_2, ])) # <- Blue region, Expected | |
obs_3 = [3, 3, 150, 22, 0.5] | |
# x4 to 22, three planes in BLUE, one in RED, expected BLUE | |
print("obs_3", obs_3) | |
print("obs_3 prediction", clf.predict([obs_3, ])) # <- RED, Not Expected | |
print("obs_3 dec fun", clf.decision_function([obs_3, ])) # <- RED region, Not Expected | |
plt.show() | |
# Concerns | |
# - The decision_function method is not consistent with the predict method, it | |
# should have the same "direction"/color/class despite the intensity | |
# - Modifying a variable in one direction should be consistent in every graph, | |
# i.e. if with x4 = 22 is red, all graphs with x4 should be red at 22 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment