Skip to content

Instantly share code, notes, and snippets.

@fxmarty
Created February 28, 2024 14:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fxmarty/c0d5948d359ba5aef6f2cb25dc19ac8e to your computer and use it in GitHub Desktop.
Save fxmarty/c0d5948d359ba5aef6f2cb25dc19ac8e to your computer and use it in GitHub Desktop.
torch.compile + static cache train benchmark
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from transformers.cache_utils import StaticCache
import time
from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler
import contextlib
import numpy as np
tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
with torch.device("cuda"):
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.float16,
attn_implementation="sdpa",
)
n_runs = 7
inps = {
"input_ids": torch.ones((4, 500), dtype=torch.int64).to("cuda"),
"use_cache": False,
}
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
def run(model, n_runs, inps):
is_compiled = isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
latency_forward = []
latency_backward = []
for i in range(n_runs):
torch.cuda.synchronize()
start = time.time_ns()
res = model(**inps)
torch.cuda.synchronize()
end = time.time_ns()
latency_ms = (end - start) * 1e-6
if i > 3:
latency_forward.append(latency_ms)
print(f"\n- {i}-th call forward latency: {latency_ms:.3f} ms")
loss = res.logits.mean()
torch.cuda.synchronize()
start = time.time_ns()
if False and i > 3:
cm = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
)
else:
cm = contextlib.nullcontext()
with cm as prof:
loss.backward()
if False and i > 3:
if is_compiled:
name = f"trace_backward_compiled_{i}.json"
else:
name = f"trace_backward_eager_{i}.json"
prof.export_chrome_trace(name)
torch.cuda.synchronize()
end = time.time_ns()
latency_ms = (end - start) * 1e-6
if i > 3:
latency_backward.append(latency_ms)
print(f"- {i}-th call backward latency: {latency_ms:.3f} ms")
# step() OOM
#for name, param in model.named_parameters():
# print(name, param.grad.shape)
#torch.cuda.synchronize()
#start = time.time_ns()
#optimizer.step()
#torch.cuda.synchronize()
#end = time.time_ns()
#latency_ms = (end - start) * 1e-6
#print(f"- {i}-th call step latency: {latency_ms:.3f} ms")
model.zero_grad()
latency_forward = np.mean(latency_forward)
latency_backward = np.mean(latency_backward)
return latency_forward, latency_backward
print("--------- WITHOUT TORCH.COMPILE")
latency_forward_eager, latency_backward_eager = run(model, n_runs, inps)
print("compiling...")
torch.cuda.synchronize()
start = time.time_ns()
model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
torch.cuda.synchronize()
end = time.time_ns()
latency_compile_ms = (end - start) * 1e-6
print(f"torch.compile call: {latency_compile_ms:.3f} ms")
print("--------- WITH TORCH.COMPILE")
latency_forward_compile, latency_backward_compile = run(model, n_runs, inps)
print("--------- summary")
print(f"Latency forward (eager): {latency_forward_eager:.3f} ms")
print(f"Latency forward (compile): {latency_forward_compile:.3f} ms")
print(f"Speedup forward: x{latency_forward_eager / latency_forward_compile:.3f}\n")
print(f"Latency backward (eager): {latency_backward_eager:.3f} ms")
print(f"Latency backward (compile): {latency_backward_compile:.3f} ms")
print(f"Speedup backward: x{latency_backward_eager / latency_backward_compile:.3f}\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment