Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Last active March 5, 2022 21:10
Show Gist options
  • Save tezansahu/e0046405854c0ee398b55d237ee27088 to your computer and use it in GitHub Desktop.
Save tezansahu/e0046405854c0ee398b55d237ee27088 to your computer and use it in GitHub Desktop.
def trainModel(model_params, training_args, run, run_name):
# Log the training & evaluation datasets as CSV files
run.log_dataset(train_df, data_slice=mlf.DataSlice.TRAIN, fileformat=mlf.FileFormat.CSV)
run.log_dataset(eval_df, data_slice=mlf.DataSlice.TEST, fileformat=mlf.FileFormat.CSV)
# Log the model specifications and the training hyperparameters as parameters for the run
run.log_params({**model_params, **training_args})
training_args['output_dir'] = os.path.join('outputs', run_name)
training_args['overwrite_output_dir'] = True
model = ClassificationModel(
model_params['model_type'],
model_params['model_name'],
num_labels=3,
args=training_args
)
print("Training the model...")
model.train_model(train_df)
def f1_multiclass(labels, preds):
return f1_score(labels, preds, average='micro')
print("Evaluating the model...")
result, model_outputs, wrong_predictions = model.eval_model(eval_df, f1=f1_multiclass, acc=accuracy_score)
# Log the performance metrics
run.log_metrics(result)
# Function to convert integer labels back to actual sentiments
def labelToSentiment(label):
if label == 0:
return "negative"
elif label == 1:
return "neutral"
else:
return "positive"
# Create the evaluation dataset, along with model predictions, to be logged
eval_df_toLog = pd.DataFrame({
"headline": eval_df.text,
"sentiment": [labelToSentiment(label) for label in eval_df.labels.to_list()],
"prediction": [labelToSentiment(np.argmax(x)) for x in model_outputs]
})
# Log the stats for the evaluation data
run.log_dataset_stats(
eval_df_toLog,
data_slice=mlf.DataSlice.TEST,
data_schema=mlf.Schema(
feature_column_names=['headline'],
prediction_column_name='prediction',
actual_column_name='sentiment'
),
model_type=mlf.ModelType.MULTICLASS_CLASSIFICATION
)
return model, result, model_outputs, wrong_predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment