-
-
Save micahmelling/8717ddcb190a6d34586cbad7d7deed56 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def produce_shap_values(pipe, x_test, num_features, model_name): | |
""" | |
Produces a plot of SHAP values to identify important features. Work is required to properly extract feature names | |
from the pipeline. Categorical feature names are extracted from the DictVectorizer and combined with the numeric | |
feature names. Features are removed based on the feature_selector routine in the pipeline. Once the final feature | |
names have been isolated, they are used to rename the columns in x_test. | |
:param pipe: scikit-learn pipeline defined in the construct_pipeline function | |
:param x_test: x_test dataframe | |
:param num_features: list of numeric features used for modeling | |
:param model_name: the string name of the model | |
""" | |
# for every feature, grab boolean of if the feature selector kept it | |
support = pipe.named_steps['feature_selector'].get_support() | |
model = pipe.named_steps['model'] | |
# remove model | |
pipe.steps.pop(len(pipe) - 1) | |
# remove feature_selector | |
pipe.steps.pop(len(pipe) - 1) | |
# transform the dataframe with the remaining pipeline | |
x_test = pipe.transform(x_test) | |
x_test = pd.DataFrame(x_test) | |
# extract categorical feature names nested in our pipeline and combine with known numeric feature names | |
dict_vect = pipe.named_steps['preprocessor'].named_transformers_.get('categorical_transformer').named_steps[ | |
'dict_vectorizer'] | |
cat_features = dict_vect.feature_names_ | |
cols_df = pd.DataFrame({'cols': num_features + cat_features, 'support': support}) | |
cols = cols_df['cols'].tolist() | |
# assign column names to our dataframe | |
x_test.columns = cols | |
# drop columns eliminated by our feature selector | |
remove_df = cols_df.loc[cols_df['support'] == False] | |
remove_cols = remove_df['cols'].tolist() | |
x_test.drop(remove_cols, 1, inplace=True) | |
# produce shap values, which needs a model outside a pipeline and a dataframe with column names | |
explainer = shap.TreeExplainer(model) | |
shap_values = explainer.shap_values(x_test) | |
shap.summary_plot(shap_values, x_test, show=False) | |
plt.savefig(os.path.join(DIAGNOSTICS_DIRECTORY, SHAP_VALUES_DIRECTORY, '{}_shap_values'.format(model_name)), | |
bbox_inches='tight') | |
plt.tight_layout() | |
plt.clf() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment