Skip to content

Instantly share code, notes, and snippets.

@SinanAkkoyun
Last active July 22, 2023 13:07
Show Gist options
  • Save SinanAkkoyun/5bb69b0988231eb20896790b2d81e087 to your computer and use it in GitHub Desktop.
Save SinanAkkoyun/5bb69b0988231eb20896790b2d81e087 to your computer and use it in GitHub Desktop.
vllm_batch_test.py
import argparse
from vllm import LLM, SamplingParams
from itertools import cycle
import time
# Set up command-line argument parsing
parser = argparse.ArgumentParser(description='Generate text from repeated prompts.')
parser.add_argument('n', type=int, help='Number of repetitions for the prompts.')
args = parser.parse_args()
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Use the value of n provided as a command-line argument
n = args.n
# Create a cycle object
cycled_prompts = cycle(prompts)
# Create a new list with n elements, repeating the elements from prompts as needed
expanded_prompts = [next(cycled_prompts) for _ in range(n)]
prompts = expanded_prompts
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, ignore_eos=True, max_tokens=200)
llm = LLM(model="TheBloke/Llama-2-7B-fp16")
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
# Print one
prompt = outputs[0].prompt
generated_text = outputs[0].outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"batch: {len(prompts)}")
print(f"speed: {200/(end-start)}tps")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment