Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created August 8, 2023 07:55
Show Gist options
  • Save tiandiao123/fdffb0d3fa086e4684dc2a3ed43a04f9 to your computer and use it in GitHub Desktop.
Save tiandiao123/fdffb0d3fa086e4684dc2a3ed43a04f9 to your computer and use it in GitHub Desktop.
import torch
from inference import CaiInferenceConfig, convert_to_ds_model, recover_from_ds_model
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from argparse import ArgumentParser
import time
import torch
from torch.profiler import profile, record_function, ProfilerActivity
parser = ArgumentParser()
parser.add_argument("--name", default="bigscience/bloom-560m", type=str, help="model_name")
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type")
parser.add_argument("--max_tokens", default=1024, type=int, help="maximum tokens used for the text-generation KV-cache")
parser.add_argument("--max_new_tokens", default=128, type=int, help="maximum new tokens to generate")
parser.add_argument("--greedy", default=False, type=bool, help="greedy generation mode")
parser.add_argument("--use_cache", default=True, type=bool, help="use cache for generation")
parser.add_argument("--test_performance", default=True, type=bool , help="enable latency, bandwidth, and throughout testing")
parser.add_argument("--local_rank", type=int, default=0, help="local rank")
parser.add_argument("--kernel_type", type=str, default="triton", choices=["torch", "ds", "triton"], help="kernel implementation")
args = parser.parse_args()
def print_perf_stats(latency_set, config, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
latency_set.sort()
avg = sum(latency_set) / count
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
if args.dtype == "float16":
num_bytes = 2
elif args.dtype == "float32":
num_bytes = 4
else:
num_bytes = 1
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9))
print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12))
# torch.cuda.set_device(7)
tokenizer = AutoTokenizer.from_pretrained(args.name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(args.name, pad_token_id=tokenizer.eos_token_id)
print(f"{model.config}")
model = model.half()
if args.kernel_type in ["ds", "triton"]:
cai_inf_config = CaiInferenceConfig(fp16=True, device=torch.cuda.current_device())
if args.kernel_type == "triton":
cai_inf_config.use_triton = True
model = convert_to_ds_model(model, cai_inf_config).half()
model = model.cuda().to(torch.cuda.current_device())
print(f"torch cuda device {torch.cuda.current_device()}")
text = "Replace me by any text you'd like. we need to work hard to help our nation." + \
"we need to think about something really powerful to strive to achieve our targets" + \
". do you have some idea how to start a busineess? for example, how to start a company? " + \
"If you can share some thoughts, it will be great!!! "
inputs = [text]
input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
input_len=1024
input_tokens = {"input_ids":torch.randint(1, 1000, (args.batch_size, input_len), device=torch.cuda.current_device()),
"attention_mask":torch.ones((args.batch_size, input_len), device=torch.cuda.current_device())}
print("inputs ", input_tokens)
print(input_tokens["input_ids"].shape)
input_len = input_tokens["input_ids"].shape[1]
iters = 10 if args.test_performance else 2 #warmup
times = []
warmup=3
prof_flag = 0
generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False)
for i in range(iters):
if i >= warmup:
prof_flag=1
torch.cuda.synchronize()
start = time.time()
outputs = model.generate(**input_tokens,
**generate_kwargs)
torch.cuda.synchronize()
end = time.time()
out_len = outputs.shape[1]
print("generation time {} s".format(str(end - start)))
times.append((end - start)/(out_len - input_len))
print("outputs, ", len(outputs))
outputs=tokenizer.batch_decode(outputs)
if args.local_rank == 0:
if args.test_performance:
print_perf_stats(times, model.config)
with profile(activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
torch.cuda.synchronize()
outputs = model.generate(**input_tokens,
**generate_kwargs)
torch.cuda.synchronize()
# print("outputs shape, ", outputs.shape)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
prof.export_chrome_trace("trace.json")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment