Skip to content

Instantly share code, notes, and snippets.

@danyaljj
Created February 18, 2021 00:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danyaljj/3f1aedc2a8c46d87eb4d1a9f68df7b5e to your computer and use it in GitHub Desktop.
Save danyaljj/3f1aedc2a8c46d87eb4d1a9f68df7b5e to your computer and use it in GitHub Desktop.
# this file extracts the predictions of several existing summarization systems for XSUM dataset
import json
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
dataset = load_dataset('xsum')
total_len = len(dataset['test'])
batch_size = 16
device = 'cuda'
for modeln in ["google/pegasus-xsum", "facebook/bart-large-xsum"]:
tokenizer = AutoTokenizer.from_pretrained(modeln)
model = AutoModelForSeq2SeqLM.from_pretrained(modeln).to(device)
out_map = {}
for i in tqdm(range(0, int(total_len / batch_size) + 1)):
start = i * batch_size
end = (i + 1) * batch_size
batch_x = dataset['test'][start:end]
summaries = batch_x['summary']
ids = batch_x['id']
batch = tokenizer.prepare_seq2seq_batch(summaries, truncation=True, padding='longest', return_tensors='pt').to(
device)
translated = model.generate(**batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
for id, txt in zip(ids, tgt_text):
out_map[id] = txt
modeln = modeln.replace("/", "_").replace("-", "_")
outfile = open(f"{modeln}.json", "w")
outfile.write(json.dumps(out_map, sort_keys=True, indent=4))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment