Skip to content

Instantly share code, notes, and snippets.

@BioSciEconomist
Last active October 15, 2021 02:05
Show Gist options
  • Save BioSciEconomist/6a2953e639393d924fc7529f5508d5c2 to your computer and use it in GitHub Desktop.
Save BioSciEconomist/6a2953e639393d924fc7529f5508d5c2 to your computer and use it in GitHub Desktop.
Simulate data where SHAP values are not causal
# *-----------------------------------------------------------------
# | PROGRAM NAME: beyond SHAP.py
# | DATE: 10/14/21
# | CREATED BY: MATT BOGARD
# | PROJECT FILE:
# *----------------------------------------------------------------
# | PURPOSE: simulate SHAP values that are not causal
# *----------------------------------------------------------------
# this code is based on: Be Careful When Interpreting Predictive Models in Search of Causal Insights
# by Scot Lundberg see: https://towardsdatascience.com/be-careful-when-interpreting-predictive-models-in-search-of-causal-insights-e68626e664b6
# Original Code: https://shap.readthedocs.io/en/latest/example_notebooks/overviews/Be%20careful%20when%20interpreting%20predictive%20models%20in%20search%20of%20causal%C2%A0insights.html
# see also: https://towardsdatascience.com/explain-your-model-with-the-shap-values-bc36aac4de3d
# https://towardsdatascience.com/shap-explained-the-way-i-wish-someone-explained-it-to-me-ab81cc69ef30
import numpy as np
import pandas as pd
import scipy.stats
import sklearn
import xgboost
import econml
import shap
#
# generate data
#
class FixableDataFrame(pd.DataFrame):
""" Helper class for manipulating generative models.
"""
def __init__(self, *args, fixed={}, **kwargs):
self.__dict__["__fixed_var_dictionary"] = fixed
super(FixableDataFrame, self).__init__(*args, **kwargs)
def __setitem__(self, key, value):
out = super(FixableDataFrame, self).__setitem__(key, value)
if isinstance(key, str) and key in self.__dict__["__fixed_var_dictionary"]:
out = super(FixableDataFrame, self).__setitem__(key, self.__dict__["__fixed_var_dictionary"][key])
return out
# generate the data
def generator(n, fixed={}, seed=0):
""" The generative model for our subscriber retention example.
"""
if seed is not None:
np.random.seed(seed)
X = FixableDataFrame(fixed=fixed)
# the number of sales calls made to this customer
X["Sales calls"] = np.random.uniform(0, 4, size=(n,)).round()
# the number of sales calls made to this customer
X["Interactions"] = X["Sales calls"] + np.random.poisson(0.2, size=(n,))
# the health of the regional economy this customer is a part of
X["Economy"] = np.random.uniform(0, 1, size=(n,))
# the time since the last product upgrade when this customer came up for renewal
X["Last upgrade"] = np.random.uniform(0, 20, size=(n,))
# how much the user perceives that they need the product
X["Product need"] = (X["Sales calls"] * 0.1 + np.random.normal(0, 1, size=(n,)))
# the fractional discount offered to this customer upon renewal
X["Discount"] = ((1-scipy.special.expit(X["Product need"])) * 0.5 + 0.5 * np.random.uniform(0, 1, size=(n,))) / 2
# What percent of the days in the last period was the user actively using the product
X["Monthly usage"] = scipy.special.expit(X["Product need"] * 0.3 + np.random.normal(0, 1, size=(n,)))
# how much ad money we spent per user targeted at this user (or a group this user is in)
X["Ad spend"] = X["Monthly usage"] * np.random.uniform(0.99, 0.9, size=(n,)) + (X["Last upgrade"] < 1) + (X["Last upgrade"] < 2)
# how many bugs did this user encounter in the since their last renewal
X["Bugs faced"] = np.array([np.random.poisson(v*2) for v in X["Monthly usage"]])
# how many bugs did the user report?
X["Bugs reported"] = (X["Bugs faced"] * scipy.special.expit(X["Product need"])).round()
# did the user renew?
X["Did renew"] = scipy.special.expit(7 * (
0.18 * X["Product need"] \
+ 0.08 * X["Monthly usage"] \
+ 0.1 * X["Economy"] \
+ 0.05 * X["Discount"] \
+ 0.05 * np.random.normal(0, 1, size=(n,)) \
+ 0.05 * (1 - X['Bugs faced'] / 20) \
+ 0.005 * X["Sales calls"] \
+ 0.015 * X["Interactions"] \
+ 0.1 / (X["Last upgrade"]/4 + 0.25)
+ X["Ad spend"] * 0.0 - 0.45
))
# in real life we would make a random draw to get either 0 or 1 for if the
# customer did or did not renew. but here we leave the label as the probability
# so that we can get less noise in our plots. Uncomment this line to get
# noiser causal effect lines but the same basic results
X["Did renew"] = scipy.stats.bernoulli.rvs(X["Did renew"])
return X
def user_retention_dataset():
""" The observed data for model training.
"""
n = 10000
X_full = generator(n)
y = X_full["Did renew"]
X = X_full.drop(["Did renew", "Product need", "Bugs faced"], axis=1)
return X, y
#
# fit xgboost model
#
def fit_xgboost(X, y):
""" Train an XGBoost model with early stopping.
"""
X_train,X_test,y_train,y_test = sklearn.model_selection.train_test_split(X, y)
dtrain = xgboost.DMatrix(X_train, label=y_train)
dtest = xgboost.DMatrix(X_test, label=y_test)
model = xgboost.train(
{ "eta": 0.001, "subsample": 0.5, "max_depth": 2, "objective": "reg:logistic"}, dtrain, num_boost_round=200000,
evals=((dtest, "test"),), early_stopping_rounds=20, verbose_eval=False
)
return model
X, y = user_retention_dataset() # define data
model = fit_xgboost(X, y) # fit model
# calculate SHAP values
explainer = shap.Explainer(model)
shap_values = explainer(X)
# plot SHAP values
clust = shap.utils.hclust(X, y, linkage="complete")
shap.plots.bar(shap_values, clustering=clust, clustering_cutoff=1)
# summary plot
shap.summary_plot(shap_values, X)
# show that SHAP gets correlations directionally wrong (based on theoretically simulated values)
shap.plots.scatter(shap_values[:,7]) # bugs reported
shap.plots.scatter(shap_values[:,4]) # discount
shap.plots.scatter(shap_values[:,6]) # ad spend
shap.plots.scatter(shap_values[:,2]) # economy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment