Skip to content

Instantly share code, notes, and snippets.

@aleksandr-smechov
Last active January 3, 2024 06:40
Show Gist options
  • Save aleksandr-smechov/437f17a055d146229a1eb3f64c5c4b4b to your computer and use it in GitHub Desktop.
Save aleksandr-smechov/437f17a055d146229a1eb3f64c5c4b4b to your computer and use it in GitHub Desktop.
vLLM gradio server for skypilot
import argparse
import requests
import gradio as gr
def http_bot(prompt, model_input, api_key, temperature, max_tokens, top_p):
headers = {"User-Agent": "vLLM Client"}
payload = dict(
model=model_input,
api_key=api_key,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=False
)
response = requests.post(args.model_url, headers=headers, json=payload)
return response.json()["choices"][0]["text"]
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# vLLM text completion demo\n")
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
model_input = gr.Textbox(label="Model")
api_key = gr.Textbox(label="API key")
temperature_slider = gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.1, value=0.7)
max_tokens_input = gr.Number(label="Max Tokens", value=128)
top_p_slider = gr.Slider(label="Top P", minimum=0, maximum=1, step=0.1, value=1.0)
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
inputbox.submit(
http_bot,
inputs=[inputbox, model_input, api_key, temperature_slider, max_tokens_input, top_p_slider],
outputs=outputbox
)
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", type=str, default="http://localhost:8000/completions")
args = parser.parse_args()
demo = build_demo()
demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port, share=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment