Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Generate candlestick plot from PmmlShapExplainerTest output
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams.update({
"axes.facecolor": "DDDDDD",
"axes.edgecolor": "white",
"axes.grid": True,
"axes.axisbelow": True,
"legend.facecolor": "ffffff",
"grid.color": "white",
"grid.linestyle": "-",
"savefig.bbox": "tight"
})
feature_map = {
"AGE":"Age",
"AMT_INCOME_TOTAL":"Income",
"CNT_CHILDREN":"# Children",
"DAYS_EMPLOYED":"Days Employed",
"FLAG_OWN_REALTY": "Own Realty?",
"FLAG_WORK_PHONE": "Work Phone?",
"FLAG_OWN_CAR": "Own Car?"
}
# vvvvvvvvvvvvvvvvvvv copy and paste output from PmmlShapExplainerTest here vvvvvvvvvvvvvvvv
background_feature_means = [44.55, 74340.0, 2.01, 2478.0, 0.54, 0.46, 0.39]
features = {
"AGE": 57.000000,
"AMT_INCOME_TOTAL": 111000.000000,
"CNT_CHILDREN": 4.000000,
"DAYS_EMPLOYED": 3700.000000,
"FLAG_OWN_REALTY": 1.000000,
"FLAG_WORK_PHONE": 0.000000,
"FLAG_OWN_CAR": 1.000000,
}
prediction = 0.658974
shap_values = {
"AGE": (0.058538, 0.002538),
"AMT_INCOME_TOTAL": (0.110987, 0.002538),
"CNT_CHILDREN": (-0.019712, 0.002538),
"DAYS_EMPLOYED": (0.049493, 0.002538),
"FLAG_OWN_REALTY": (-0.002034, 0.002538),
"FLAG_WORK_PHONE": (0.002363, 0.002538),
"FLAG_OWN_CAR": (-0.006651, 0.006216),
}
fnull=[0.46598961291984575]
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# plot setup
fig = plt.figure(figsize=(9,9),dpi=200)
ax = plt.gca()
pos = fnull[0]
barheight=.41
bar_colors = []
upwards = max(prediction, fnull[0])
downwards = min(prediction, fnull[0])
xlabels = []
colors = ['#ee0000', "#316DC1"]
# plot each shap value
for i,(k, (sv, conf)) in enumerate(shap_values.items()):
# plot bar
bar = ax.bar(i, height=sv, width=.84, bottom=pos, label="{}={:.0f}".format(feature_map[k], features[k]), color='#ee0000' if sv<0 else "#316DC1")
bar_colors.append(bar.patches[0].get_facecolor())
# place shap value onto plot
if sv > 0:
ax.text(i, sv+pos+.001, "+{:.2f}%".format(sv*100), fontsize=8, fontweight='bold', color=bar_colors[-1], horizontalalignment='center')
else:
ax.text(i, sv+pos-.001, "{:.2f}%".format(sv*100), fontsize=8, fontweight='bold', color=bar_colors[-1], horizontalalignment='center',verticalalignment='top')
xlabels.append("{}\n={}".format(feature_map[k], features[k]))
pos+=sv
# place joints across bars at top/bottom for clarity
pos = fnull[0]
for i,(k, (sv, conf)) in enumerate(list(shap_values.items())[:-1]):
ax.plot((i-barheight, i+.5),(pos+sv, pos+sv), color=bar_colors[i])
ax.plot((i+.5, i+1+barheight), (pos+sv, pos+sv), color=bar_colors[i+1])
pos+=sv
# place horizontal line at prediction
ax.axhline(prediction, linewidth=1, color='k',alpha=.25)
ax.text(-1+.1, prediction+.001, "Prediction={:.2f}%".format(prediction*100), fontsize=10, horizontalalignment='left')
# place horizontal line at background value
ax.axhline(fnull[0], linewidth=1, color='k',alpha=.25)
ax.text(7-.1, fnull[0]+.001, "Background={:.2f}%".format(fnull[0]*100), fontsize=10, horizontalalignment="right")
# label plot
plt.title("Per-Feature SHAP Values, Loan Approval")
plt.ylabel("Probability of Loan Approval")
b, t = .45, .7
plt.ylim(b,t)
plt.xlim(-1,7)
xticks = np.linspace(b,t,11)
plt.yticks(xticks, ["{:.1f}%".format(x*100) for x in xticks], fontsize=8)
plt.xticks(np.arange(0,7),xlabels, fontsize=8)
for ticklabel, (sv,_) in zip(plt.gca().get_xticklabels(), list(shap_values.values())):
ticklabel.set_color(colors[0] if sv<0 else colors[1])
# plot and save
plt.savefig("loan_shap_values")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment