Skip to content

Instantly share code, notes, and snippets.

@danyaljj
Created November 17, 2022 16:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danyaljj/a0acd4b922b35ae8b76edcfb50c78631 to your computer and use it in GitHub Desktop.
Save danyaljj/a0acd4b922b35ae8b76edcfb50c78631 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
from typing import Iterable
from collections import Counter
import os
import logging
import sys
import json
import click
import datasets
import numpy as np
logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stdout)
# Match TF logging format
handler.setFormatter(logging.Formatter('%(asctime)s.%(msecs)06d: %(message)s', datefmt='%Y-%m-%d %H:%M:%S'))
handler.setLevel(logging.INFO)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
METRICS = {
'meteor': {'metric': './Meteor1.py'},
'bleurt': {'metric': 'bleurt', 'load_kwargs': {'config_name': 'bleurt-base-128'}},
'rouge': {'metric': 'rouge',
'compute_kwargs': {'use_agregator': False, 'use_stemmer': True, 'rouge_types': ['rougeL']}},
'sacrebleu': {'metric': 'sacrebleu'},
'bert_score': {'metric': 'bertscore', 'compute_kwargs': {'lang': 'en'}},
}
TRUNCATE_REFS = False
REF_PADDING_VALUE = ''
def _get_most_common_ref(refs: Iterable[str]):
origs = list(refs)
normalized = [s.strip().lower() for s in refs]
most_common = Counter(normalized).most_common()[0][0]
return origs[normalized.index(most_common)]
@click.command()
@click.option(
'--gold_label_file',
type=click.File('r'),
help='path to the file with gold labels.',
required=True,
)
@click.option(
'--prediction_file',
type=click.File('r'),
help='path to the line-by-line file containing system predictions',
required=True,
)
@click.option(
'--output',
type=click.File('w'),
help='Output results to this file.',
required=True,
)
def main(gold_label_file, prediction_file, output=None, metrics=METRICS):
references = json.load(gold_label_file)
predictions = json.load(prediction_file)
references = {str(k): references[k] for k in references}
predictions = {str(k): predictions[k] for k in predictions}
scores = evaluator(predictions, references, metrics)
if output:
json.dump(scores, output)
return scores
def evaluator(predictions, references, metrics):
diff = set(references.keys()) - set(predictions.keys())
if len(diff) > 0:
raise ValueError(f'Prediction keys do not cover the references: {diff}')
# Clear command line arguments for bleurt
sys.argv = [sys.argv[0]]
scores = {}
for m in metrics:
logger.info(f' - - - - - - \n Computing metric {m} with config: {metrics[m]}')
print(metrics)
print(m)
metric_name = metrics[m]['metric']
metric = datasets.load_metric(
metric_name,
# script_version=os.environ['HF_SCRIPT_VERSION'],
**metrics[m].get('load_kwargs', {})
)
if m == 'meteor':
metric_name = m
formatted_refs = list(references.values())
if metric_name == 'sacrebleu':
# These metrics need same number of references per prediction
if TRUNCATE_REFS:
# Truncate to min reference length
min_refs = min(len(refs) for refs in formatted_refs)
formatted_refs = [refs[:min_refs] for refs in formatted_refs]
else:
# Pad to max reference length
# (Padding SacreBLEU with empty strings matches NLTK:
# https://github.com/mjpost/sacreBLEU/issues/28)
max_refs = max(len(refs) for refs in formatted_refs)
formatted_refs = [
[
refs[i] if i < len(refs) else REF_PADDING_VALUE
for i in range(max_refs)
]
for refs in formatted_refs
]
if metric_name in ['rouge', 'meteor', 'bleurt']:
score = max_metric(metric, formatted_refs,
[predictions[k] for k in references],
metric_name,
**metrics[m].get('compute_kwargs', {}))
else:
score = metric.compute(
predictions=[predictions[k] for k in references],
references=formatted_refs,
**metrics[m].get('compute_kwargs', {})
)
logger.info(f'Computing score from result: {score}')
if metric_name == 'rouge':
score = score['rougeL'].mid.fmeasure
elif metric_name == 'bertscore':
score = score['f1'].mean().item()
elif metric_name == 'sacrebleu':
score = score['score'] / 100
elif metric_name == 'bleurt':
score = np.mean(score['scores'])
else:
score = score[metric_name]
logger.info(f'Adding score {m} = {score}')
scores[m] = score
return scores
def max_metric(metric, formatted_references, predictions, metric_name, **compute_kwargs):
assert len(formatted_references) == len(predictions)
aggregated_scores = []
for r, p in zip(formatted_references, predictions):
scores = metric.compute(references=r, predictions=[p for _ in range(len(r))], **compute_kwargs)
if metric_name == 'rouge':
max_score = max([s.fmeasure for s in scores['rougeL']])
elif metric_name == 'meteor':
max_score = max(scores['raw_scores'])
elif metric_name == 'bleurt':
max_score = max(scores['scores'])
else:
raise Exception("You can't use this function for metrics other than `rouge`, `meteor` or `bluert`, ...")
aggregated_scores.append(max_score)
return np.mean(aggregated_scores)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment