Skip to content

Instantly share code, notes, and snippets.

@andfanilo
Created September 17, 2020 14:38
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save andfanilo/6bad569e3405c89b6db1df8acf18df0e to your computer and use it in GitHub Desktop.
Save andfanilo/6bad569e3405c89b6db1df8acf18df0e to your computer and use it in GitHub Desktop.
SHAP in Streamlit
import shap
import streamlit as st
import streamlit.components.v1 as components
import xgboost
@st.cache
def load_data():
return shap.datasets.boston()
def st_shap(plot, height=None):
shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
components.html(shap_html, height=height)
st.title("SHAP in Streamlit")
# train XGBoost model
X,y = load_data()
model = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 100)
# explain the model's predictions using SHAP
# (same syntax works for LightGBM, CatBoost, scikit-learn and spark models)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
# visualize the first prediction's explanation (use matplotlib=True to avoid Javascript)
st_shap(shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:]))
# visualize the training set predictions
st_shap(shap.force_plot(explainer.expected_value, shap_values, X), 400)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment