Skip to content

Instantly share code, notes, and snippets.

@ArthurZucker
Created June 7, 2024 08:39
Show Gist options
  • Save ArthurZucker/a79018e7642e7ddefe06531407ef8401 to your computer and use it in GitHub Desktop.
Save ArthurZucker/a79018e7642e7ddefe06531407ef8401 to your computer and use it in GitHub Desktop.
Whisper static cache
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor, StaticCache
import torch
import torch._dynamo.config
import torch._inductor.config
import time
from tqdm import tqdm
import logging
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._dynamo.config.cache_size_limit = 32
torch._logging.set_logs(recompiles=True, graph_breaks=True)
# torch.set_float32_matmul_precision('high')
torch.set_printoptions(linewidth=200) # you can better see how the mask is shaped
NUM_TOKENS = 100
NUM_WARMUP = 3
NUM_ITERS = 5
ATTN_IMPLEMENTATION = "sdpa"
MODEL_ID = "openai/whisper-medium.en"
BATCH_SIZES = [1]
torch_device = "cuda:0"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, attn_implementation=ATTN_IMPLEMENTATION)
model.to(torch_device, dtype=torch_dtype)
is_multilingual = getattr(model.generation_config, "is_multilingual", False)
language = "en" if is_multilingual else None
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
sample = dataset[30]["audio"]
inputs = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").to(torch_device)
input_features = inputs.input_features.to(torch_dtype)
all_dynamic_results = {}
all_static_results = {}
from torch.profiler import profile, record_function, ProfilerActivity, \
tensorboard_trace_handler
import datetime
with torch.no_grad():
for _ in range(NUM_ITERS):
torch.cuda.synchronize()
start = datetime.datetime.now()
out = model.generate(input_features, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None)
torch.cuda.synchronize()
print(processor.tokenizer.batch_decode(out))
runtime = datetime.datetime.now() - start
print(f"Inference took {runtime} seconds")
tok_per_s = (out.shape[1] * 1) / runtime.total_seconds()
print(f"Dynamic bsz {1} - {tok_per_s} tok/s")
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
for batch_size in BATCH_SIZES:
# cache = (StaticCache(model.config, batch_size, (NUM_WARMUP + NUM_ITERS) * NUM_TOKENS, device="cuda:0", dtype=torch.float16) , StaticCache(model.config, batch_size, (NUM_WARMUP + NUM_ITERS) * NUM_TOKENS, device="cuda:0", dtype=torch.float16 ))
cache = None
# cache = StaticCache(model.config,batch_size, (NUM_WARMUP + NUM_ITERS)*NUM_TOKENS, device=model.device, dtype=torch.float16)
input_features_batch = input_features.repeat(batch_size, 1, 1)
for _ in range(NUM_WARMUP):
start = datetime.datetime.now()
torch.cuda.synchronize()
out = model.generate(input_features, past_key_values=cache, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None)
torch.cuda.synchronize()
# cache[1].reset()
# cache[0].reset()
print(processor.tokenizer.batch_decode(out))
print(f"Warmup Inference took {datetime.datetime.now() - start} seconds")
# start = time.time()
# with profile(
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# on_trace_ready=tensorboard_trace_handler(f"./tb_logs/tb_{datetime.datetime.now().strftime('%Y-%m-%d_%Hh%Mm%Ss')}"),
# record_shapes=False,
# profile_memory=False,
# with_stack=False
# ) as prof:
# for _ in range(NUM_ITERS):
# with torch.inference_mode():
# torch.cuda.synchronize()
# start = datetime.datetime.now()
# with record_function("generate"):
# out = model.generate(input_features, past_key_values=cache, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None)
# torch.cuda.synchronize()
# print(processor.tokenizer.batch_decode(out))
# print(f"Profiled Inference took {datetime.datetime.now() - start} seconds")
for _ in range(NUM_ITERS):
torch.cuda.synchronize()
start = datetime.datetime.now()
out = model.generate(input_features, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None)
torch.cuda.synchronize()
print(processor.tokenizer.batch_decode(out))
runtime = datetime.datetime.now() - start
print(f"Inference took {runtime} seconds")
tok_per_s = (out.shape[1] * batch_size) / runtime.total_seconds()
print(f"Static bsz {batch_size} - {tok_per_s} tok/s")
# runtime = time.time() - start
# tok_per_s = (NUM_TOKENS * batch_size) / runtime
# all_dynamic_results[batch_size] = tok_per_s
# model.generation_config.cache_implementation = "static"
# model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
# for batch_size in BATCH_SIZES:
# input_features_batch = input_features.repeat(batch_size, 1, 1)
# for _ in range(NUM_WARMUP):
# model.generate(input_features, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None)
# start = time.time()
# for _ in range(NUM_ITERS):
# model.generate(input_features, min_new_tokens=NUM_TOKENS, max_new_tokens=NUM_TOKENS, language=language, begin_suppress_tokens=None, suppress_tokens=None)
# torch.cuda.synchronize()
# runtime = time.time() - start
# tok_per_s = (NUM_TOKENS * batch_size) / runtime
# all_static_results[batch_size] = tok_per_s
# print(f"Static bsz {batch_size} - {tok_per_s} tok/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment