Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Last active January 27, 2024 21:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save richardliaw/716958495c250fe377a65f562f2d6229 to your computer and use it in GitHub Desktop.
Save richardliaw/716958495c250fe377a65f562f2d6229 to your computer and use it in GitHub Desktop.
diff --git a/examples/run.py b/examples/run.py
index 9a05434..d7946db 100644
--- a/examples/run.py
+++ b/examples/run.py
@@ -135,6 +135,12 @@ def parse_arguments(args=None):
choices=["hf", "nemo"],
help="The source of lora checkpoint.")
+ parser.add_argument(
+ '--run_profiling',
+ default=False,
+ action='store_true',
+ help="Run several 10 iterations to profile the inference latencies.")
+
parser.add_argument(
'--num_prepend_vtokens',
nargs="+",
@@ -158,7 +164,11 @@ def parse_input(tokenizer,
pad_id = tokenizer.pad_token_id
batch_input_ids = []
- if input_file is None:
+ print(input_text)
+ if input_text[0].startswith("DUMMY"):
+ batch_size = input_text[0].split("_")[1] if "_" in input_text else 1
+ batch_input_ids = [np.zeros(max_input_length)] * batch_size
+ elif input_file is None:
for curr_text in input_text:
if prompt_template is not None:
curr_text = prompt_template.format(input_text=curr_text)
@@ -229,6 +239,7 @@ def print_output(tokenizer,
output_end = sequence_lengths[batch_idx][beam]
outputs = output_ids[batch_idx][beam][
output_begin:output_end].tolist()
+ print("Length Output", output_end - output_begin)
output_text = tokenizer.decode(outputs)
print(
f'Output [Text {batch_idx} Beam {beam}]: \"{output_text}\"')
@@ -270,6 +281,7 @@ def print_output(tokenizer,
def main(args):
+ import tensorrt_llm
runtime_rank = tensorrt_llm.mpi_rank()
logger.set_level(args.log_level)
@@ -328,29 +340,33 @@ def main(args):
max_beam_width=args.num_beams,
max_attention_window_size=args.max_attention_window_size)
runner = runner_cls.from_dir(**runner_kwargs)
-
- with torch.no_grad():
- outputs = runner.generate(
- batch_input_ids,
- max_new_tokens=args.max_output_len,
- max_attention_window_size=args.max_attention_window_size,
- end_id=end_id,
- pad_id=pad_id,
- temperature=args.temperature,
- top_k=args.top_k,
- top_p=args.top_p,
- num_beams=args.num_beams,
- length_penalty=args.length_penalty,
- repetition_penalty=args.repetition_penalty,
- stop_words_list=stop_words_list,
- bad_words_list=bad_words_list,
- lora_uids=args.lora_task_uids,
- prompt_table_path=args.prompt_table_path,
- prompt_tasks=args.prompt_tasks,
- streaming=args.streaming,
- output_sequence_lengths=True,
- return_dict=True)
- torch.cuda.synchronize()
+
+ def generate_with_runner():
+ with torch.no_grad():
+ outputs = runner.generate(
+ batch_input_ids,
+ max_new_tokens=args.max_output_len,
+ max_attention_window_size=args.max_attention_window_size,
+ end_id=end_id,
+ pad_id=pad_id,
+ temperature=args.temperature,
+ top_k=args.top_k,
+ top_p=args.top_p,
+ num_beams=args.num_beams,
+ length_penalty=args.length_penalty,
+ repetition_penalty=args.repetition_penalty,
+ stop_words_list=stop_words_list,
+ bad_words_list=bad_words_list,
+ lora_uids=args.lora_task_uids,
+ prompt_table_path=args.prompt_table_path,
+ prompt_tasks=args.prompt_tasks,
+ streaming=args.streaming,
+ output_sequence_lengths=True,
+ return_dict=True)
+ torch.cuda.synchronize()
+ return outputs
+
+ outputs = generate_with_runner()
if runtime_rank == 0:
if args.streaming:
@@ -381,6 +397,21 @@ def main(args):
context_logits=context_logits,
generation_logits=generation_logits,
output_logits_npy=args.output_logits_npy)
+
+ if args.run_profiling:
+ ite = 10
+ # warmup
+ for _ in range(ite):
+ generate_with_runner()
+
+ import tensorrt_llm.profiler
+ tensorrt_llm.profiler.start("tmp")
+ for _ in range(ite):
+ generate_with_runner()
+ tensorrt_llm.profiler.stop("tmp")
+ print(
+ f"batch_size: {len(batch_input_ids)}, avg latency of {ite} iterations: : {tensorrt_llm.profiler.elapsed_time_in_sec('tmp') / ite} sec"
+ )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment