Skip to content

Instantly share code, notes, and snippets.

@mobicham
Created April 15, 2024 19:46
Show Gist options
  • Save mobicham/0e51c9f572721a76a5ac1e06fea533e9 to your computer and use it in GitHub Desktop.
Save mobicham/0e51c9f572721a76a5ac1e06fea533e9 to your computer and use it in GitHub Desktop.
import torch, os
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
os.environ["TOKENIZERS_PARALLELISM"] = "1"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
##########################################################################################################################################################
cache_path = '.'
model_id = "meta-llama/Llama-2-7b-chat-hf"
compute_dtype = torch.float16 #torch.bfloat16 #torch.float16
device = 'cuda:0'
use_flash_attn = False
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, device_map=device, attn_implementation="flash_attention_2" if use_flash_attn else "sdpa")
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig, _build_lazy_trace, make_dynamic_graphed_callable
def compile_llm(m):
config = CompilationConfig.Default()
#config.enable_xformers = True
config.enable_fused_linear_geglu = True
config.enable_jit = True
config.enable_jit_freeze = True
config.enable_triton = True
config.enable_cuda_graph = False
config.prefer_lowp_gemm = True
enable_cuda_graph = True
if config.enable_jit:
lazy_trace_ = _build_lazy_trace(
config,
enable_triton_reshape=enable_cuda_graph,
enable_triton_layer_norm=enable_cuda_graph,
)
m.forward = lazy_trace_(m.forward)
if enable_cuda_graph:
m.forward = make_dynamic_graphed_callable(m.forward)
return m
model = compile_llm(model)
#warm-up
for _ in range(10):
with torch.no_grad():
out = model(torch.ones((1, 1024), dtype=torch.int32, device=device)).logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment