Skip to content

Instantly share code, notes, and snippets.

@fxmarty
Created February 28, 2024 14:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fxmarty/2f2e5d811a2d46afed913e617184748a to your computer and use it in GitHub Desktop.
Save fxmarty/2f2e5d811a2d46afed913e617184748a to your computer and use it in GitHub Desktop.
torch.compile + static cache decoding benchmark
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from transformers.cache_utils import StaticCache
import time
import numpy as np
tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
with torch.device("cuda"):
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.float16,
attn_implementation="sdpa",
)
inputs = tokenizer(
["I would", "Today I am in Paris and"], padding=True, return_tensors="pt"
).to(model.device)
n_runs = 10
inps = {
"input_ids": torch.tensor([[5]], dtype=torch.int64).to("cuda"),
"position_ids": torch.tensor([[501]], dtype=torch.int64).to("cuda"),
"cache_position": torch.tensor([501], dtype=torch.int64).to("cuda"),
"past_key_values": None,
"use_cache": True,
"attention_mask": torch.ones((1, 500), dtype=torch.int64).to("cuda")
}
model._setup_cache(StaticCache, max_batch_size=1, max_cache_len=1000)
def run(model, n_runs, inps):
latency_forward = []
for i in range(n_runs):
torch.cuda.synchronize()
start = time.time_ns()
res = model(**inps)
torch.cuda.synchronize()
end = time.time_ns()
latency_ms = (end - start) * 1e-6
if i > 3:
latency_forward.append(latency_ms)
print(f"\n- {i}-th call latency: {latency_ms:.3f} ms")
latency_forward = np.mean(latency_forward)
return latency_forward
with torch.no_grad():
print("--------- WITHOUT TORCH.COMPILE")
latency_forward_eager = run(model, n_runs, inps)
print("compiling...")
torch.cuda.synchronize()
start = time.time_ns()
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
torch.cuda.synchronize()
end = time.time_ns()
latency_compile_ms = (end - start) * 1e-6
print(f"torch.compile call: {latency_compile_ms:.3f} ms")
print("--------- WITH TORCH.COMPILE")
latency_forward_compile = run(model, n_runs, inps)
print("--------- summary")
print(f"Latency forward (eager): {latency_forward_eager:.3f} ms")
print(f"Latency forward (compile): {latency_forward_compile:.3f} ms")
print(f"Speedup forward: x{latency_forward_eager / latency_forward_compile:.3f}\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment