Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Last active April 26, 2024 14:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CoffeeVampir3/4d8f0cf31677aa005eada071567e5f1b to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/4d8f0cf31677aa005eada071567e5f1b to your computer and use it in GitHub Desktop.
exllama minimum example
from flask import Flask, render_template
import torch
from flask_socketio import SocketIO, emit
from generation.make_instruct import get_generator_func
from generation.exllama_generator_wrapper import encode_message, encode_system, encode_header
import os,sys
app = Flask(__name__)
socketio = SocketIO(app)
system_prompt = "Respond to all inputs with EEE"
seed_msg = encode_message(tokenizer, "user", "hello world")
init_msg = encode_message(tokenizer, "assistant", "EEE")
init_head = encode_header(tokenizer, "assistant")
enc_sys_prompt = encode_system(tokenizer, system_prompt)
enc_sys_prompt.extend(seed_msg)
enc_sys_prompt.extend(init_msg)
testing = torch.tensor(enc_sys_prompt).unsqueeze(dim=0)
@app.route('/')
def index():
return render_template('index.html')
@socketio.on('send_message')
def handle_send_message(message):
global generate, enc_sys_prompt, init_head
emit('user_message', message, broadcast=True)
emit('start_generation', broadcast=True) # Emit event to indicate generation start
next_message = encode_message(tokenizer, "user", message)
enc_sys_prompt.extend(next_message)
enc_sys_prompt.extend(init_head)
testing = torch.tensor(enc_sys_prompt).unsqueeze(dim=0)
for fragment, count in generate(instruction_ids=testing):
emit('stream_response', {'fragment': fragment}, broadcast=True)
emit('end_generation', broadcast=True) # Emit event to indicate generation end
if __name__ == '__main__':
socketio.run(app)
import sys, os, random
import torch
# A requirement for using exllamav2 api
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
)
def load_model(model_directory):
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
return config, tokenizer, cache, generator
def encode_system(tokenizer, system_prompt):
bos_token = tokenizer.single_id("<|begin_of_text|>")
eot_token = tokenizer.single_id("<|eot_id|>")
tokens = [bos_token]
tokens.extend(encode_header(tokenizer, "system"))
system_ids = tokenizer.encode(system_prompt, add_bos = False).view(-1).tolist()
tokens.extend(system_ids)
tokens.append(eot_token)
return tokens
def encode_header(tokenizer, username):
tokens = []
start_header = tokenizer.single_id("<|start_header_id|>")
end_header = tokenizer.single_id("<|end_header_id|>")
tokens.append(start_header)
tokens.extend(tokenizer.encode(username, add_bos = False).view(-1).tolist())
tokens.append(end_header)
tokens.extend(tokenizer.encode("\n\n", add_bos = False).view(-1).tolist())
return tokens
def encode_message(tokenizer, username, message):
eot_token = tokenizer.single_id("<|eot_id|>")
tokens = encode_header(tokenizer, username)
tokens.extend(
tokenizer.encode(message.strip(), add_bos = False).view(-1).tolist()
)
tokens.append(eot_token)
return tokens
def generate_response_stream(instruction_ids, tokenizer, generator, settings, stop_sequences=[]):
generator.begin_stream_ex(instruction_ids, settings)
stop_sequences.append(tokenizer.eos_token_id)
stop_sequences.append(128009)
generator.set_stop_conditions(stop_sequences)
while True:
res = generator.stream_ex()
if res["eos"]:
return
chunk = res["chunk"]
counts = len(res["chunk_token_ids"])
yield chunk, counts
import sys, os
from functools import partial
# Needed for exllamav2 lib
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2.generator import (
ExLlamaV2Sampler
)
from .exllama_generator_wrapper import load_model, generate_response_stream
def get_generator_func(model_path):
abs_path = os.path.abspath(model_path)
config, tokenizer, cache, generator = load_model(model_path)
settings = ExLlamaV2Sampler.Settings()
settings.temperature = 2
settings.top_k = 30
settings.min_p = 0.1
generate = partial(generate_response_stream, generator=generator, settings=settings, tokenizer=tokenizer)
return tokenizer, generate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment