Skip to content

Instantly share code, notes, and snippets.

@micahmelling
Last active July 14, 2020 02:07
Show Gist options
  • Save micahmelling/8717ddcb190a6d34586cbad7d7deed56 to your computer and use it in GitHub Desktop.
Save micahmelling/8717ddcb190a6d34586cbad7d7deed56 to your computer and use it in GitHub Desktop.
def produce_shap_values(pipe, x_test, num_features, model_name):
"""
Produces a plot of SHAP values to identify important features. Work is required to properly extract feature names
from the pipeline. Categorical feature names are extracted from the DictVectorizer and combined with the numeric
feature names. Features are removed based on the feature_selector routine in the pipeline. Once the final feature
names have been isolated, they are used to rename the columns in x_test.
:param pipe: scikit-learn pipeline defined in the construct_pipeline function
:param x_test: x_test dataframe
:param num_features: list of numeric features used for modeling
:param model_name: the string name of the model
"""
# for every feature, grab boolean of if the feature selector kept it
support = pipe.named_steps['feature_selector'].get_support()
model = pipe.named_steps['model']
# remove model
pipe.steps.pop(len(pipe) - 1)
# remove feature_selector
pipe.steps.pop(len(pipe) - 1)
# transform the dataframe with the remaining pipeline
x_test = pipe.transform(x_test)
x_test = pd.DataFrame(x_test)
# extract categorical feature names nested in our pipeline and combine with known numeric feature names
dict_vect = pipe.named_steps['preprocessor'].named_transformers_.get('categorical_transformer').named_steps[
'dict_vectorizer']
cat_features = dict_vect.feature_names_
cols_df = pd.DataFrame({'cols': num_features + cat_features, 'support': support})
cols = cols_df['cols'].tolist()
# assign column names to our dataframe
x_test.columns = cols
# drop columns eliminated by our feature selector
remove_df = cols_df.loc[cols_df['support'] == False]
remove_cols = remove_df['cols'].tolist()
x_test.drop(remove_cols, 1, inplace=True)
# produce shap values, which needs a model outside a pipeline and a dataframe with column names
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(x_test)
shap.summary_plot(shap_values, x_test, show=False)
plt.savefig(os.path.join(DIAGNOSTICS_DIRECTORY, SHAP_VALUES_DIRECTORY, '{}_shap_values'.format(model_name)),
bbox_inches='tight')
plt.tight_layout()
plt.clf()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment