Skip to content

Instantly share code, notes, and snippets.

@ArthurZucker
Created February 8, 2024 06:18
Show Gist options
  • Save ArthurZucker/2dd607c4333ac4c489af30f54a1d8a2d to your computer and use it in GitHub Desktop.
Save ArthurZucker/2dd607c4333ac4c489af30f54a1d8a2d to your computer and use it in GitHub Desktop.
Transformers with torch compile
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache, set_seed
torch.set_printoptions(linewidth=400)
attn_implementation = "sdpa"
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>")
model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-7b-chat-hf",torch_dtype=torch.bfloat16,attn_implementation=attn_implementation).to("cuda:1")
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("Eager dynamic", decoded)
set_seed(0)
model.generation_config.cache_implementation = "static"
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("Static eager",decoded)
set_seed(0)
model._forward = model.forward
compiled_forward = torch.compile(model.forward)
def compiled(func, input_ids, **kwargs):
return func(input_ids, **kwargs)
def call(input_ids, **kwargs):
if input_ids.shape[-1] == 1:
return compiled(compiled_forward, input_ids, **kwargs)
return model._forward(input_ids, **kwargs)
model.forward = call
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("Static compiled",decoded)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment