Skip to content

Instantly share code, notes, and snippets.

@erip
Last active July 31, 2023 15:06
Show Gist options
  • Save erip/14f3fe54a4601f2a03e5890773b4d2b1 to your computer and use it in GitHub Desktop.
Save erip/14f3fe54a4601f2a03e5890773b4d2b1 to your computer and use it in GitHub Desktop.
Scoring translations with HF
#!/usr/bin/env python3
import itertools
from argparse import ArgumentParser, FileType
import torch
import numpy as np
from tqdm import tqdm
from transformers import PrefixConstrainedLogitsProcessor, AutoTokenizer, AutoModelForSeq2SeqLM
def setup_argparse():
parser = ArgumentParser()
parser.add_argument("-t", "--tokenizer", type=str, required=True)
parser.add_argument("-m", "--model", type=str, required=True)
parser.add_argument("-bs", "--batch-size", type=int, default=16)
parser.add_argument("-i", "--input", type=FileType("r"), default="-")
parser.add_argument("-o", "--output", type=FileType("w"), default="-")
parser.add_argument("-d", "--delimiter", type=str, default="\t")
parser.add_argument("--device", type=str, default="cpu")
return parser
def create_processor_fn(ref_tokens_by_segment):
def inner(batch_id, _):
return ref_tokens_by_segment[batch_id]
return inner
def tokenize(src, tgt, tokenizer, num_beams=5):
inputs = tokenizer(src, text_target=tgt, padding=True, return_tensors="pt")
logit_processor = PrefixConstrainedLogitsProcessor(create_processor_fn(inputs["labels"]), num_beams=num_beams)
return inputs, logit_processor
def forced_decode(inputs, logit_processor, model, num_beams=5):
inputs = inputs.to(model.device)
output = model.generate(**inputs, num_beams=num_beams, logits_processor=[logit_processor], return_dict_in_generate=True, output_scores=True)
return output.sequences_scores.tolist()
def batch_lines(it, batch_size):
it = iter(it)
item = list(itertools.islice(it, batch_size))
while item:
yield item
item = list(itertools.islice(it, batch_size))
if __name__ == "__main__":
args = setup_argparse().parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(args.model).to(args.device))
# Batch in-order; may not be optimal due to variance in padding but bs is small anyway so minor impact
with args.input as fin:
inputs = list(batch_lines(map(str.strip, fin), args.batch_size))
# Tokenize in predefined batches and create a "logits processor" en-route which constrains the prefix to be
# the reference
inputs_logits = []
for batch in tqdm(inputs):
src, tgt = zip(*[line.split(args.delimiter) for line in batch])
inputs_logits.append(tokenize(src, tgt, tokenizer))
# forward pass and score extraction
with args.output as fout, torch.no_grad():
for input, logit_processor in tqdm(inputs_logits):
scores = forced_decode(input, logit_processor, model)
print(*scores, sep="\n", file=fout)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment