Skip to content

Instantly share code, notes, and snippets.

@epicfilemcnulty
Created September 5, 2023 11:27
Show Gist options
  • Save epicfilemcnulty/46d022c52692e716c3ff8539c3b968ec to your computer and use it in GitHub Desktop.
Save epicfilemcnulty/46d022c52692e716c3ff8539c3b968ec to your computer and use it in GitHub Desktop.
Simple HTTP API wrapper around HF transformers
import argparse
import time
import torch
import uuid
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, BitsAndBytesConfig, GPTQConfig
from transformers import StoppingCriteria, StoppingCriteriaList
from auto_gptq import exllama_set_max_input_length
import transformers
from peft import PeftModel
from bottle import Bottle, run, route, request
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', required=True, type=str, help="Grasping Model")
parser.add_argument('-a', '--model_name', required=False, type=str, default="uknown", help="Grasping Model's Alias")
parser.add_argument('-l', '--lora_dir', required=False, type=str, default='', help="Path to lora directory")
parser.add_argument('-b', '--bits', required=False, type=int, default=0, help="Load in INT8 or NF4 or GPTQ")
parser.add_argument('-c', '--context', required=False, type=int, default=2048, help="Context length for exllama cache")
parser.add_argument('-r', '--remote', required=False, type=bool, default=False, help="Trust remote code (default is False)")
parser.add_argument('--port', default=8013, required=False, type=int, help="Port to listen on")
parser.add_argument('--ip', default='127.0.0.1', required=False, type=str, help="IP to listen on")
args = parser.parse_args()
app = Bottle()
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = [stop.to("cuda") for stop in stops]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
def load_model():
model_id = args.model
tokenizer = AutoTokenizer.from_pretrained(model_id)
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
if args.bits == 0:
gptq_config = GPTQConfig(bits=4, disable_exllama=False)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', quantization_config=gptq_config, max_memory=max_memory, trust_remote_code=args.remote)
model = exllama_set_max_input_length(model, args.context)
elif args.bits == 4:
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', quantization_config=nf4_config, max_memory=max_memory, trust_remote_code=args.remote)
else:
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', load_in_8bit=True, max_memory=max_memory, trust_remote_code=args.remote)
print(model.generation_config)
if args.lora_dir != '':
model = PeftModel.from_pretrained(model, args.lora_dir)
return model, tokenizer
llm, tokenizer = load_model()
conversations = {}
def full_conversation(idx):
chat = ''
for message in conversations[idx]['messages']:
if message['role'] == 'system':
chat += message['content']
if message['role'] == 'user':
chat += conversations[idx]['prefix'] + message['content'] + conversations[idx]['postfix']
if message['role'] == 'assistant':
chat += conversations[idx]['suffix'] + message['content'] + '\n'
if conversations[idx]['messages'][-1]['role'] == 'user':
chat += conversations[idx]['suffix']
return chat
@app.route('/prompt', method='PUT')
def set_prompt():():
data = request.json
conversation_uuid = data.get('uuid', str(uuid.uuid4()))
messages = data.get('messages', [{'role':'system', 'content':''}])
prefix = data.get('prefix', 'USER: ')
postfix = data.get('postfix', '\n')
suffix = data.get('suffix', 'ASSISTANT:')
conversations[conversation_uuid] = {
"messages": messages,
"prefix": prefix,
"suffix": suffix,
"postfix": postfix
}
return {"message": "Prompt set", "uuid": conversation_uuid}
@app.route('/chat', method='POST')
def chat():
data = request.json
conversation_uuid = data['uuid']
if conversation_uuid not in conversations:
return {"uuid":conversation_uuid, "message": "not found"}
temperature = data.get('temperature', 0.5)
max_new_tokens = data.get('max_length', 256)
query = data.get('query')
conversations[conversation_uuid]['messages'].append({'role':'user','content':query})
full_ctx = full_conversation(conversation_uuid)
stop_words = [conversations[conversation_uuid]['prefix'].rstrip(), '</s>']
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
start_time = time.time_ns()
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
input_ids = tokenizer(full_ctx, return_tensors="pt").input_ids.to('cuda')
outputs = llm.generate(
inputs=input_ids,
do_sample=True,
num_beams=1,
stopping_criteria=stopping_criteria,
max_new_tokens = max_new_tokens,
temperature = temperature,
num_return_sequences=1,
remove_invalid_values=True,
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = answer.replace(full_ctx,"")
conversations[conversation_uuid]['messages'].append({'role':'assistant','content':answer})
new_tokens = len(outputs[0]) - len(input_ids[0])
end_time = time.time_ns()
secs = (end_time - start_time) / 1e9
return {
"text": answer,
"ctx": len(outputs[0]),
"tokens": new_tokens,
"rate": new_tokens / secs,
"model": args.model_name,
}
@app.route('/complete', method='POST')
def complete():
data = request.json
temperature = data.get('temperature', 0.5)
max_new_tokens = data.get('max_length', 256)
query = data.get('query')
tok = AutoTokenizer.from_pretrained(args.model, add_bos_token=False)se)
start_time = time.time_ns()
input_ids = tok(query, return_tensors="pt").input_ids.to('cuda')
outputs = llm.generate(
inputs=input_ids,
do_sample=True,
num_beams=1,
max_new_tokens = max_new_tokens,
temperature = temperature,
num_return_sequences=1,
remove_invalid_values=True,
)
answer = tokenizer.decode(outputs[0])
new_tokens = len(outputs[0]) - len(input_ids[0])
end_time = time.time_ns()
secs = (end_time - start_time) / 1e9
return {
"text": answer,
"ctx": len(outputs[0]),
"tokens": new_tokens,
"rate": new_tokens / secs,
"model": args.model_name,
}
run(app, host=args.ip, port=args.port)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment