Last active
July 22, 2023 13:07
-
-
Save SinanAkkoyun/5bb69b0988231eb20896790b2d81e087 to your computer and use it in GitHub Desktop.
vllm_batch_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from vllm import LLM, SamplingParams | |
from itertools import cycle | |
import time | |
# Set up command-line argument parsing | |
parser = argparse.ArgumentParser(description='Generate text from repeated prompts.') | |
parser.add_argument('n', type=int, help='Number of repetitions for the prompts.') | |
args = parser.parse_args() | |
prompts = [ | |
"Hello, my name is", | |
"The president of the United States is", | |
"The capital of France is", | |
"The future of AI is", | |
] | |
# Use the value of n provided as a command-line argument | |
n = args.n | |
# Create a cycle object | |
cycled_prompts = cycle(prompts) | |
# Create a new list with n elements, repeating the elements from prompts as needed | |
expanded_prompts = [next(cycled_prompts) for _ in range(n)] | |
prompts = expanded_prompts | |
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, ignore_eos=True, max_tokens=200) | |
llm = LLM(model="TheBloke/Llama-2-7B-fp16") | |
start = time.time() | |
outputs = llm.generate(prompts, sampling_params) | |
end = time.time() | |
# Print one | |
prompt = outputs[0].prompt | |
generated_text = outputs[0].outputs[0].text | |
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | |
print(f"batch: {len(prompts)}") | |
print(f"speed: {200/(end-start)}tps") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment