Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Created October 26, 2020 17:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sshleifer/d36089da7ef102b6106a59233863adc3 to your computer and use it in GitHub Desktop.
Save sshleifer/d36089da7ef102b6106a59233863adc3 to your computer and use it in GitHub Desktop.
Timing Generate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time
from tqdm import tqdm
from pathlib import Path
import pandas as pd
models = ['sshleifer/distilbart-cnn-12-3',
'sshleifer/distilbart-cnn-12-6',
'sshleifer/distilbart-cnn-6-6',
'sshleifer/distilbart-xsum-1-1',
'sshleifer/distilbart-xsum-12-1',
'sshleifer/distilbart-xsum-12-3',
'sshleifer/distilbart-xsum-12-6',
'sshleifer/distilbart-xsum-6-6',
'sshleifer/distilbart-xsum-9-6',
'facebook/bart-large-cnn', 'facebook/bart-large-xsum']
pegs = ['sshleifer/distill-pegasus-cnn-16-4',
'sshleifer/distill-pegasus-xsum-16-4',
'sshleifer/distill-pegasus-xsum-16-8',
'google/pegasus-xsum', 'sshleifer/pegasus-cnn-ft-v2'
]
def time_generate(mname, batch, fp16=True, **gen_kwargs):
model = AutoModelForSeq2SeqLM.from_pretrained(mname).to(torch_device)
if fp16:
model = model.half()
start_time = time.time()
model.generate(**batch, **gen_kwargs)
tbatch = time.time() - start_time
bb1 = batch.input_ids[:1]
start_time = time.time()
model.generate(bb1, **gen_kwargs)
t1 = time.time() - start_time
return dict(m=mname, t1=t1, tbatch=tbatch)
BS=16
peg_tok = AutoTokenizer.from_pretrained('google/pegasus-xsum')
bart_tok = AutoTokenizer.from_pretrained('facebook/bart-large')
data = Path('xsum/val.source').open().read().split('\n')[:BS]
torch_device = 'cuda'
# num_beams=8,6 are published defaults.
peg_batch = peg_tok.prepare_seq2seq_batch(data, max_length=512, return_tensors='pt', num_beams=8).to(torch_device)
bart_batch = bart_tok.prepare_seq2seq_batch(data, max_length=512, return_tensors='pt', num_beams=6).to(torch_device)
times = [time_generate(mname, bart_batch, fp16=True) for mname in tqdm(models)]
timesp2 = [time_generate(mname, peg_batch, fp16=False) for mname in tqdm(pegs)]
all_times = pd.concat([pd.DataFrame(timesp2), pd.DataFrame(times)]).sort_values('tbatch', ascending=False)
all_times.to_csv('generation_timings.csv', index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment