Skip to content

Instantly share code, notes, and snippets.

@epicfilemcnulty
Created July 11, 2023 17:54
Show Gist options
  • Save epicfilemcnulty/af56dd310166b5892d9cfcbfe1b53207 to your computer and use it in GitHub Desktop.
Save epicfilemcnulty/af56dd310166b5892d9cfcbfe1b53207 to your computer and use it in GitHub Desktop.
simple HF tranformers inference (HTTP API wrapped)
import transformers
import transformers.models.llama.modeling_llama
def enable_ntk_rope_scaling(alpha=4):
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
max_position_embeddings = 2048*alpha
a = alpha
base = base * a ** (dim / (dim-2))
old_init(self, dim, max_position_embeddings, base, device)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init
import argparse
import time
import torch
import uuid
import os
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline, BitsAndBytesConfig
from transformers import StoppingCriteria, StoppingCriteriaList
import transformers
from bottle import Bottle, run, route, request
from utils.ntk_rope_scale import enable_ntk_rope_scaling
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', required=False, type=str, default='/storage/models/LLaMA/FP16/Wizard-Vicuna-Uncensored-13B', help="Grasping Model")
parser.add_argument('-a', '--model_name', required=False, type=str, default="WizVicUncen13.NF4", help="Model's alias")
parser.add_argument('-4', '--four_bit', required=False, type=bool, default=True, help="Load in 4 bit")
parser.add_argument('-A', '--alpha', default=1, required=False, type=int, help="NTK Scaled RoPE's alpha")
parser.add_argument('-c', '--context', default=2048, required=False, type=int, help="Context length")
parser.add_argument('--port', default=8013, required=False, type=int, help="Port to listen on")
parser.add_argument('--ip', default='127.0.0.1', required=False, type=str, help="IP to listen on")
args = parser.parse_args()
if args.alpha != 1:
enable_ntk_rope_scaling(args.alpha)
app = Bottle()
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = [stop.to("cuda") for stop in stops]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
def load_model():
model_id = args.model
tokenizer = LlamaTokenizer.from_pretrained(model_id)
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
if args.four_bit:
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = LlamaForCausalLM.from_pretrained(model_id, device_map='auto', quantization_config=nf4_config, max_memory=max_memory)
else:
model = LlamaForCausalLM.from_pretrained(model_id, device_map='auto', load_in_8bit=True, max_memory=max_memory)
print(model.generation_config)
return model, tokenizer
llm, tokenizer = load_model()
conversations = {}
def full_conversation(idx):
chat = ''
for message in conversations[idx]['messages']:
if message['role'] == 'system':
chat += message['content'] + '\n\n'
if message['role'] == 'user':
chat += conversations[idx]['prefix'] + ' ' + message['content'] + '\n'
if message['role'] == 'assistant':
chat += conversations[idx]['suffix'] + ' ' + message['content'] + '\n'
if conversations[idx]['messages'][-1]['role'] == 'user':
chat += conversations[idx]['suffix']
return chat
@app.route('/prompt', method='PUT')
def set_prompt():
data = request.json
conversation_uuid = data.get('uuid', str(uuid.uuid4()))
prompt = data.get('prompt', '')
messages = data.get('messages', [{'role':'system', 'content':prompt}])
prefix = data.get('prefix', 'USER:')
suffix = data.get('suffix', 'ASSISTANT:')
conversations[conversation_uuid] = {
"messages": messages,
"prefix": prefix,
"suffix": suffix
}
return {"message": "Prompt set", "uuid": conversation_uuid}
@app.route('/chat', method='POST')
def chat():
data = request.json
conversation_uuid = data['uuid']
if conversation_uuid not in conversations:
return {"uuid":conversation_uuid, "message": "not found"}
temperature = data.get('temperature', 0.7)
max_new_tokens = data.get('max_length', 512)
query = data.get('query')
conversations[conversation_uuid]['messages'].append({'role':'user','content':query})
full_ctx = full_conversation(conversation_uuid)
stop_words = [conversations[conversation_uuid]['prefix']]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
start_time = time.time_ns()
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
input_ids = tokenizer(full_ctx, return_tensors="pt").input_ids.to('cuda')
outputs = llm.generate(
input_ids,
do_sample=False,
num_beams=1,
stopping_criteria=stopping_criteria,
max_new_tokens = max_new_tokens,
temperature = temperature,
num_return_sequences=1,
remove_invalid_values=True,
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = answer.replace(full_ctx,"")
conversations[conversation_uuid]['messages'].append({'role':'assistant','content':answer})
new_tokens = len(outputs[0]) - len(input_ids[0])
end_time = time.time_ns()
secs = (end_time - start_time) / 1e9
return {
"text": answer,
"ctx": len(outputs[0]),
"tokens": new_tokens,
"rate": new_tokens / secs,
"model": args.model_name,
}
run(app, host=args.ip, port=args.port)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment