Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created August 7, 2023 08:36
Show Gist options
  • Save tiandiao123/3e7a87010b55ed997e13c332a3573ba0 to your computer and use it in GitHub Desktop.
Save tiandiao123/3e7a87010b55ed997e13c332a3573ba0 to your computer and use it in GitHub Desktop.
import torch
import time
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from argparse import ArgumentParser
from transformers import LlamaForCausalLM, LlamaTokenizer
from inference import CaiInferenceConfig, convert_to_ds_model, recover_from_ds_model
from torch.profiler import profile, record_function, ProfilerActivity
from types import MethodType
from typing import Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, LlamaAttention, LlamaModel, LlamaForCausalLM
from einops import rearrange
from colossalai.logging import get_dist_logger
from inference.policy.attention_helper.llama_attention import _forward_v1
from inference.policy.attention_helper.llama2_attention import _forward_v2
logger = get_dist_logger()
parser = ArgumentParser()
parser.add_argument("--name", default="/data3/users/lcjt/projs/chinese-llama2/pretrained/llama/llama-2-7b-hf", type=str, help="model_name")
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type")
parser.add_argument("--max_tokens", default=2048, type=int, help="maximum tokens used for the text-generation KV-cache")
parser.add_argument("--max_new_tokens", default=128, type=int, help="maximum new tokens to generate")
parser.add_argument("--greedy", default=False, type=bool, help="greedy generation mode")
parser.add_argument("--use_cache", default=True, type=bool, help="use cache for generation")
parser.add_argument("--test_performance", default=True, type=bool , help="enable latency, bandwidth, and throughout testing")
parser.add_argument("--local_rank", type=int, default=0, help="local rank")
parser.add_argument("--kernel_type", type=str, default="triton", choices=["torch", "ds", "triton"], help="kernel implementation")
args = parser.parse_args()
_llama_flash_attention_forward = _forward_v2
def self_defined_tokens(tokenizer):
text = "how is weather today? I want to know the weather of beijing. "
inputs = [text]
input_tokens = tokenizer.batch_encode_plus(inputs, padding = True, return_tensors="pt")
return input_tokens
def print_perf_stats(latency_set, config, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
latency_set.sort()
avg = sum(latency_set) / count
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
if args.dtype == "float16":
num_bytes = 2
elif args.dtype == "float32":
num_bytes = 4
else:
num_bytes = 1
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9))
print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12))
print("Avg Throughput: tokens/s: {}".format((1000/(avg * 1000))))
def _prepare_decoder_flash_attention_mask(self: LlamaModel,
attention_mask: torch.Tensor,
input_shape: Union[torch.Size, Sequence[int]],
inputs_embeds: torch.Tensor,
past_key_values_length: int = 0) -> torch.Tensor:
"""
Prepare attention mask for decoder-only LLM (e.g., Llama) when using flash-attn.
Args:
attention_mask (`torch.Tensor`):
A (bsz, max_len) shape tensor represents mini-batch 2D mask.
input_shape:
inputs_embeds:
past_key_values_length:
Returns:
(`torch.Tensor`):
A mask tensor of shape (bsz, max_len)
"""
return attention_mask
def replace_flash_attention_for_llama(model: Union[torch.nn.Module, LlamaForCausalLM]) -> None:
for module in model.modules():
if isinstance(module, LlamaAttention) is True:
module.forward = MethodType(_llama_flash_attention_forward, module)
logger.info("Replace `LlamaAttention.forward` method.")
if isinstance(module, LlamaModel) is True:
# replace attention mask computation.
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_flash_attention_mask, module)
logger.info("Replace `LlamaModel._prepare_decoder_attention_mask` method.")
def test(use_self_defined_input = False):
tokenizer = LlamaTokenizer.from_pretrained(args.name)
tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(args.name, pad_token_id=tokenizer.eos_token_id)
model = model.half()
print("model config: ", model.config)
replace_flash_attention_for_llama(model)
model.to(torch.cuda.current_device())
if use_self_defined_input is False:
input_tokens={"input_ids":torch.randint(1, 1000, (1, 1024))}
else:
input_tokens = self_defined_tokens(tokenizer)
input_len = 0
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
# print(input_tokens[t].shape)
input_len = input_tokens[t].shape[1]
iters = 10 if args.test_performance else 2 #warmup
print("input token length is " + str(input_len))
times = []
warmup=3
prof_flag = 0
generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False)
for i in range(iters):
if i >= warmup:
prof_flag=1
torch.cuda.synchronize()
start = time.time()
outputs = model.generate(**input_tokens,
**generate_kwargs, early_stopping=False)
torch.cuda.synchronize()
end = time.time()
num_tokens_generation = outputs.shape[1] - input_len
print(num_tokens_generation)
print(f"generation time is {(end - start) * 1000} ms")
time_spend = (end-start)/num_tokens_generation
times.append(time_spend)
print("outputs shape ", outputs.shape)
outputs=tokenizer.batch_decode(outputs)
print(outputs)
if args.local_rank == 0:
if args.test_performance:
print_perf_stats(times, model.config)
with profile(activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
torch.cuda.synchronize()
outputs = model.generate(**input_tokens,
**generate_kwargs)
torch.cuda.synchronize()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
if __name__ == "__main__":
test(False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment