Skip to content

Instantly share code, notes, and snippets.

@erip
Last active March 4, 2022 22:09
Show Gist options
  • Save erip/bb0843fc1aaccd0ae5d844634cb7675d to your computer and use it in GitHub Desktop.
Save erip/bb0843fc1aaccd0ae5d844634cb7675d to your computer and use it in GitHub Desktop.
Forced decoding with Huggingface Transformers
from transformers import PrefixConstrainedLogitsProcessor
def create_processor_fn(ref_tokens_by_segment):
def inner(batch_id, _):
return ref_tokens_by_segment[batch_id]
return inner
# ...
with tokenizer.as_target_tokenizer():
tgt_encoded = tokenizer(tgt_lines)
logit_processor = PrefixConstrainedLogitsProcessor(create_processor_fn(tgt_encoded["input_ids"]), num_beams=5)
output = model.generate(**inputs, num_beams=5, logits_processor=[logit_processor], return_dict_in_generate=True, output_scores=True)
print(output.sequences_scores)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment