Skip to content

Instantly share code, notes, and snippets.

@fxmarty
Created July 25, 2024 14:47
Show Gist options
  • Save fxmarty/74b4463935ada44ef7755a544cae8773 to your computer and use it in GitHub Desktop.
Save fxmarty/74b4463935ada44ef7755a544cae8773 to your computer and use it in GitHub Desktop.
transformers_compile.py
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from transformers.cache_utils import StaticCache
import logging
import time
#model_id = "fxmarty/tiny-llama-fast-tokenizer"
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(
model_id, padding_side="left"
)
tokenizer.pad_token_id = tokenizer.eos_token_id
with torch.device("cuda"):
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
attn_implementation="sdpa",
)
model = model.to("cuda")
inputs = tokenizer(
["I would", "Today I am in Paris and", "I am"], padding=True, return_tensors="pt"
).to(model.device)
new_tokens = 256
gen_config = GenerationConfig(
max_new_tokens=new_tokens,
min_new_tokens=new_tokens,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
num_beams=1,
do_sample=False,
eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
print("----- GENERATE WITHOUT COMPILE")
start = time.perf_counter()
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
end = time.perf_counter()
print(f"Non-compiled generate call took (no warmup): {end - start:.3f} s")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("decoded", decoded)
# torch._logging.set_logs(dynamo=logging.INFO, aot=logging.INFO, inductor=logging.INFO, graph_breaks=True, guards=True, recompiles=True, output_code=True, graph_code=True, graph=True)
print("compiling...")
start = time.perf_counter()
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
end = time.perf_counter()
print(f"Finished compile call: {end - start:.3f} s")
print("----- GENERATE WITH COMPILE")
start = time.perf_counter()
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
end = time.perf_counter()
print(f"First generate call took: {end - start:.3f} s")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("-------- decoded 1", decoded)
start = time.perf_counter()
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
end = time.perf_counter()
print(f"Second generate call took: {end - start:.3f} s")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("-------- decoded 2", decoded)
start = time.perf_counter()
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
end = time.perf_counter()
print(f"Third generate call took: {end - start:.3f} s")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("-------- decoded 3", decoded)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment