Skip to content

Instantly share code, notes, and snippets.

@brandon-lockaby
Created August 20, 2023 21:26
Show Gist options
  • Save brandon-lockaby/86ed5c210a632ae7ccbc7783b5d5ba6f to your computer and use it in GitHub Desktop.
Save brandon-lockaby/86ed5c210a632ae7ccbc7783b5d5ba6f to your computer and use it in GitHub Desktop.
exllama_tasks
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from exllama.lora import ExLlamaLora
from exllama.tokenizer import ExLlamaTokenizer
from exllama.generator import ExLlamaGenerator
from exllama import model_init
import argparse
import torch
import sys
import os
#import glob
torch.set_grad_enabled(False)
torch.cuda._lazy_init()
parser = argparse.ArgumentParser(description = "ExLlama server")
model_init.add_args(parser)
parser.add_argument("-lora", "--lora", type = str, help = "Path to LoRA binary to use during benchmark")
parser.add_argument("-loracfg", "--lora_config", type = str, help = "Path to LoRA config to use during benchmark")
parser.add_argument("-ld", "--lora_dir", type = str, help = "Path to LoRA config and binary. to use during benchmark")
parser.add_argument("-temp", "--temperature", type = float, help = "Temperature", default = 0.95)
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K", default = 20)
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P", default = 0.65)
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P", default = 0.00)
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty", default = 1.15)
parser.add_argument("-repps", "--repetition_penalty_sustain", type = int, help = "Past length for repetition penalty", default = 256)
parser.add_argument("-beams", "--beams", type = int, help = "Number of beams for beam search", default = 1)
parser.add_argument("-beamlen", "--beam_length", type = int, help = "Number of future tokens to consider", default = 1)
args = parser.parse_args()
model_init.post_parse(args)
model_init.get_model_files(args)
# Paths
if args.lora_dir is not None:
args.lora_config = os.path.join(args.lora_dir, "adapter_config.json")
args.lora = os.path.join(args.lora_dir, "adapter_model.bin")
# Some feedback
print(f" -- Sequence length: {args.length}")
print(f" -- Temperature: {args.temperature:.2f}")
print(f" -- Top-K: {args.top_k}")
print(f" -- Top-P: {args.top_p:.2f}")
print(f" -- Min-P: {args.min_p:.2f}")
print(f" -- Repetition penalty: {args.repetition_penalty:.2f}")
print(f" -- Beams: {args.beams} x {args.beam_length}")
model_init.print_options(args)
# Globals
model_init.set_globals(args)
# Instantiate model and generator
config = model_init.make_config(args)
model = ExLlama(config)
cache = ExLlamaCache(model)
tokenizer = ExLlamaTokenizer(args.tokenizer)
model_init.print_stats(model)
# Load LoRA
lora = None
if args.lora:
print(f" -- LoRA config: {args.lora_config}")
print(f" -- Loading LoRA: {args.lora}")
if args.lora_config is None:
print(f" ## Error: please specify lora path to adapter_config.json")
sys.exit()
lora = ExLlamaLora(model, args.lora_config, args.lora)
if lora.bias_ignored:
print(f" !! Warning: LoRA zero bias ignored")
class ExLlamaBeamSearchTask:
def __init__(self, input_text, **kwargs):
self.is_running = False
#valid_keys = ['temperature', 'top_k', 'top_p', 'min_p', 'repetition_penalty', 'repetition_penalty_sustain', 'beams', 'beam_length']
#kwargs = {key: value for key, value in kwargs.items() if key in valid_keys}
print(input_text, dict(kwargs))
generator = ExLlamaGenerator(model, tokenizer, cache)
generator.settings = ExLlamaGenerator.Settings()
generator.settings.temperature = kwargs.get('temperature', args.temperature)
generator.settings.top_k = kwargs.get('top_k', args.top_k)
generator.settings.top_p = kwargs.get('top_p', args.top_p)
generator.settings.min_p = kwargs.get('min_p', args.min_p)
generator.settings.token_repetition_penalty_max = kwargs.get('repetition_penalty', args.repetition_penalty)
generator.settings.token_repetition_penalty_sustain = kwargs.get('repetition_penalty_sustain', args.repetition_penalty_sustain)
generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2
generator.settings.beams = kwargs.get('beams', args.beams)
generator.settings.beam_length = kwargs.get('beam_length', args.beam_length)
generator.lora = lora
self.generator = generator
self.input_text = input_text
self.output_text = ''
self.num_output_tokens = 0
self.max_new_tokens = kwargs.get('max_new_tokens', 2048) # todo warn on poor stopping criteria
self.input_ids = self.generator.tokenizer.encode(self.input_text)
self.num_input_tokens = self.input_ids.shape[-1]
self.generator.gen_begin(self.input_ids)
self.generator.begin_beam_search()
self.is_running = True
self.next_token = iter(self.generate_next_token, (None, None))
def stop(self):
if self.is_running:
self.is_running = False
self.generator.end_beam_search()
def generate_next_token(self):
if not self.is_running:
return None, None
if self.num_output_tokens >= self.max_new_tokens:
return None, None
token = self.generator.beam_search().item()
if token == self.generator.tokenizer.eos_token_id:
return None, None
self.num_output_tokens += 1
text = self.generator.tokenizer.decode(self.generator.sequence_actual[:, -self.num_output_tokens: ][0])
new_text = text[len(self.output_text):]
self.output_text += new_text
if not self.is_running:
return None, None
return new_text, token
class ExLlamaTask:
def __init__(self, input_text, **kwargs):
self.is_running = False
print(input_text, dict(kwargs))
generator = ExLlamaGenerator(model, tokenizer, cache)
generator.settings = ExLlamaGenerator.Settings()
generator.settings.temperature = kwargs.get('temperature', args.temperature)
generator.settings.top_k = kwargs.get('top_k', args.top_k)
generator.settings.top_p = kwargs.get('top_p', args.top_p)
generator.settings.min_p = kwargs.get('min_p', args.min_p)
generator.settings.token_repetition_penalty_max = kwargs.get('repetition_penalty', args.repetition_penalty)
generator.settings.token_repetition_penalty_sustain = kwargs.get('repetition_penalty_sustain', args.repetition_penalty_sustain)
generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2
generator.settings.beams = kwargs.get('beams', args.beams)
generator.settings.beam_length = kwargs.get('beam_length', args.beam_length)
generator.lora = lora
self.generator = generator
self.input_text = input_text
self.output_text = ''
self.num_output_tokens = 0
self.max_new_tokens = kwargs.get('max_new_tokens', 2048) # todo warn on poor stopping criteria
self.sequence = tokenizer.encode(self.input_text)
print('self.sequence =', self.sequence)
self.num_input_tokens = self.sequence.shape[-1]
self.is_running = True
self.next_token = iter(self.generate_next_token, (None, None))
self.generator.gen_begin(self.sequence)
def stop(self):
if self.is_running:
self.is_running = False
def generate_next_token(self):
if not self.is_running:
return None, None
if self.num_output_tokens >= self.max_new_tokens:
return None, None
self.generator.end_beam_search()
logits = model.forward(self.sequence[:, -1:], cache, lora = lora)[:, -1, :]
probs = torch.softmax(logits, dim=-1)
token = torch.argmax(probs, dim=-1)
probs = torch.topk(probs, 50, -1)
if token.item() == tokenizer.eos_token_id:
return None, None
self.num_output_tokens += 1
self.sequence = torch.cat([self.sequence, token.unsqueeze(0)], dim=-1)
text = tokenizer.decode(self.sequence[:, -self.num_output_tokens: ][0])
new_text = text[len(self.output_text):]
self.output_text += new_text
if not self.is_running:
return None, None
return new_text, token.item(), probs
# let's go
if True:
input_text = "### Instruction: What is a dog?\n### Response:"
task = ExLlamaTask(input_text, max_new_tokens=1024, temperature=1000)
for text, id, probs in task.next_token:
print(text, id, probs)
print(task.output_text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment