Skip to content

Instantly share code, notes, and snippets.

@dfuller22
Created September 25, 2020 23:56
Show Gist options
  • Save dfuller22/a9827d49ecda72240b084552a8cf7c88 to your computer and use it in GitHub Desktop.
Save dfuller22/a9827d49ecda72240b084552a8cf7c88 to your computer and use it in GitHub Desktop.
def plot_importance(tree, X_train, top_n=10, figsize=(10,10), ax=None):
"""Takes in pre-fit descision tree and the training X data used. Will output
a horizontal bar plot (.plt) of the top 10 (default) features used in said tree."""
## Imports
import pandas as pd
import matplotlib as plt
## Generate feature importances + store into series with correct column names
imps = pd.Series(tree.feature_importances_,index=X_train.columns)
## Sort values s.t. "top_n" importances display in horizontal bar graph
imps.sort_values(ascending=True).tail(top_n).plot(kind='barh',figsize=figsize, ax=ax)
return imps
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment