Skip to content

Instantly share code, notes, and snippets.

@digiwombat
Created April 9, 2023 02:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save digiwombat/da67cb7acb02a01d1a27b2c923aca17a to your computer and use it in GitHub Desktop.
Save digiwombat/da67cb7acb02a01d1a27b2c923aca17a to your computer and use it in GitHub Desktop.
Modified llamacpp chat class (mildly edited from the original to fit text-gen-webui)
import sys
from time import time
from os import cpu_count
import llama_cpp
from modules.llamacpp_common import GptParams, gpt_params_parse, gpt_random_prompt
from modules import shared
ANSI_COLOR_RESET = "\x1b[0m"
ANSI_COLOR_YELLOW = "\x1b[33m"
ANSI_BOLD = "\x1b[1m"
ANSI_COLOR_GREEN = "\x1b[32m"
CONSOLE_COLOR_DEFAULT = ANSI_COLOR_RESET
CONSOLE_COLOR_PROMPT = ANSI_COLOR_YELLOW
CONSOLE_COLOR_USER_INPUT = ANSI_BOLD + ANSI_COLOR_GREEN
class LlamaCppModel:
def __init__(self) -> None:
self.initialized = False
@classmethod
def from_pretrained(self, path) -> None:
# input args
self.params = GptParams()
self.params.model = str(path)
self.params.interactive = True
if (self.params.perplexity):
raise NotImplementedError("""************
please use the 'perplexity' tool for perplexity calculations
************""")
if (self.params.embedding):
raise NotImplementedError("""************
please use the 'embedding' tool for embedding calculations
************""")
if (self.params.n_ctx > 2048):
print(f"""warning: model does not support \
context sizes greater than 2048 tokens ({self.params.n_ctx} \
specified) expect poor results""", file=sys.stderr)
if (self.params.seed <= 0):
self.params.seed = int(time())
print(f"seed = {self.params.seed}", file=sys.stderr)
if (self.params.random_prompt):
self.params.prompt = gpt_random_prompt(self.params.seed)
# runtime args
self.input_consumed = 0
self.n_past = 0
self.first_antiprompt = ["User:"]
self.remaining_tokens = self.params.n_predict
self.output_echo = self.params.input_echo
# model load
self.lparams = llama_cpp.llama_context_default_params()
self.lparams.n_ctx = self.params.n_ctx
self.lparams.n_parts = self.params.n_parts
self.lparams.seed = self.params.seed
self.lparams.memory_f16 = self.params.memory_f16
self.lparams.use_mlock = self.params.use_mlock
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams)
if (not self.ctx):
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
print(file=sys.stderr)
print(f"system_info: n_threads = {self.params.n_threads} / {cpu_count()} \
| {llama_cpp.llama_print_system_info().decode('utf8')}", file=sys.stderr)
# determine the required inference memory per token:
if (self.params.mem_test):
tmp = [0, 1, 2, 3]
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
llama_cpp.llama_print_timings(self.ctx)
self.exit()
return
# create internal context
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
# Add a space in front of the first character to match OG llama tokenizer behavior
self.params.prompt = " " + self.params.prompt
# Load prompt file
if (self.params.file):
with open(self.params.file) as f:
self.params.prompt = f.read()
# tokenize the prompt
self.embd = []
self.embd_inp = self._tokenize(self.params.prompt)
if (len(self.embd_inp) > self.params.n_ctx - 4):
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
# number of tokens to keep when resetting context
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
self.params.n_keep = len(self.embd_inp)
self.inp_prefix = self._tokenize(self.params.instruct_inp_prefix)
self.inp_suffix = self._tokenize(self.params.instruct_inp_suffix, False)
# in instruct mode, we inject a prefix and a suffix to each input by the user
if (self.params.instruct):
self.params.interactive_start = True
self.first_antiprompt.append(self._tokenize(self.params.instruct_inp_prefix.strip(), False))
# enable interactive mode if reverse prompt or interactive start is specified
if (len(self.params.antiprompt) != 0 or self.params.interactive_start):
self.params.interactive = True
# determine newline token
self.llama_token_newline = self._tokenize("\n", False)
if (self.params.verbose_prompt):
print(f"""
prompt: '{self.params.prompt}'
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
for i in range(len(self.embd_inp)):
print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr)
if (self.params.n_keep > 0):
print("static prompt based on n_keep: '")
for i in range(self.params.n_keep):
print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr)
print("'", file=sys.stderr)
print(file=sys.stderr)
if (self.params.interactive):
print("interactive mode on.", file=sys.stderr)
if (len(self.params.antiprompt) > 0):
for antiprompt in self.params.antiprompt:
print(f"Reverse prompt: '{antiprompt}'", file=sys.stderr)
if len(self.params.input_prefix) > 0:
print(f"Input prefix: '{self.params.input_prefix}'", file=sys.stderr)
print(f"""sampling: temp = {self.params.temp},\
top_k = {self.params.top_k},\
top_p = {self.params.top_p},\
repeat_last_n = {self.params.repeat_last_n},\
repeat_penalty = {self.params.repeat_penalty}
generate: n_ctx = {self.n_ctx}, \
n_batch = {self.params.n_batch}, \
n_predict = {self.params.n_predict}, \
n_keep = {self.params.n_keep}
""", file=sys.stderr)
# determine antiprompt tokens
for i in self.params.antiprompt:
self.first_antiprompt.append(self._tokenize(i, False))
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
if (self.params.interactive):
print("""== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to LLaMa.
- If you want to submit another line, end your input in '\\'.
""", file=sys.stderr)
self.set_color(CONSOLE_COLOR_PROMPT)
return self, self
# tokenize a prompt
@classmethod
def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
return _arr[:_n]
@classmethod
def use_antiprompt(self):
return len(self.first_antiprompt) > 0
@classmethod
def set_color(self, c):
if (self.params.use_color):
print(c, end="")
# generate tokens
@classmethod
def generate(self):
while self.remaining_tokens > 0 or self.params.interactive:
# predict
if len(self.embd) > 0:
# infinite text generation via context swapping
# if we run out of context:
# - take the n_keep first tokens from the original prompt (via n_past)
# - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
if (self.n_past + len(self.embd) > self.n_ctx):
n_left = self.n_past - self.params.n_keep
self.n_past = self.params.n_keep
# insert n_left/2 tokens at the start of embd from last_n_tokens
_insert = self.last_n_tokens[
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
]
self.embd = _insert + self.embd
if (llama_cpp.llama_eval(
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
) != 0):
raise Exception("Failed to llama_eval!")
self.n_past += len(self.embd)
self.embd = []
if len(self.embd_inp) <= self.input_consumed:
# out of user input, sample next token
#TODO: self.params.ignore_eos
_arr = self.last_n_tokens[-min(self.params.repeat_last_n, self.n_past):]
id = llama_cpp.llama_sample_top_p_top_k(
self.ctx,
(llama_cpp.llama_token * len(_arr))(*_arr),
len(_arr),
self.params.top_k,
self.params.top_p,
self.params.temp,
self.params.repeat_penalty,
)
self.last_n_tokens.pop(0)
self.last_n_tokens.append(id)
# replace end of text token with newline token when in interactive mode
if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct):
id = self.llama_token_newline[0]
if (self.use_antiprompt()):
# tokenize and inject first reverse prompt
self.embd_inp += self.first_antiprompt[0]
# add it to the context
self.embd.append(id)
# echo this to console
self.output_echo = True
# decrement remaining sampling budget
self.remaining_tokens -= 1
else:
# output to console if input echo is on
self.output_echo = self.params.input_echo
# some user input remains from prompt or interaction, forward it to processing
while len(self.embd_inp) > self.input_consumed:
self.embd.append(self.embd_inp[self.input_consumed])
self.last_n_tokens.pop(0)
self.last_n_tokens.append(self.embd_inp[self.input_consumed])
self.input_consumed += 1
if len(self.embd) >= self.params.n_batch:
break
# display tokens
if self.output_echo:
for id in self.embd:
yield id
# reset color to default if we there is no pending user input
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed):
self.set_color(CONSOLE_COLOR_DEFAULT)
if (self.params.interactive and len(self.embd_inp) <= self.input_consumed):
# if antiprompt is present, stop
if (self.use_antiprompt()):
if True in [
i == self.last_n_tokens[-len(i):]
for i in self.first_antiprompt
]:
break
# if we are using instruction mode, and we have processed the initial prompt
if (self.n_past > 0 and self.params.interactive_start):
break
# end of text token
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
if (not self.params.instruct):
for i in " [end of text]\n":
yield i
break
# respect n_predict even if antiprompt is present
if (self.params.interactive and self.remaining_tokens <= 0 and self.params.n_predict != -1):
# If we arent in instruction mode, fix the current generation by appending the antiprompt.
# Makes it so if chat ends prematurely you dont append the AI's text etc.
if not self.params.instruct:
self.embd_inp += self.first_antiprompt[0]
self.n_remain = self.params.n_predict
break
self.params.interactive_start = False
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.exit()
def exit(self):
llama_cpp.llama_free(self.ctx)
self.set_color(CONSOLE_COLOR_DEFAULT)
# return past text
def past(self):
for id in self.last_n_tokens[-self.n_past:]:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
# write input
@classmethod
def input(self, prompt: str):
if (self.params.instruct and self.last_n_tokens[-len(self.inp_prefix):] != self.inp_prefix):
self.embd_inp += self.inp_prefix
self.embd_inp += self._tokenize(prompt)
if (self.params.instruct):
self.embd_inp += self.inp_suffix
# write output
@classmethod
def output(self):
self.remaining_tokens = self.params.n_predict
for id in self.generate():
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
@classmethod
def output_static(self):
self.remaining_tokens = self.params.n_predict
output = ""
for id in self.generate():
output += llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
return output
# read user input
def read_input(self):
out = ""
while (t := input()).endswith("\\"):
out += t[:-1] + "\n"
return out + t + "\n"
# interactive mode
def interact(self):
for i in self.output():
print(i,end="",flush=True)
self.params.input_echo = False
while self.params.interactive:
self.set_color(CONSOLE_COLOR_USER_INPUT)
if (self.params.instruct):
print('\n> ', end="")
self.input(self.read_input())
else:
print(self.params.input_prefix, end="")
self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.output_postfix}")
print(self.params.output_postfix,end="")
self.set_color(CONSOLE_COLOR_DEFAULT)
try:
for i in self.output():
print(i,end="",flush=True)
except KeyboardInterrupt:
self.set_color(CONSOLE_COLOR_DEFAULT)
if not self.params.instruct:
print(self.params.fix_prefix,end="")
self.input(self.params.fix_prefix)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment