Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Last active March 5, 2022 20:39
Show Gist options
  • Save tezansahu/ea2b2771c534440807e8a608b614fc9e to your computer and use it in GitHub Desktop.
Save tezansahu/ea2b2771c534440807e8a608b614fc9e to your computer and use it in GitHub Desktop.
def trainModel(model_params, training_args, run, run_name):
training_args['output_dir'] = os.path.join('outputs', run_name)
training_args['overwrite_output_dir'] = True
model = ClassificationModel(
model_params['model_type'],
model_params['model_name'],
num_labels=3,
args=training_args
)
print("Training the model...")
model.train_model(train_df)
def f1_multiclass(labels, preds):
return f1_score(labels, preds, average='micro')
print("Evaluating the model...")
result, model_outputs, wrong_predictions = model.eval_model(eval_df, f1=f1_multiclass, acc=accuracy_score)
return model, result, model_outputs, wrong_predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment