Last active
January 27, 2024 21:29
-
-
Save richardliaw/716958495c250fe377a65f562f2d6229 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/examples/run.py b/examples/run.py | |
index 9a05434..d7946db 100644 | |
--- a/examples/run.py | |
+++ b/examples/run.py | |
@@ -135,6 +135,12 @@ def parse_arguments(args=None): | |
choices=["hf", "nemo"], | |
help="The source of lora checkpoint.") | |
+ parser.add_argument( | |
+ '--run_profiling', | |
+ default=False, | |
+ action='store_true', | |
+ help="Run several 10 iterations to profile the inference latencies.") | |
+ | |
parser.add_argument( | |
'--num_prepend_vtokens', | |
nargs="+", | |
@@ -158,7 +164,11 @@ def parse_input(tokenizer, | |
pad_id = tokenizer.pad_token_id | |
batch_input_ids = [] | |
- if input_file is None: | |
+ print(input_text) | |
+ if input_text[0].startswith("DUMMY"): | |
+ batch_size = input_text[0].split("_")[1] if "_" in input_text else 1 | |
+ batch_input_ids = [np.zeros(max_input_length)] * batch_size | |
+ elif input_file is None: | |
for curr_text in input_text: | |
if prompt_template is not None: | |
curr_text = prompt_template.format(input_text=curr_text) | |
@@ -229,6 +239,7 @@ def print_output(tokenizer, | |
output_end = sequence_lengths[batch_idx][beam] | |
outputs = output_ids[batch_idx][beam][ | |
output_begin:output_end].tolist() | |
+ print("Length Output", output_end - output_begin) | |
output_text = tokenizer.decode(outputs) | |
print( | |
f'Output [Text {batch_idx} Beam {beam}]: \"{output_text}\"') | |
@@ -270,6 +281,7 @@ def print_output(tokenizer, | |
def main(args): | |
+ import tensorrt_llm | |
runtime_rank = tensorrt_llm.mpi_rank() | |
logger.set_level(args.log_level) | |
@@ -328,29 +340,33 @@ def main(args): | |
max_beam_width=args.num_beams, | |
max_attention_window_size=args.max_attention_window_size) | |
runner = runner_cls.from_dir(**runner_kwargs) | |
- | |
- with torch.no_grad(): | |
- outputs = runner.generate( | |
- batch_input_ids, | |
- max_new_tokens=args.max_output_len, | |
- max_attention_window_size=args.max_attention_window_size, | |
- end_id=end_id, | |
- pad_id=pad_id, | |
- temperature=args.temperature, | |
- top_k=args.top_k, | |
- top_p=args.top_p, | |
- num_beams=args.num_beams, | |
- length_penalty=args.length_penalty, | |
- repetition_penalty=args.repetition_penalty, | |
- stop_words_list=stop_words_list, | |
- bad_words_list=bad_words_list, | |
- lora_uids=args.lora_task_uids, | |
- prompt_table_path=args.prompt_table_path, | |
- prompt_tasks=args.prompt_tasks, | |
- streaming=args.streaming, | |
- output_sequence_lengths=True, | |
- return_dict=True) | |
- torch.cuda.synchronize() | |
+ | |
+ def generate_with_runner(): | |
+ with torch.no_grad(): | |
+ outputs = runner.generate( | |
+ batch_input_ids, | |
+ max_new_tokens=args.max_output_len, | |
+ max_attention_window_size=args.max_attention_window_size, | |
+ end_id=end_id, | |
+ pad_id=pad_id, | |
+ temperature=args.temperature, | |
+ top_k=args.top_k, | |
+ top_p=args.top_p, | |
+ num_beams=args.num_beams, | |
+ length_penalty=args.length_penalty, | |
+ repetition_penalty=args.repetition_penalty, | |
+ stop_words_list=stop_words_list, | |
+ bad_words_list=bad_words_list, | |
+ lora_uids=args.lora_task_uids, | |
+ prompt_table_path=args.prompt_table_path, | |
+ prompt_tasks=args.prompt_tasks, | |
+ streaming=args.streaming, | |
+ output_sequence_lengths=True, | |
+ return_dict=True) | |
+ torch.cuda.synchronize() | |
+ return outputs | |
+ | |
+ outputs = generate_with_runner() | |
if runtime_rank == 0: | |
if args.streaming: | |
@@ -381,6 +397,21 @@ def main(args): | |
context_logits=context_logits, | |
generation_logits=generation_logits, | |
output_logits_npy=args.output_logits_npy) | |
+ | |
+ if args.run_profiling: | |
+ ite = 10 | |
+ # warmup | |
+ for _ in range(ite): | |
+ generate_with_runner() | |
+ | |
+ import tensorrt_llm.profiler | |
+ tensorrt_llm.profiler.start("tmp") | |
+ for _ in range(ite): | |
+ generate_with_runner() | |
+ tensorrt_llm.profiler.stop("tmp") | |
+ print( | |
+ f"batch_size: {len(batch_input_ids)}, avg latency of {ite} iterations: : {tensorrt_llm.profiler.elapsed_time_in_sec('tmp') / ite} sec" | |
+ ) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment