Created
August 31, 2021 20:08
-
-
Save makispl/aea51b0d79649629b28ef41a1fdfd7f4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run_model(df, folds, feats, model): | |
""" | |
Takes in a dataframe of 'plays', the # folds, | |
features list and model, prints and returns | |
the mean score across all the folds | |
Parameters | |
--------- | |
df : a dataframe object | |
Contains the plays | |
folds : int | |
Number of foldsZ | |
feats : a list object | |
Contains the features' columns | |
model : a string object | |
The model name | |
Returns | |
------- | |
np.mean(scores) : float | |
Contains the average score | |
""" | |
scores = [] | |
for fold in range(folds): | |
# get training data using folds | |
df_train = df[df.kfold != fold].reset_index(drop=True) | |
# get validation data using folds | |
df_valid = df[df.kfold == fold].reset_index(drop=True) | |
# get training data | |
x_train = df_train[feats].values | |
# get validation data | |
x_valid = df_valid[feats].values | |
# initialize Logistic Regression model | |
model = model | |
# fit model on training data (ohe) | |
model.fit(x_train, df_train.gm_cluster.values) | |
# predict on validation data | |
valid_preds = model.predict(x_valid) | |
# get f1_weighted score | |
f1 = f1_score(df_valid.gm_cluster.values, valid_preds, average='weighted') | |
# print f1 | |
print(f"Fold = {fold}, F1 = {f1}") | |
# register the score in the score list | |
scores.append(f1) | |
# print total score | |
print(f"Model {model} \n===================\nMean F1 Score = {np.mean(scores)}") | |
return np.mean(scores) | |
# initialize Logistic Regression model | |
logres = LogisticRegression( | |
multi_class='multinomial', | |
solver='lbfgs', | |
n_jobs=-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment