Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Created February 12, 2022 22:00
Show Gist options
  • Save tezansahu/d19d0ebeb0ec0c10e1865957762bffcd to your computer and use it in GitHub Desktop.
Save tezansahu/d19d0ebeb0ec0c10e1865957762bffcd to your computer and use it in GitHub Desktop.
# Wrapper around the wup_measure(...) function to process batch inputs
def batch_wup_measure(labels, preds):
wup_scores = [wup_measure(answer_space[label], answer_space[pred]) for label, pred in zip(labels, preds)]
return np.mean(wup_scores)
# Function to compute all relevant performance metrics, to be passed into the trainer
def compute_metrics(eval_tuple: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
logits, labels = eval_tuple
preds = logits.argmax(axis=-1)
return {
"wups": batch_wup_measure(labels, preds),
"acc": accuracy_score(labels, preds),
"f1": f1_score(labels, preds, average='macro')
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment