Skip to content

Instantly share code, notes, and snippets.

@RobGeada
Last active November 29, 2021 15:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RobGeada/e273456121e70708f17394a33e802b87 to your computer and use it in GitHub Desktop.
Save RobGeada/e273456121e70708f17394a33e802b87 to your computer and use it in GitHub Desktop.
Generate candlestick plot from PmmlShapExplainerTest output
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams.update({
"axes.facecolor": "DDDDDD",
"axes.edgecolor": "white",
"axes.grid": True,
"axes.axisbelow": True,
"legend.facecolor": "ffffff",
"grid.color": "white",
"grid.linestyle": "-",
"savefig.bbox": "tight"
})
feature_map = {
"AGE":"Age",
"AMT_INCOME_TOTAL":"Income",
"CNT_CHILDREN":"# Children",
"DAYS_EMPLOYED":"Days Employed",
"FLAG_OWN_REALTY": "Own Realty?",
"FLAG_WORK_PHONE": "Work Phone?",
"FLAG_OWN_CAR": "Own Car?"
}
# vvvvvvvvvvvvvvvvvvv copy and paste output from PmmlShapExplainerTest here vvvvvvvvvvvvvvvv
background_feature_means = [44.55, 74340.0, 2.01, 2478.0, 0.54, 0.46, 0.39]
features = {
"AGE": 57.000000,
"AMT_INCOME_TOTAL": 111000.000000,
"CNT_CHILDREN": 4.000000,
"DAYS_EMPLOYED": 3700.000000,
"FLAG_OWN_REALTY": 1.000000,
"FLAG_WORK_PHONE": 0.000000,
"FLAG_OWN_CAR": 1.000000,
}
prediction = 0.658974
shap_values = {
"AGE": (0.058538, 0.002538),
"AMT_INCOME_TOTAL": (0.110987, 0.002538),
"CNT_CHILDREN": (-0.019712, 0.002538),
"DAYS_EMPLOYED": (0.049493, 0.002538),
"FLAG_OWN_REALTY": (-0.002034, 0.002538),
"FLAG_WORK_PHONE": (0.002363, 0.002538),
"FLAG_OWN_CAR": (-0.006651, 0.006216),
}
fnull=[0.46598961291984575]
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# plot setup
fig = plt.figure(figsize=(9,9),dpi=200)
ax = plt.gca()
pos = fnull[0]
barheight=.41
bar_colors = []
upwards = max(prediction, fnull[0])
downwards = min(prediction, fnull[0])
xlabels = []
colors = ['#ee0000', "#316DC1"]
# plot each shap value
for i,(k, (sv, conf)) in enumerate(shap_values.items()):
# plot bar
bar = ax.bar(i, height=sv, width=.84, bottom=pos, label="{}={:.0f}".format(feature_map[k], features[k]), color='#ee0000' if sv<0 else "#316DC1")
bar_colors.append(bar.patches[0].get_facecolor())
# place shap value onto plot
if sv > 0:
ax.text(i, sv+pos+.001, "+{:.2f}%".format(sv*100), fontsize=8, fontweight='bold', color=bar_colors[-1], horizontalalignment='center')
else:
ax.text(i, sv+pos-.001, "{:.2f}%".format(sv*100), fontsize=8, fontweight='bold', color=bar_colors[-1], horizontalalignment='center',verticalalignment='top')
xlabels.append("{}\n={}".format(feature_map[k], features[k]))
pos+=sv
# place joints across bars at top/bottom for clarity
pos = fnull[0]
for i,(k, (sv, conf)) in enumerate(list(shap_values.items())[:-1]):
ax.plot((i-barheight, i+.5),(pos+sv, pos+sv), color=bar_colors[i])
ax.plot((i+.5, i+1+barheight), (pos+sv, pos+sv), color=bar_colors[i+1])
pos+=sv
# place horizontal line at prediction
ax.axhline(prediction, linewidth=1, color='k',alpha=.25)
ax.text(-1+.1, prediction+.001, "Prediction={:.2f}%".format(prediction*100), fontsize=10, horizontalalignment='left')
# place horizontal line at background value
ax.axhline(fnull[0], linewidth=1, color='k',alpha=.25)
ax.text(7-.1, fnull[0]+.001, "Background={:.2f}%".format(fnull[0]*100), fontsize=10, horizontalalignment="right")
# label plot
plt.title("Per-Feature SHAP Values, Loan Approval")
plt.ylabel("Probability of Loan Approval")
b, t = .45, .7
plt.ylim(b,t)
plt.xlim(-1,7)
xticks = np.linspace(b,t,11)
plt.yticks(xticks, ["{:.1f}%".format(x*100) for x in xticks], fontsize=8)
plt.xticks(np.arange(0,7),xlabels, fontsize=8)
for ticklabel, (sv,_) in zip(plt.gca().get_xticklabels(), list(shap_values.values())):
ticklabel.set_color(colors[0] if sv<0 else colors[1])
# plot and save
plt.savefig("loan_shap_values")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment