Last active
July 31, 2023 15:06
-
-
Save erip/14f3fe54a4601f2a03e5890773b4d2b1 to your computer and use it in GitHub Desktop.
Scoring translations with HF
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import itertools | |
from argparse import ArgumentParser, FileType | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from transformers import PrefixConstrainedLogitsProcessor, AutoTokenizer, AutoModelForSeq2SeqLM | |
def setup_argparse(): | |
parser = ArgumentParser() | |
parser.add_argument("-t", "--tokenizer", type=str, required=True) | |
parser.add_argument("-m", "--model", type=str, required=True) | |
parser.add_argument("-bs", "--batch-size", type=int, default=16) | |
parser.add_argument("-i", "--input", type=FileType("r"), default="-") | |
parser.add_argument("-o", "--output", type=FileType("w"), default="-") | |
parser.add_argument("-d", "--delimiter", type=str, default="\t") | |
parser.add_argument("--device", type=str, default="cpu") | |
return parser | |
def create_processor_fn(ref_tokens_by_segment): | |
def inner(batch_id, _): | |
return ref_tokens_by_segment[batch_id] | |
return inner | |
def tokenize(src, tgt, tokenizer, num_beams=5): | |
inputs = tokenizer(src, text_target=tgt, padding=True, return_tensors="pt") | |
logit_processor = PrefixConstrainedLogitsProcessor(create_processor_fn(inputs["labels"]), num_beams=num_beams) | |
return inputs, logit_processor | |
def forced_decode(inputs, logit_processor, model, num_beams=5): | |
inputs = inputs.to(model.device) | |
output = model.generate(**inputs, num_beams=num_beams, logits_processor=[logit_processor], return_dict_in_generate=True, output_scores=True) | |
return output.sequences_scores.tolist() | |
def batch_lines(it, batch_size): | |
it = iter(it) | |
item = list(itertools.islice(it, batch_size)) | |
while item: | |
yield item | |
item = list(itertools.islice(it, batch_size)) | |
if __name__ == "__main__": | |
args = setup_argparse().parse_args() | |
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | |
model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(args.model).to(args.device)) | |
# Batch in-order; may not be optimal due to variance in padding but bs is small anyway so minor impact | |
with args.input as fin: | |
inputs = list(batch_lines(map(str.strip, fin), args.batch_size)) | |
# Tokenize in predefined batches and create a "logits processor" en-route which constrains the prefix to be | |
# the reference | |
inputs_logits = [] | |
for batch in tqdm(inputs): | |
src, tgt = zip(*[line.split(args.delimiter) for line in batch]) | |
inputs_logits.append(tokenize(src, tgt, tokenizer)) | |
# forward pass and score extraction | |
with args.output as fout, torch.no_grad(): | |
for input, logit_processor in tqdm(inputs_logits): | |
scores = forced_decode(input, logit_processor, model) | |
print(*scores, sep="\n", file=fout) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment