Created
August 31, 2021 20:46
-
-
Save makispl/503d43620258ac2c812f2f17964ce19c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def feat_permutation_importance(df, feats, model): | |
""" | |
Takes in a dataframe of 'plays', features list | |
and model, plots and returns the mean score across | |
all the folds | |
Parameters | |
--------- | |
df : a dataframe object | |
Contains the plays | |
feats : a list object | |
Contains the features' columns | |
model : a string object | |
The model name | |
Returns | |
------- | |
importances : pd.Series object | |
Contains the importances of features | |
""" | |
# define the dataset features and target | |
X = df[feats] | |
y = df["gm_cluster"] | |
# initialize Random Forest model | |
model = model | |
model.fit(X, y) | |
# perform permutation importance | |
results = permutation_importance(model, X, y, scoring='f1_weighted') | |
# get importance | |
importance = results.importances_mean | |
idxs = np.argsort(importance) | |
importances = pd.Series(importance, index=feats) | |
# plot feature importance | |
plt.title('Permutation Feature Importance', fontsize=12) | |
plt.barh(range(len(idxs)), importances[idxs], align='center') | |
plt.yticks(range(len(idxs)), [feats[i] for i in idxs]) | |
plt.xlabel('Feature Importance') | |
plt.show() | |
return importances | |
# check for all the features | |
feat_permutation_importance(train_norm_df, train_norm_feats, logres) | |
# re-define norm_feats | |
train_norm_feats.remove('START_POSITION_n') | |
# check for all the features != 'START_POSITION' | |
feat_permutation_importance(train_norm_df, train_norm_feats, logres) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment