Created August 7, 2023 08:36
import torch
import time
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from argparse import ArgumentParser
from transformers import LlamaForCausalLM, LlamaTokenizer
from inference import CaiInferenceConfig, convert_to_ds_model, recover_from_ds_model
from torch.profiler import profile, record_function, ProfilerActivity
from types import MethodType
from typing import Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, LlamaAttention, LlamaModel, LlamaForCausalLM
from einops import rearrange
from colossalai.logging import get_dist_logger
from inference.policy.attention_helper.llama_attention import _forward_v1
from inference.policy.attention_helper.llama2_attention import _forward_v2
logger = get_dist_logger()
parser = ArgumentParser()
parser.add_argument("--name", default="/data3/users/lcjt/projs/chinese-llama2/pretrained/llama/llama-2-7b-hf", type=str, help="model_name")
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type")
parser.add_argument("--max_tokens", default=2048, type=int, help="maximum tokens used for the text-generation KV-cache")
parser.add_argument("--max_new_tokens", default=128, type=int, help="maximum new tokens to generate")
parser.add_argument("--greedy", default=False, type=bool, help="greedy generation mode")
parser.add_argument("--use_cache", default=True, type=bool, help="use cache for generation")
parser.add_argument("--test_performance", default=True, type=bool , help="enable latency, bandwidth, and throughout testing")
parser.add_argument("--local_rank", type=int, default=0, help="local rank")
parser.add_argument("--kernel_type", type=str, default="triton", choices=["torch", "ds", "triton"], help="kernel implementation")
args = parser.parse_args()
_llama_flash_attention_forward = _forward_v2
def self_defined_tokens(tokenizer):
text = "how is weather today? I want to know the weather of beijing. "
inputs = [text]
input_tokens = tokenizer.batch_encode_plus(inputs, padding = True, return_tensors="pt")
return input_tokens
def print_perf_stats(latency_set, config, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
avg = sum(latency_set) / count
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
if args.dtype == "float16":
num_bytes = 2
elif args.dtype == "float32":
num_bytes = 4
num_bytes = 1
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9))
print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12))
print("Avg Throughput: tokens/s: {}".format((1000/(avg * 1000))))
def _prepare_decoder_flash_attention_mask(self: LlamaModel,
attention_mask: torch.Tensor,
input_shape: Union[torch.Size, Sequence[int]],
inputs_embeds: torch.Tensor,
past_key_values_length: int = 0) -> torch.Tensor:
Prepare attention mask for decoder-only LLM (e.g., Llama) when using flash-attn.
attention_mask (`torch.Tensor`):
A (bsz, max_len) shape tensor represents mini-batch 2D mask.
A mask tensor of shape (bsz, max_len)
return attention_mask
def replace_flash_attention_for_llama(model: Union[torch.nn.Module, LlamaForCausalLM]) -> None:
for module in model.modules():
if isinstance(module, LlamaAttention) is True:
module.forward = MethodType(_llama_flash_attention_forward, module)"Replace `LlamaAttention.forward` method.")
if isinstance(module, LlamaModel) is True:
# replace attention mask computation.
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_flash_attention_mask, module)"Replace `LlamaModel._prepare_decoder_attention_mask` method.")
def test(use_self_defined_input = False):
tokenizer = LlamaTokenizer.from_pretrained(
tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(, pad_token_id=tokenizer.eos_token_id)
model = model.half()
print("model config: ", model.config)
if use_self_defined_input is False:
input_tokens={"input_ids":torch.randint(1, 1000, (1, 1024))}
input_tokens = self_defined_tokens(tokenizer)
input_len = 0
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
# print(input_tokens[t].shape)
input_len = input_tokens[t].shape[1]
iters = 10 if args.test_performance else 2 #warmup
print("input token length is " + str(input_len))
times = []
prof_flag = 0
generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False)
for i in range(iters):
if i >= warmup:
start = time.time()
outputs = model.generate(**input_tokens,
**generate_kwargs, early_stopping=False)
end = time.time()
num_tokens_generation = outputs.shape[1] - input_len
print(f"generation time is {(end - start) * 1000} ms")
time_spend = (end-start)/num_tokens_generation
print("outputs shape ", outputs.shape)
if args.local_rank == 0:
if args.test_performance:
print_perf_stats(times, model.config)
with profile(activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
outputs = model.generate(**input_tokens,
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
if __name__ == "__main__":
