Skip to content

Instantly share code, notes, and snippets.

@makispl
Created August 31, 2021 20:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save makispl/503d43620258ac2c812f2f17964ce19c to your computer and use it in GitHub Desktop.
Save makispl/503d43620258ac2c812f2f17964ce19c to your computer and use it in GitHub Desktop.
def feat_permutation_importance(df, feats, model):
"""
Takes in a dataframe of 'plays', features list
and model, plots and returns the mean score across
all the folds
Parameters
---------
df : a dataframe object
Contains the plays
feats : a list object
Contains the features' columns
model : a string object
The model name
Returns
-------
importances : pd.Series object
Contains the importances of features
"""
# define the dataset features and target
X = df[feats]
y = df["gm_cluster"]
# initialize Random Forest model
model = model
model.fit(X, y)
# perform permutation importance
results = permutation_importance(model, X, y, scoring='f1_weighted')
# get importance
importance = results.importances_mean
idxs = np.argsort(importance)
importances = pd.Series(importance, index=feats)
# plot feature importance
plt.title('Permutation Feature Importance', fontsize=12)
plt.barh(range(len(idxs)), importances[idxs], align='center')
plt.yticks(range(len(idxs)), [feats[i] for i in idxs])
plt.xlabel('Feature Importance')
plt.show()
return importances
# check for all the features
feat_permutation_importance(train_norm_df, train_norm_feats, logres)
# re-define norm_feats
train_norm_feats.remove('START_POSITION_n')
# check for all the features != 'START_POSITION'
feat_permutation_importance(train_norm_df, train_norm_feats, logres)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment