Skip to content

Instantly share code, notes, and snippets.

@allanj
Created July 8, 2021 04:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save allanj/62be8bffd610ec60105a3101f8910931 to your computer and use it in GitHub Desktop.
Save allanj/62be8bffd610ec60105a3101f8910931 to your computer and use it in GitHub Desktop.
Fairseq Generation
import torch
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained(
'model_files/bart-large-model',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='data/cloze_replace_all-bin'
)
bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('data/cloze_replace_all/val.source', encoding='utf-8') as source, open('data/cloze_replace_all/val.hypo', 'w', encoding='utf-8') as fout:
sline = source.readline().strip()
slines = [sline]
for sline in source:
if count % bsz == 0:
with torch.no_grad():
hypotheses_batch = bart.sample(slines, beam=1, max_len_b=20)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
slines = []
slines.append(sline.strip())
count += 1
if slines != []:
hypotheses_batch = bart.sample(slines, beam=1, max_len_b=20)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment