Created
November 8, 2021 21:57
-
-
Save math-a3k/4f660fcc7976a63049b92feed77ca759 to your computer and use it in GitHub Desktop.
Decision Function graphing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import combinations | |
import numpy as np | |
import matplotlib | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import ListedColormap | |
from scipy.special import comb | |
from sklearn.ensemble import HistGradientBoostingClassifier | |
matplotlib.use("GTK3Cairo") | |
np.random.seed(123456) | |
def trunc_normal(a, b, mean, sd, size): | |
rs = np.random.normal(mean, sd, size) | |
rs1 = [r if r > a else a for r in rs] | |
rs2 = [r if r < b else b for r in rs1] | |
return rs2 | |
def generate_observation(): | |
data = {} | |
is_covid19 = bool(np.random.binomial(1, 0.6, 1)[0]) | |
data["is_covid19"] = is_covid19 | |
data["rbc"] = round(trunc_normal(2, 8, 4.5, 2, 1)[0] | |
if is_covid19 else | |
trunc_normal(2, 8, 4.05, 2, 1)[0], 2) | |
# | |
data["wbc"] = round(trunc_normal(2, 40, 7, 4, 1)[0] | |
if is_covid19 else | |
trunc_normal(2, 40, 7.7, 5, 1)[0], 2) | |
# | |
data["plt"] = round(trunc_normal(50, 550, 220, 10, 1)[0] | |
if is_covid19 else | |
trunc_normal(50, 550, 242, 20, 1)[0]) | |
# | |
data["lymp"] = round(trunc_normal(0.1, 30, 8, 1, 1)[0] | |
if is_covid19 else | |
trunc_normal(0.1, 30, 8, 4, 1)[0], 2) | |
# | |
data["neut"] = round(trunc_normal(0.2, 40, 5, 2, 1)[0] | |
if is_covid19 else | |
trunc_normal(0.2, 40, 5.5, 2, 1)[0], 2) | |
return [val for key, val in data.items()] | |
observations = np.array([generate_observation() for i in range(99)]) | |
labels = observations[:, 0] | |
data = observations[:, 1:] | |
cols = ["rbc", "wbc", "plt", "lymp", "neut"] | |
cols_indexes = [i for i in range(len(cols))] | |
clf = HistGradientBoostingClassifier() | |
clf.fit(data, labels) | |
obs = [3, 3, 150, 0.5, 0.5] | |
n_graphs = comb(len(cols), 2) | |
n_graphs_rows = int(np.ceil(n_graphs / 2)) | |
cm = plt.cm.bwr | |
cm_bright = ListedColormap(['#0000FF', '#FF0000']) | |
mesh_steps = 500 | |
fig = plt.figure() | |
for i, pair in enumerate(combinations(cols_indexes, 2)): | |
obs_x = float(obs[pair[0]]) | |
x_min = np.fmin(np.nanmin(data[:, pair[0]]), obs_x) | |
x_max = np.fmax(np.nanmax(data[:, pair[0]]), obs_x) | |
x_rec = x_max - x_min | |
obs_y = float(obs[pair[1]]) | |
y_min = np.fmin(np.nanmin(data[:, pair[1]]), obs_y) | |
y_max = np.fmax(np.nanmax(data[:, pair[1]]), obs_y) | |
y_rec = y_max - y_min | |
if y_rec > 0 and x_rec > 0: | |
ax = fig.add_subplot( | |
n_graphs_rows, (2 if n_graphs > 1 else 1), i + 1, | |
sharex=None, sharey=None | |
) | |
margin_x = x_rec / 20 | |
margin_y = y_rec / 3 | |
axis_x_min, axis_x_max = x_min - margin_x, x_max + margin_x | |
axis_y_min, axis_y_max = y_min - margin_y, y_max + margin_y | |
mesh_step_x = (axis_x_max - axis_x_min) / mesh_steps | |
mesh_step_y = (axis_y_max - axis_y_min) / mesh_steps | |
xx, yy = np.meshgrid( | |
np.arange(axis_x_min, axis_x_max, mesh_step_x), | |
np.arange(axis_y_min, axis_y_max, mesh_step_y) | |
) | |
xx_ravel, yy_ravel = xx.ravel(), yy.ravel() | |
cols_for_stack = [] | |
for i in range(0, len(obs)): | |
if i == pair[0]: | |
cols_for_stack.append(xx_ravel) | |
elif i == pair[1]: | |
cols_for_stack.append(yy_ravel) | |
else: | |
cols_for_stack.append( | |
np.full(xx_ravel.shape, obs[i]) | |
) | |
grid = np.column_stack((*cols_for_stack, )) | |
if hasattr(clf, "decision_function"): | |
Z = clf.decision_function(grid) | |
else: | |
Z = clf.predict_proba(grid)[:, 1] | |
Z = Z.reshape(xx.shape) | |
ax.contourf(xx, yy, Z, cmap=cm, alpha=.8) | |
ax.scatter(data[:, pair[0]], data[:, pair[1]], | |
c=labels, cmap=cm_bright, | |
edgecolors=None, alpha=0.6) | |
# Plot the observation | |
ax.axvline(obs[pair[0]], c="green") | |
ax.axhline(obs[pair[1]], c="green") | |
ax.scatter(obs[pair[0]], obs[pair[1]], c="green", | |
alpha=1, edgecolors='k') | |
# Set lims, ticks and labels | |
ax.set_xlim(axis_x_min, axis_x_max) | |
ax.set_ylim(axis_y_min, axis_y_max) | |
ax.tick_params(axis='both', which='major', | |
labelsize=4, pad=1) | |
ax.set_xticks((axis_x_min, axis_x_max)) | |
ax.set_yticks((axis_y_min, axis_y_max)) | |
ax.set_xlabel(cols[pair[0]], labelpad=-5, size=6) | |
ax.set_ylabel(cols[pair[1]], labelpad=-10, size=6) | |
fig.tight_layout() | |
# 0 - > Blue, 1 -> Red | |
print("obs", obs) | |
print("obs prediction", clf.predict([obs, ])) | |
print("obs dec fun", clf.decision_function([obs, ])) # <- Blue region | |
obs_2 = [3, 4, 150, 0.5, 0.5] # <- WBC to 4, should be in red region | |
print("obs_2", obs_2) | |
print("obs_2 prediction", clf.predict([obs_2, ])) | |
print("obs_2 dec fun", clf.decision_function([obs_2, ])) # <- Red region | |
obs_3 = [4, 3, 150, 0.5, 0.5] # <- RBC to 4, should be in red region | |
print("obs_3", obs_3) | |
print("obs_3 prediction", clf.predict([obs_3, ])) | |
print("obs_3 dec fun", clf.decision_function([obs_3, ])) # <- Blue region, Not Expected | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment