Skip to content

Instantly share code, notes, and snippets.

@eustlb
Last active June 20, 2024 10:09
Show Gist options
  • Save eustlb/ef06f00858cbae4d8743f5024be869ec to your computer and use it in GitHub Desktop.
Save eustlb/ef06f00858cbae4d8743f5024be869ec to your computer and use it in GitHub Desktop.
import os
import time
import pickle
from tqdm import tqdm
import torch
from transformers import WhisperForConditionalGeneration
def benchmark_gen(
model_name,
dtype,
attn_implementation,
batch_sizes,
n_measures,
n_tokens,
num_beams,
return_timestamps,
):
model = WhisperForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=dtype,
attn_implementation=attn_implementation,
)
model.to("cuda:0", dtype=dtype)
num_mel_bins = model.config.num_mel_bins
gen_kwargs = {
"return_timestamps": return_timestamps,
"num_beams": num_beams,
"top_k": 0,
"min_new_tokens": n_tokens,
"max_new_tokens": n_tokens
}
tokens_per_sec = []
for batch_size in tqdm(batch_sizes):
inputs = torch.randn(
(batch_size, num_mel_bins, 3000),
dtype=dtype,
device="cuda:0"
)
n_warmup = 3
for _ in range(n_warmup):
_ = model.generate(inputs, **gen_kwargs)
times = []
for _ in range(n_measures):
start_time = time.time()
_ = model.generate(inputs, **gen_kwargs)
gen_time = time.time() - start_time
times.append(gen_time)
avg_gen_time = sum(times) / len(times)
n_generated_tokens = n_tokens * batch_size
tokens_per_sec.append(n_generated_tokens / avg_gen_time)
return tokens_per_sec
def main():
model_names = [
"eustlb/distil-large-v3-fr",
"openai/whisper-large-v3",
"openai/whisper-tiny",
"openai/whisper-base",
"openai/whisper-small",
"openai/whisper-medium"
]
batch_sizes = [2**i for i in range(8)]
dtype = "bfloat16"
dtype = getattr(torch, dtype)
attn_implementation = "sdpa"
n_measures = 5
n_tokens = 128
num_beams = 1
return_timestamps = True
pbar = tqdm(model_names)
for model_name in pbar:
pbar.set_description(f"benchmarking {model_name}")
tokens_per_sec = benchmark_gen(
model_name,
dtype,
attn_implementation,
batch_sizes,
n_measures,
n_tokens,
num_beams,
return_timestamps,
)
os.makedirs("benchmark-results", exist_ok=True)
file_name = f"gen_{model_name.replace('/', '_')}.pkl"
file_path = os.path.join("benchmark-results", file_name)
bs_to_time = {}
bs_to_time['batch_sizes'] = batch_sizes
bs_to_time['tokens_per_sec'] = tokens_per_sec
bs_to_time['attn_implementation'] = attn_implementation
bs_to_time['dtype'] = dtype
bs_to_time['n_measures'] = n_measures
bs_to_time['n_tokens'] = n_tokens
bs_to_time['num_beams'] = num_beams
bs_to_time['return_timestamps'] = return_timestamps
print(file_name)
with open(file_path, 'wb') as file:
pickle.dump(bs_to_time, file)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment