Created
November 17, 2022 16:37
-
-
Save danyaljj/a0acd4b922b35ae8b76edcfb50c78631 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from typing import Iterable | |
from collections import Counter | |
import os | |
import logging | |
import sys | |
import json | |
import click | |
import datasets | |
import numpy as np | |
logger = logging.getLogger(__name__) | |
handler = logging.StreamHandler(sys.stdout) | |
# Match TF logging format | |
handler.setFormatter(logging.Formatter('%(asctime)s.%(msecs)06d: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')) | |
handler.setLevel(logging.INFO) | |
logger.addHandler(handler) | |
logger.setLevel(logging.INFO) | |
METRICS = { | |
'meteor': {'metric': './Meteor1.py'}, | |
'bleurt': {'metric': 'bleurt', 'load_kwargs': {'config_name': 'bleurt-base-128'}}, | |
'rouge': {'metric': 'rouge', | |
'compute_kwargs': {'use_agregator': False, 'use_stemmer': True, 'rouge_types': ['rougeL']}}, | |
'sacrebleu': {'metric': 'sacrebleu'}, | |
'bert_score': {'metric': 'bertscore', 'compute_kwargs': {'lang': 'en'}}, | |
} | |
TRUNCATE_REFS = False | |
REF_PADDING_VALUE = '' | |
def _get_most_common_ref(refs: Iterable[str]): | |
origs = list(refs) | |
normalized = [s.strip().lower() for s in refs] | |
most_common = Counter(normalized).most_common()[0][0] | |
return origs[normalized.index(most_common)] | |
@click.command() | |
@click.option( | |
'--gold_label_file', | |
type=click.File('r'), | |
help='path to the file with gold labels.', | |
required=True, | |
) | |
@click.option( | |
'--prediction_file', | |
type=click.File('r'), | |
help='path to the line-by-line file containing system predictions', | |
required=True, | |
) | |
@click.option( | |
'--output', | |
type=click.File('w'), | |
help='Output results to this file.', | |
required=True, | |
) | |
def main(gold_label_file, prediction_file, output=None, metrics=METRICS): | |
references = json.load(gold_label_file) | |
predictions = json.load(prediction_file) | |
references = {str(k): references[k] for k in references} | |
predictions = {str(k): predictions[k] for k in predictions} | |
scores = evaluator(predictions, references, metrics) | |
if output: | |
json.dump(scores, output) | |
return scores | |
def evaluator(predictions, references, metrics): | |
diff = set(references.keys()) - set(predictions.keys()) | |
if len(diff) > 0: | |
raise ValueError(f'Prediction keys do not cover the references: {diff}') | |
# Clear command line arguments for bleurt | |
sys.argv = [sys.argv[0]] | |
scores = {} | |
for m in metrics: | |
logger.info(f' - - - - - - \n Computing metric {m} with config: {metrics[m]}') | |
print(metrics) | |
print(m) | |
metric_name = metrics[m]['metric'] | |
metric = datasets.load_metric( | |
metric_name, | |
# script_version=os.environ['HF_SCRIPT_VERSION'], | |
**metrics[m].get('load_kwargs', {}) | |
) | |
if m == 'meteor': | |
metric_name = m | |
formatted_refs = list(references.values()) | |
if metric_name == 'sacrebleu': | |
# These metrics need same number of references per prediction | |
if TRUNCATE_REFS: | |
# Truncate to min reference length | |
min_refs = min(len(refs) for refs in formatted_refs) | |
formatted_refs = [refs[:min_refs] for refs in formatted_refs] | |
else: | |
# Pad to max reference length | |
# (Padding SacreBLEU with empty strings matches NLTK: | |
# https://github.com/mjpost/sacreBLEU/issues/28) | |
max_refs = max(len(refs) for refs in formatted_refs) | |
formatted_refs = [ | |
[ | |
refs[i] if i < len(refs) else REF_PADDING_VALUE | |
for i in range(max_refs) | |
] | |
for refs in formatted_refs | |
] | |
if metric_name in ['rouge', 'meteor', 'bleurt']: | |
score = max_metric(metric, formatted_refs, | |
[predictions[k] for k in references], | |
metric_name, | |
**metrics[m].get('compute_kwargs', {})) | |
else: | |
score = metric.compute( | |
predictions=[predictions[k] for k in references], | |
references=formatted_refs, | |
**metrics[m].get('compute_kwargs', {}) | |
) | |
logger.info(f'Computing score from result: {score}') | |
if metric_name == 'rouge': | |
score = score['rougeL'].mid.fmeasure | |
elif metric_name == 'bertscore': | |
score = score['f1'].mean().item() | |
elif metric_name == 'sacrebleu': | |
score = score['score'] / 100 | |
elif metric_name == 'bleurt': | |
score = np.mean(score['scores']) | |
else: | |
score = score[metric_name] | |
logger.info(f'Adding score {m} = {score}') | |
scores[m] = score | |
return scores | |
def max_metric(metric, formatted_references, predictions, metric_name, **compute_kwargs): | |
assert len(formatted_references) == len(predictions) | |
aggregated_scores = [] | |
for r, p in zip(formatted_references, predictions): | |
scores = metric.compute(references=r, predictions=[p for _ in range(len(r))], **compute_kwargs) | |
if metric_name == 'rouge': | |
max_score = max([s.fmeasure for s in scores['rougeL']]) | |
elif metric_name == 'meteor': | |
max_score = max(scores['raw_scores']) | |
elif metric_name == 'bleurt': | |
max_score = max(scores['scores']) | |
else: | |
raise Exception("You can't use this function for metrics other than `rouge`, `meteor` or `bluert`, ...") | |
aggregated_scores.append(max_score) | |
return np.mean(aggregated_scores) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment