Created
September 17, 2020 14:38
-
-
Save andfanilo/6bad569e3405c89b6db1df8acf18df0e to your computer and use it in GitHub Desktop.
SHAP in Streamlit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import shap | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import xgboost | |
@st.cache | |
def load_data(): | |
return shap.datasets.boston() | |
def st_shap(plot, height=None): | |
shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>" | |
components.html(shap_html, height=height) | |
st.title("SHAP in Streamlit") | |
# train XGBoost model | |
X,y = load_data() | |
model = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 100) | |
# explain the model's predictions using SHAP | |
# (same syntax works for LightGBM, CatBoost, scikit-learn and spark models) | |
explainer = shap.TreeExplainer(model) | |
shap_values = explainer.shap_values(X) | |
# visualize the first prediction's explanation (use matplotlib=True to avoid Javascript) | |
st_shap(shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])) | |
# visualize the training set predictions | |
st_shap(shap.force_plot(explainer.expected_value, shap_values, X), 400) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment