Created
August 20, 2023 21:26
-
-
Save brandon-lockaby/86ed5c210a632ae7ccbc7783b5d5ba6f to your computer and use it in GitHub Desktop.
exllama_tasks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig | |
from exllama.lora import ExLlamaLora | |
from exllama.tokenizer import ExLlamaTokenizer | |
from exllama.generator import ExLlamaGenerator | |
from exllama import model_init | |
import argparse | |
import torch | |
import sys | |
import os | |
#import glob | |
torch.set_grad_enabled(False) | |
torch.cuda._lazy_init() | |
parser = argparse.ArgumentParser(description = "ExLlama server") | |
model_init.add_args(parser) | |
parser.add_argument("-lora", "--lora", type = str, help = "Path to LoRA binary to use during benchmark") | |
parser.add_argument("-loracfg", "--lora_config", type = str, help = "Path to LoRA config to use during benchmark") | |
parser.add_argument("-ld", "--lora_dir", type = str, help = "Path to LoRA config and binary. to use during benchmark") | |
parser.add_argument("-temp", "--temperature", type = float, help = "Temperature", default = 0.95) | |
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K", default = 20) | |
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P", default = 0.65) | |
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P", default = 0.00) | |
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty", default = 1.15) | |
parser.add_argument("-repps", "--repetition_penalty_sustain", type = int, help = "Past length for repetition penalty", default = 256) | |
parser.add_argument("-beams", "--beams", type = int, help = "Number of beams for beam search", default = 1) | |
parser.add_argument("-beamlen", "--beam_length", type = int, help = "Number of future tokens to consider", default = 1) | |
args = parser.parse_args() | |
model_init.post_parse(args) | |
model_init.get_model_files(args) | |
# Paths | |
if args.lora_dir is not None: | |
args.lora_config = os.path.join(args.lora_dir, "adapter_config.json") | |
args.lora = os.path.join(args.lora_dir, "adapter_model.bin") | |
# Some feedback | |
print(f" -- Sequence length: {args.length}") | |
print(f" -- Temperature: {args.temperature:.2f}") | |
print(f" -- Top-K: {args.top_k}") | |
print(f" -- Top-P: {args.top_p:.2f}") | |
print(f" -- Min-P: {args.min_p:.2f}") | |
print(f" -- Repetition penalty: {args.repetition_penalty:.2f}") | |
print(f" -- Beams: {args.beams} x {args.beam_length}") | |
model_init.print_options(args) | |
# Globals | |
model_init.set_globals(args) | |
# Instantiate model and generator | |
config = model_init.make_config(args) | |
model = ExLlama(config) | |
cache = ExLlamaCache(model) | |
tokenizer = ExLlamaTokenizer(args.tokenizer) | |
model_init.print_stats(model) | |
# Load LoRA | |
lora = None | |
if args.lora: | |
print(f" -- LoRA config: {args.lora_config}") | |
print(f" -- Loading LoRA: {args.lora}") | |
if args.lora_config is None: | |
print(f" ## Error: please specify lora path to adapter_config.json") | |
sys.exit() | |
lora = ExLlamaLora(model, args.lora_config, args.lora) | |
if lora.bias_ignored: | |
print(f" !! Warning: LoRA zero bias ignored") | |
class ExLlamaBeamSearchTask: | |
def __init__(self, input_text, **kwargs): | |
self.is_running = False | |
#valid_keys = ['temperature', 'top_k', 'top_p', 'min_p', 'repetition_penalty', 'repetition_penalty_sustain', 'beams', 'beam_length'] | |
#kwargs = {key: value for key, value in kwargs.items() if key in valid_keys} | |
print(input_text, dict(kwargs)) | |
generator = ExLlamaGenerator(model, tokenizer, cache) | |
generator.settings = ExLlamaGenerator.Settings() | |
generator.settings.temperature = kwargs.get('temperature', args.temperature) | |
generator.settings.top_k = kwargs.get('top_k', args.top_k) | |
generator.settings.top_p = kwargs.get('top_p', args.top_p) | |
generator.settings.min_p = kwargs.get('min_p', args.min_p) | |
generator.settings.token_repetition_penalty_max = kwargs.get('repetition_penalty', args.repetition_penalty) | |
generator.settings.token_repetition_penalty_sustain = kwargs.get('repetition_penalty_sustain', args.repetition_penalty_sustain) | |
generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2 | |
generator.settings.beams = kwargs.get('beams', args.beams) | |
generator.settings.beam_length = kwargs.get('beam_length', args.beam_length) | |
generator.lora = lora | |
self.generator = generator | |
self.input_text = input_text | |
self.output_text = '' | |
self.num_output_tokens = 0 | |
self.max_new_tokens = kwargs.get('max_new_tokens', 2048) # todo warn on poor stopping criteria | |
self.input_ids = self.generator.tokenizer.encode(self.input_text) | |
self.num_input_tokens = self.input_ids.shape[-1] | |
self.generator.gen_begin(self.input_ids) | |
self.generator.begin_beam_search() | |
self.is_running = True | |
self.next_token = iter(self.generate_next_token, (None, None)) | |
def stop(self): | |
if self.is_running: | |
self.is_running = False | |
self.generator.end_beam_search() | |
def generate_next_token(self): | |
if not self.is_running: | |
return None, None | |
if self.num_output_tokens >= self.max_new_tokens: | |
return None, None | |
token = self.generator.beam_search().item() | |
if token == self.generator.tokenizer.eos_token_id: | |
return None, None | |
self.num_output_tokens += 1 | |
text = self.generator.tokenizer.decode(self.generator.sequence_actual[:, -self.num_output_tokens: ][0]) | |
new_text = text[len(self.output_text):] | |
self.output_text += new_text | |
if not self.is_running: | |
return None, None | |
return new_text, token | |
class ExLlamaTask: | |
def __init__(self, input_text, **kwargs): | |
self.is_running = False | |
print(input_text, dict(kwargs)) | |
generator = ExLlamaGenerator(model, tokenizer, cache) | |
generator.settings = ExLlamaGenerator.Settings() | |
generator.settings.temperature = kwargs.get('temperature', args.temperature) | |
generator.settings.top_k = kwargs.get('top_k', args.top_k) | |
generator.settings.top_p = kwargs.get('top_p', args.top_p) | |
generator.settings.min_p = kwargs.get('min_p', args.min_p) | |
generator.settings.token_repetition_penalty_max = kwargs.get('repetition_penalty', args.repetition_penalty) | |
generator.settings.token_repetition_penalty_sustain = kwargs.get('repetition_penalty_sustain', args.repetition_penalty_sustain) | |
generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2 | |
generator.settings.beams = kwargs.get('beams', args.beams) | |
generator.settings.beam_length = kwargs.get('beam_length', args.beam_length) | |
generator.lora = lora | |
self.generator = generator | |
self.input_text = input_text | |
self.output_text = '' | |
self.num_output_tokens = 0 | |
self.max_new_tokens = kwargs.get('max_new_tokens', 2048) # todo warn on poor stopping criteria | |
self.sequence = tokenizer.encode(self.input_text) | |
print('self.sequence =', self.sequence) | |
self.num_input_tokens = self.sequence.shape[-1] | |
self.is_running = True | |
self.next_token = iter(self.generate_next_token, (None, None)) | |
self.generator.gen_begin(self.sequence) | |
def stop(self): | |
if self.is_running: | |
self.is_running = False | |
def generate_next_token(self): | |
if not self.is_running: | |
return None, None | |
if self.num_output_tokens >= self.max_new_tokens: | |
return None, None | |
self.generator.end_beam_search() | |
logits = model.forward(self.sequence[:, -1:], cache, lora = lora)[:, -1, :] | |
probs = torch.softmax(logits, dim=-1) | |
token = torch.argmax(probs, dim=-1) | |
probs = torch.topk(probs, 50, -1) | |
if token.item() == tokenizer.eos_token_id: | |
return None, None | |
self.num_output_tokens += 1 | |
self.sequence = torch.cat([self.sequence, token.unsqueeze(0)], dim=-1) | |
text = tokenizer.decode(self.sequence[:, -self.num_output_tokens: ][0]) | |
new_text = text[len(self.output_text):] | |
self.output_text += new_text | |
if not self.is_running: | |
return None, None | |
return new_text, token.item(), probs | |
# let's go | |
if True: | |
input_text = "### Instruction: What is a dog?\n### Response:" | |
task = ExLlamaTask(input_text, max_new_tokens=1024, temperature=1000) | |
for text, id, probs in task.next_token: | |
print(text, id, probs) | |
print(task.output_text) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment