Created
April 9, 2023 02:38
-
-
Save digiwombat/da67cb7acb02a01d1a27b2c923aca17a to your computer and use it in GitHub Desktop.
Modified llamacpp chat class (mildly edited from the original to fit text-gen-webui)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from time import time | |
from os import cpu_count | |
import llama_cpp | |
from modules.llamacpp_common import GptParams, gpt_params_parse, gpt_random_prompt | |
from modules import shared | |
ANSI_COLOR_RESET = "\x1b[0m" | |
ANSI_COLOR_YELLOW = "\x1b[33m" | |
ANSI_BOLD = "\x1b[1m" | |
ANSI_COLOR_GREEN = "\x1b[32m" | |
CONSOLE_COLOR_DEFAULT = ANSI_COLOR_RESET | |
CONSOLE_COLOR_PROMPT = ANSI_COLOR_YELLOW | |
CONSOLE_COLOR_USER_INPUT = ANSI_BOLD + ANSI_COLOR_GREEN | |
class LlamaCppModel: | |
def __init__(self) -> None: | |
self.initialized = False | |
@classmethod | |
def from_pretrained(self, path) -> None: | |
# input args | |
self.params = GptParams() | |
self.params.model = str(path) | |
self.params.interactive = True | |
if (self.params.perplexity): | |
raise NotImplementedError("""************ | |
please use the 'perplexity' tool for perplexity calculations | |
************""") | |
if (self.params.embedding): | |
raise NotImplementedError("""************ | |
please use the 'embedding' tool for embedding calculations | |
************""") | |
if (self.params.n_ctx > 2048): | |
print(f"""warning: model does not support \ | |
context sizes greater than 2048 tokens ({self.params.n_ctx} \ | |
specified) expect poor results""", file=sys.stderr) | |
if (self.params.seed <= 0): | |
self.params.seed = int(time()) | |
print(f"seed = {self.params.seed}", file=sys.stderr) | |
if (self.params.random_prompt): | |
self.params.prompt = gpt_random_prompt(self.params.seed) | |
# runtime args | |
self.input_consumed = 0 | |
self.n_past = 0 | |
self.first_antiprompt = ["User:"] | |
self.remaining_tokens = self.params.n_predict | |
self.output_echo = self.params.input_echo | |
# model load | |
self.lparams = llama_cpp.llama_context_default_params() | |
self.lparams.n_ctx = self.params.n_ctx | |
self.lparams.n_parts = self.params.n_parts | |
self.lparams.seed = self.params.seed | |
self.lparams.memory_f16 = self.params.memory_f16 | |
self.lparams.use_mlock = self.params.use_mlock | |
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams) | |
if (not self.ctx): | |
raise RuntimeError(f"error: failed to load model '{self.params.model}'") | |
print(file=sys.stderr) | |
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \ | |
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr) | |
# determine the required inference memory per token: | |
if (self.params.mem_test): | |
tmp = [0, 1, 2, 3] | |
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads) | |
llama_cpp.llama_print_timings(self.ctx) | |
self.exit() | |
return | |
# create internal context | |
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx) | |
# Add a space in front of the first character to match OG llama tokenizer behavior | |
self.params.prompt = " " + self.params.prompt | |
# Load prompt file | |
if (self.params.file): | |
with open(self.params.file) as f: | |
self.params.prompt = f.read() | |
# tokenize the prompt | |
self.embd = [] | |
self.embd_inp = self._tokenize(self.params.prompt) | |
if (len(self.embd_inp) > self.params.n_ctx - 4): | |
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})") | |
# number of tokens to keep when resetting context | |
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct): | |
self.params.n_keep = len(self.embd_inp) | |
self.inp_prefix = self._tokenize(self.params.instruct_inp_prefix) | |
self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False) | |
# in instruct mode, we inject a prefix and a suffix to each input by the user | |
if (self.params.instruct): | |
self.params.interactive_start = True | |
self.first_antiprompt.append(self._tokenize(self.params.instruct_inp_prefix.strip(), False)) | |
# enable interactive mode if reverse prompt or interactive start is specified | |
if (len(self.params.antiprompt) != 0 or self.params.interactive_start): | |
self.params.interactive = True | |
# determine newline token | |
self.llama_token_newline = self._tokenize("\n", False) | |
if (self.params.verbose_prompt): | |
print(f""" | |
prompt: '{self.params.prompt}' | |
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr) | |
for i in range(len(self.embd_inp)): | |
print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr) | |
if (self.params.n_keep > 0): | |
print("static prompt based on n_keep: '") | |
for i in range(self.params.n_keep): | |
print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr) | |
print("'", file=sys.stderr) | |
print(file=sys.stderr) | |
if (self.params.interactive): | |
print("interactive mode on.", file=sys.stderr) | |
if (len(self.params.antiprompt) > 0): | |
for antiprompt in self.params.antiprompt: | |
print(f"Reverse prompt: '{antiprompt}'", file=sys.stderr) | |
if len(self.params.input_prefix) > 0: | |
print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr) | |
print(f"""sampling: temp = {self.params.temp},\ | |
top_k = {self.params.top_k},\ | |
top_p = {self.params.top_p},\ | |
repeat_last_n = {self.params.repeat_last_n},\ | |
repeat_penalty = {self.params.repeat_penalty} | |
generate: n_ctx = {self.n_ctx}, \ | |
n_batch = {self.params.n_batch}, \ | |
n_predict = {self.params.n_predict}, \ | |
n_keep = {self.params.n_keep} | |
""", file=sys.stderr) | |
# determine antiprompt tokens | |
for i in self.params.antiprompt: | |
self.first_antiprompt.append(self._tokenize(i, False)) | |
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices | |
if (self.params.interactive): | |
print("""== Running in interactive mode. == | |
- Press Ctrl+C to interject at any time. | |
- Press Return to return control to LLaMa. | |
- If you want to submit another line, end your input in '\\'. | |
""", file=sys.stderr) | |
self.set_color(CONSOLE_COLOR_PROMPT) | |
return self, self | |
# tokenize a prompt | |
@classmethod | |
def _tokenize(self, prompt, bos=True): | |
_arr = (llama_cpp.llama_token * (len(prompt) + 1))() | |
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos) | |
return _arr[:_n] | |
@classmethod | |
def use_antiprompt(self): | |
return len(self.first_antiprompt) > 0 | |
@classmethod | |
def set_color(self, c): | |
if (self.params.use_color): | |
print(c, end="") | |
# generate tokens | |
@classmethod | |
def generate(self): | |
while self.remaining_tokens > 0 or self.params.interactive: | |
# predict | |
if len(self.embd) > 0: | |
# infinite text generation via context swapping | |
# if we run out of context: | |
# - take the n_keep first tokens from the original prompt (via n_past) | |
# - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch | |
if (self.n_past + len(self.embd) > self.n_ctx): | |
n_left = self.n_past - self.params.n_keep | |
self.n_past = self.params.n_keep | |
# insert n_left/2 tokens at the start of embd from last_n_tokens | |
_insert = self.last_n_tokens[ | |
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd) | |
] | |
self.embd = _insert + self.embd | |
if (llama_cpp.llama_eval( | |
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads | |
) != 0): | |
raise Exception("Failed to llama_eval!") | |
self.n_past += len(self.embd) | |
self.embd = [] | |
if len(self.embd_inp) <= self.input_consumed: | |
# out of user input, sample next token | |
#TODO: self.params.ignore_eos | |
_arr = self.last_n_tokens[-min(self.params.repeat_last_n, self.n_past):] | |
id = llama_cpp.llama_sample_top_p_top_k( | |
self.ctx, | |
(llama_cpp.llama_token * len(_arr))(*_arr), | |
len(_arr), | |
self.params.top_k, | |
self.params.top_p, | |
self.params.temp, | |
self.params.repeat_penalty, | |
) | |
self.last_n_tokens.pop(0) | |
self.last_n_tokens.append(id) | |
# replace end of text token with newline token when in interactive mode | |
if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct): | |
id = self.llama_token_newline[0] | |
if (self.use_antiprompt()): | |
# tokenize and inject first reverse prompt | |
self.embd_inp += self.first_antiprompt[0] | |
# add it to the context | |
self.embd.append(id) | |
# echo this to console | |
self.output_echo = True | |
# decrement remaining sampling budget | |
self.remaining_tokens -= 1 | |
else: | |
# output to console if input echo is on | |
self.output_echo = self.params.input_echo | |
# some user input remains from prompt or interaction, forward it to processing | |
while len(self.embd_inp) > self.input_consumed: | |
self.embd.append(self.embd_inp[self.input_consumed]) | |
self.last_n_tokens.pop(0) | |
self.last_n_tokens.append(self.embd_inp[self.input_consumed]) | |
self.input_consumed += 1 | |
if len(self.embd) >= self.params.n_batch: | |
break | |
# display tokens | |
if self.output_echo: | |
for id in self.embd: | |
yield id | |
# reset color to default if we there is no pending user input | |
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed): | |
self.set_color(CONSOLE_COLOR_DEFAULT) | |
if (self.params.interactive and len(self.embd_inp) <= self.input_consumed): | |
# if antiprompt is present, stop | |
if (self.use_antiprompt()): | |
if True in [ | |
i == self.last_n_tokens[-len(i):] | |
for i in self.first_antiprompt | |
]: | |
break | |
# if we are using instruction mode, and we have processed the initial prompt | |
if (self.n_past > 0 and self.params.interactive_start): | |
break | |
# end of text token | |
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(): | |
if (not self.params.instruct): | |
for i in " [end of text]\n": | |
yield i | |
break | |
# respect n_predict even if antiprompt is present | |
if (self.params.interactive and self.remaining_tokens <= 0 and self.params.n_predict != -1): | |
# If we arent in instruction mode, fix the current generation by appending the antiprompt. | |
# Makes it so if chat ends prematurely you dont append the AI's text etc. | |
if not self.params.instruct: | |
self.embd_inp += self.first_antiprompt[0] | |
self.n_remain = self.params.n_predict | |
break | |
self.params.interactive_start = False | |
def __enter__(self): | |
return self | |
def __exit__(self, type, value, tb): | |
self.exit() | |
def exit(self): | |
llama_cpp.llama_free(self.ctx) | |
self.set_color(CONSOLE_COLOR_DEFAULT) | |
# return past text | |
def past(self): | |
for id in self.last_n_tokens[-self.n_past:]: | |
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") | |
# write input | |
@classmethod | |
def input(self, prompt: str): | |
if (self.params.instruct and self.last_n_tokens[-len(self.inp_prefix):] != self.inp_prefix): | |
self.embd_inp += self.inp_prefix | |
self.embd_inp += self._tokenize(prompt) | |
if (self.params.instruct): | |
self.embd_inp += self.inp_suffix | |
# write output | |
@classmethod | |
def output(self): | |
self.remaining_tokens = self.params.n_predict | |
for id in self.generate(): | |
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") | |
@classmethod | |
def output_static(self): | |
self.remaining_tokens = self.params.n_predict | |
output = "" | |
for id in self.generate(): | |
output += llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8") | |
return output | |
# read user input | |
def read_input(self): | |
out = "" | |
while (t := input()).endswith("\\"): | |
out += t[:-1] + "\n" | |
return out + t + "\n" | |
# interactive mode | |
def interact(self): | |
for i in self.output(): | |
print(i,end="",flush=True) | |
self.params.input_echo = False | |
while self.params.interactive: | |
self.set_color(CONSOLE_COLOR_USER_INPUT) | |
if (self.params.instruct): | |
print('\n> ', end="") | |
self.input(self.read_input()) | |
else: | |
print(self.params.input_prefix, end="") | |
self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.output_postfix}") | |
print(self.params.output_postfix,end="") | |
self.set_color(CONSOLE_COLOR_DEFAULT) | |
try: | |
for i in self.output(): | |
print(i,end="",flush=True) | |
except KeyboardInterrupt: | |
self.set_color(CONSOLE_COLOR_DEFAULT) | |
if not self.params.instruct: | |
print(self.params.fix_prefix,end="") | |
self.input(self.params.fix_prefix) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment