Skip to content

Instantly share code, notes, and snippets.

@yptheangel
Created October 17, 2022 09:33
Show Gist options
  • Save yptheangel/64ee5ffe067b32825a5b1e7c5c655d16 to your computer and use it in GitHub Desktop.
Save yptheangel/64ee5ffe067b32825a5b1e7c5c655d16 to your computer and use it in GitHub Desktop.
shap_streamlit_xgb
import shap
import streamlit as st
import streamlit.components.v1 as components
import xgboost
import matplotlib.pyplot as plt
@st.cache
def load_data():
return shap.datasets.boston()
def st_shap(plot, height=None):
shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
components.html(shap_html, height=height)
st.title("SHAP in Streamlit")
# train XGBoost model
X,y = load_data()
model = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 100)
# explain the model's predictions using SHAP
# (same syntax works for LightGBM, CatBoost, scikit-learn and spark models)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
# actual plotting
st_shap(shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:]))
st.set_option('deprecation.showPyplotGlobalUse', False)
shap.summary_plot(shap_values, X)
st.pyplot(bbox_inches='tight')
plt.clf()
# shap_object = shap.Explanation(base_values = shap_values[0][0].base_values,
# values = shap_values[0].values,
# feature_names = X.columns,
# data = shap_values[0].data)
# class ShapObject:
# def __init__(self, base_values, data, values, feature_names):
# self.base_values = base_values # Single value
# self.data = data # Raw feature values for 1 row of data
# self.values = values # SHAP values for the same row of data
# self.feature_names = feature_names # Column names
# row = 10
# shap_object = ShapObject(base_values = explainer.expected_value[1],
# values = explainer.shap_values(X)[1][row,:],
# feature_names = X.columns,
# data = X.iloc[row,:])
# shap_object = ShapObject(base_values = shap_values[0][0].base_values,
# values = shap_values[0].values,
# feature_names = X.columns,
# data = shap_values[0].data)
# shap.plots.waterfall_plot(shap_object)
# shap.plots.waterfall(shap_object)
# shap.plots.waterfall(shap_values[0])
# shap.plots._waterfall.waterfall_legacy(explainer.expected_value, shap_values)
# st.pyplot(bbox_inches='tight')
# plt.clf()
# visualize the training set predictions
st_shap(shap.force_plot(explainer.expected_value, shap_values, X), 400)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment