Skip to content

Instantly share code, notes, and snippets.

@cat-state
Created April 28, 2023 19:42
Show Gist options
  • Save cat-state/85dc6ceaa70e0d386f3ba4009a5402af to your computer and use it in GitHub Desktop.
Save cat-state/85dc6ceaa70e0d386f3ba4009a5402af to your computer and use it in GitHub Desktop.
gradio + cluster inference

to use

HF_API_TOKEN=<token> sbatch hf-infer.sbatch

then run

HF_API_TOKEN=<token> HOSTNAME=<hostname of infernce server> python gradio-tgl.py

setup env following hf inference server instructions but chance /usr/local to path to conda env instead.

import os
import gc
from string import Template
from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer
from text_generation import Client
hostname = os.environ.get("HOSTNAME")
auth_token = os.environ.get("HF_API_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(
"CarperAI/stable-vicuna-13b-fp16",
use_auth_token=auth_token if auth_token else True,
)
max_context_length = 2048
max_new_tokens = 768
prompt_template = Template("""\
### Human: $human
### Assistant: $bot\
""")
system_prompt = "### Assistant: I am StableVicuna, a large language model created by CarperAI. I am here to chat!"
system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt")
max_sys_tokens = system_prompt_tokens['input_ids'].size(-1)
client = Client(f"http://{hostname}:8080")
def bot(history):
history = history or []
# Inject prompt formatting into the history
prompt_history = []
for human, bot in history:
if bot is not None:
bot = bot.replace("<br>", "\n")
bot = bot.rstrip()
prompt_history.append(
prompt_template.substitute(
human=human, bot=bot if bot is not None else "")
)
msg_tokens = tokenizer(
"\n\n".join(prompt_history).strip(),
return_tensors="pt",
add_special_tokens=False # Use <BOS> from the system prompt
)
# Take only the most recent context up to the max context length and prepend the
# system prompt with the messages
max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
input_tokens = torch.concat([system_prompt_tokens['input_ids'], msg_tokens['input_ids'][:, max_tokens:]], dim=-1)
input_text = tokenizer.decode(input_tokens[0].cpu().tolist())
generate_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=0.999,
temperature=1.0,
)
partial_text = ""
for resp in client.generate_stream(input_text, **generate_kwargs):
new_text = resp.token.text
# Process out the prompt separator
new_text = new_text.replace("<br>", "\n")
if (partial_text + new_text).endswith('###'):
history[-1][1] = (partial_text + new_text)[:-3]
break
# new_text = new_text.split("###")[0]
# partial_text += new_text.strip()
# history[-1][1] = partial_text
break
elif new_text == '#' or new_text == '##':
partial_text += new_text
else:
# Filter empty trailing new lines
# if new_text == "\n":
# new_text = ''
partial_text += new_text
history[-1][1] = partial_text
yield history
return partial_text
def user(user_message, history):
return "", history + [[user_message, None]]
with gr.Blocks() as demo:
gr.Markdown("#StableVicuna by CarperAI")
gr.HTML("<a href='https://huggingface.co/CarperAI/stable-vicuna-13b-delta'><code>CarperAI/stable-vicuna-13b-delta</a>")
gr.HTML('''<center><a href="https://huggingface.co/spaces/CarperAI/StableVicuna?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
state = gr.State([])
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Send a message",
placeholder="Send a message",
show_label=False
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Send")
stop = gr.Button("Stop")
clear = gr.Button("Clear History")
submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
clear.click(lambda: None, None, [chatbot], queue=True)
demo.queue(max_size=32, concurrency_count=2)
demo.launch(share=True)
#!/bin/bash
#SBATCH -p g40
#SBATCH --account trlx
#SBATCH --gres=gpu:1
#SBATCH --output="%x.out"
#SBATCH --job-name=hf-infer
cd /fsx/home-uwu
source .bashrc
micromamba activate text-generation-inference
cd text-generation-inference
text-generation-launcher --model-id CarperAI/vicuna-13b-fine-tuned-rlhf --num-shard 1 --port 8080
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment