Skip to content

Instantly share code, notes, and snippets.

@kbfreder
Created February 11, 2020 23:31
Show Gist options
  • Save kbfreder/8c847d03dd1c18d3348b535117bd1b52 to your computer and use it in GitHub Desktop.
Save kbfreder/8c847d03dd1c18d3348b535117bd1b52 to your computer and use it in GitHub Desktop.
# import function, or patch it:
# Note: may need to install mlxtend
try:
from sklearn.inspection import permutation_importance
except ImportError:
print("Problem importing permutation_importance -- patching")
from mlxtend.evaluate import feature_importance_permutation
def permutation_importance(estimator, X, y, scoring='r2', n_repeats=5):
"""
Use mlxtend function, but give the same interface as the upcoming
function merged into the 0.22.dev branch of scikit-learn.
"""
# match the arguments for the new function to the mlxtend function
means, values = feature_importance_permutation(X.values, y, estimator.predict,
scoring, num_rounds=n_repeats)
return {
'importances': values,
'importances_mean': means
}
# define and fit classifier:
clf = RandomForestClassifier()
clf.fit(train_X, train_y)
# perform permutation importance
perm_importance = permutation_importance(clf, train_X, train_y, scoring='accuracy')
mean_perm_imp = perm_imp['importances_mean']
mean_pi_scaled = mean_perm_imp / mean_perm_imp.sum()
# compile & plot results
feat_imp_df = pd.DataFrame({
'features': train_X.columns,
'model_importances': rf_best.feature_importances_,
'permuted_importances': mean_pi_scaled
})
feat_imp_df.sort_values(by='permuted_importances').plot.barh(x='features')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment