Skip to content

Instantly share code, notes, and snippets.

@piercelamb
Created December 19, 2022 22:57
Show Gist options
  • Save piercelamb/7ba7112d5ee8d9a9eb1ee42e31f03391 to your computer and use it in GitHub Desktop.
Save piercelamb/7ba7112d5ee8d9a9eb1ee42e31f03391 to your computer and use it in GitHub Desktop.
eval_loop
for step, batch in enumerate(valid_dataloader):
if config.is_comparison:
batch = get_model_specific_batch(batch, model_name)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(-1)
predictions, references = accelerator.gather_for_metrics(
(predictions, batch['labels'])
)
f1_metric.add_batch(
predictions=predictions,
references=references
)
acc_metric.add_batch(
predictions=predictions,
references=references
)
f1_score = f1_metric.compute(average=F1_AVERAGE)['f1']
acc_score = acc_metric.compute()['accuracy']
return acc_score, f1_score
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment