Skip to content

Instantly share code, notes, and snippets.

@math-a3k
Created November 8, 2021 21:57
Show Gist options
  • Save math-a3k/4f660fcc7976a63049b92feed77ca759 to your computer and use it in GitHub Desktop.
Save math-a3k/4f660fcc7976a63049b92feed77ca759 to your computer and use it in GitHub Desktop.
Decision Function graphing
from itertools import combinations
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from scipy.special import comb
from sklearn.ensemble import HistGradientBoostingClassifier
matplotlib.use("GTK3Cairo")
np.random.seed(123456)
def trunc_normal(a, b, mean, sd, size):
rs = np.random.normal(mean, sd, size)
rs1 = [r if r > a else a for r in rs]
rs2 = [r if r < b else b for r in rs1]
return rs2
def generate_observation():
data = {}
is_covid19 = bool(np.random.binomial(1, 0.6, 1)[0])
data["is_covid19"] = is_covid19
data["rbc"] = round(trunc_normal(2, 8, 4.5, 2, 1)[0]
if is_covid19 else
trunc_normal(2, 8, 4.05, 2, 1)[0], 2)
#
data["wbc"] = round(trunc_normal(2, 40, 7, 4, 1)[0]
if is_covid19 else
trunc_normal(2, 40, 7.7, 5, 1)[0], 2)
#
data["plt"] = round(trunc_normal(50, 550, 220, 10, 1)[0]
if is_covid19 else
trunc_normal(50, 550, 242, 20, 1)[0])
#
data["lymp"] = round(trunc_normal(0.1, 30, 8, 1, 1)[0]
if is_covid19 else
trunc_normal(0.1, 30, 8, 4, 1)[0], 2)
#
data["neut"] = round(trunc_normal(0.2, 40, 5, 2, 1)[0]
if is_covid19 else
trunc_normal(0.2, 40, 5.5, 2, 1)[0], 2)
return [val for key, val in data.items()]
observations = np.array([generate_observation() for i in range(99)])
labels = observations[:, 0]
data = observations[:, 1:]
cols = ["rbc", "wbc", "plt", "lymp", "neut"]
cols_indexes = [i for i in range(len(cols))]
clf = HistGradientBoostingClassifier()
clf.fit(data, labels)
obs = [3, 3, 150, 0.5, 0.5]
n_graphs = comb(len(cols), 2)
n_graphs_rows = int(np.ceil(n_graphs / 2))
cm = plt.cm.bwr
cm_bright = ListedColormap(['#0000FF', '#FF0000'])
mesh_steps = 500
fig = plt.figure()
for i, pair in enumerate(combinations(cols_indexes, 2)):
obs_x = float(obs[pair[0]])
x_min = np.fmin(np.nanmin(data[:, pair[0]]), obs_x)
x_max = np.fmax(np.nanmax(data[:, pair[0]]), obs_x)
x_rec = x_max - x_min
obs_y = float(obs[pair[1]])
y_min = np.fmin(np.nanmin(data[:, pair[1]]), obs_y)
y_max = np.fmax(np.nanmax(data[:, pair[1]]), obs_y)
y_rec = y_max - y_min
if y_rec > 0 and x_rec > 0:
ax = fig.add_subplot(
n_graphs_rows, (2 if n_graphs > 1 else 1), i + 1,
sharex=None, sharey=None
)
margin_x = x_rec / 20
margin_y = y_rec / 3
axis_x_min, axis_x_max = x_min - margin_x, x_max + margin_x
axis_y_min, axis_y_max = y_min - margin_y, y_max + margin_y
mesh_step_x = (axis_x_max - axis_x_min) / mesh_steps
mesh_step_y = (axis_y_max - axis_y_min) / mesh_steps
xx, yy = np.meshgrid(
np.arange(axis_x_min, axis_x_max, mesh_step_x),
np.arange(axis_y_min, axis_y_max, mesh_step_y)
)
xx_ravel, yy_ravel = xx.ravel(), yy.ravel()
cols_for_stack = []
for i in range(0, len(obs)):
if i == pair[0]:
cols_for_stack.append(xx_ravel)
elif i == pair[1]:
cols_for_stack.append(yy_ravel)
else:
cols_for_stack.append(
np.full(xx_ravel.shape, obs[i])
)
grid = np.column_stack((*cols_for_stack, ))
if hasattr(clf, "decision_function"):
Z = clf.decision_function(grid)
else:
Z = clf.predict_proba(grid)[:, 1]
Z = Z.reshape(xx.shape)
ax.contourf(xx, yy, Z, cmap=cm, alpha=.8)
ax.scatter(data[:, pair[0]], data[:, pair[1]],
c=labels, cmap=cm_bright,
edgecolors=None, alpha=0.6)
# Plot the observation
ax.axvline(obs[pair[0]], c="green")
ax.axhline(obs[pair[1]], c="green")
ax.scatter(obs[pair[0]], obs[pair[1]], c="green",
alpha=1, edgecolors='k')
# Set lims, ticks and labels
ax.set_xlim(axis_x_min, axis_x_max)
ax.set_ylim(axis_y_min, axis_y_max)
ax.tick_params(axis='both', which='major',
labelsize=4, pad=1)
ax.set_xticks((axis_x_min, axis_x_max))
ax.set_yticks((axis_y_min, axis_y_max))
ax.set_xlabel(cols[pair[0]], labelpad=-5, size=6)
ax.set_ylabel(cols[pair[1]], labelpad=-10, size=6)
fig.tight_layout()
# 0 - > Blue, 1 -> Red
print("obs", obs)
print("obs prediction", clf.predict([obs, ]))
print("obs dec fun", clf.decision_function([obs, ])) # <- Blue region
obs_2 = [3, 4, 150, 0.5, 0.5] # <- WBC to 4, should be in red region
print("obs_2", obs_2)
print("obs_2 prediction", clf.predict([obs_2, ]))
print("obs_2 dec fun", clf.decision_function([obs_2, ])) # <- Red region
obs_3 = [4, 3, 150, 0.5, 0.5] # <- RBC to 4, should be in red region
print("obs_3", obs_3)
print("obs_3 prediction", clf.predict([obs_3, ]))
print("obs_3 dec fun", clf.decision_function([obs_3, ])) # <- Blue region, Not Expected
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment