Skip to content

Instantly share code, notes, and snippets.

@Firstbober
Created May 15, 2023 20:41
Show Gist options
  • Save Firstbober/d7f97e7f743a973c14425424e360eeda to your computer and use it in GitHub Desktop.
Save Firstbober/d7f97e7f743a973c14425424e360eeda to your computer and use it in GitHub Desktop.

Initiate the chat with load_context and then continue with generate.

import multiprocessing
import llama_cpp
class LLaMAChat():
def __init__(
self,
model: str, # model path
contextSize: int, # size of the prompt context
modelParts: int, # number of model parts
seed: int = 0, # RNG seed, 0 for random
):
print("[LLaMAChat] Init")
self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = contextSize
self.params.n_parts = modelParts
self.params.seed = seed
self.n_past = 0
self.last_n_tokens_data = [0] * self.params.n_ctx
llama_cpp.llama_print_system_info()
print("[LLaMAChat] Loading model from file")
self.ctx = llama_cpp.llama_init_from_file(str.encode(model), self.params)
def free(self):
llama_cpp.llama_print_timings(self.ctx)
llama_cpp.llama_free(self.ctx)
def get_tokens_from_str(self, text: bytes):
tokens = (llama_cpp.llama_token * (len(text) + 1))()
n_of_tok = llama_cpp.llama_tokenize(self.ctx, text, tokens, len(tokens), True)
tokens = tokens[:n_of_tok]
return [tokens, n_of_tok]
def load_context(
self,
prompt: str
):
print("[LLaMAChat] Tokenizing context prompt")
t, n = self.get_tokens_from_str(str.encode(prompt))
self.input_tokens = t
print("[LLaMAChat] Making first evaluation of input tokens")
llama_cpp.llama_eval(
self.ctx, (llama_cpp.c_int * len(self.input_tokens))(*self.input_tokens), len(self.input_tokens), 0, multiprocessing.cpu_count()
)
self.n_past += len(self.input_tokens)
def generate(
self,
prompt: str,
predictN: int = 256,
keepNOfCtx: int = 48,
stop: [] = []
):
print(f"[LLaMAChat] Generating completion based on '{prompt}'")
prompt_tokens, _ = self.get_tokens_from_str(str.encode(prompt))
self.input_tokens += prompt_tokens
# Context management, so we don't overflow
if len(self.input_tokens) > self.params.n_ctx:
print("[LLmAChat] Context overflow, re-evaluating")
starting_len = len(self.input_tokens)
self.n_past = max(1, keepNOfCtx)
keptContext = self.input_tokens[:keepNOfCtx]
self.input_tokens = self.input_tokens[keepNOfCtx:]
self.input_tokens = self.input_tokens[(starting_len - self.params.n_ctx) + keepNOfCtx:]
self.input_tokens = keptContext + self.input_tokens
llama_cpp.llama_eval(
self.ctx, (llama_cpp.c_int * len(self.input_tokens))(*self.input_tokens), len(self.input_tokens), self.n_past, multiprocessing.cpu_count()
)
# Do the generation
remaining_tokens = predictN
input_consumed = len(prompt_tokens)
n_batch = 1024
last_n_repeat = 64
last_n_tokens_data = [0] * 64
repeat_penalty = 1
frequency_penalty = 0.0
presence_penalty = 0.0
tokens_as_string = ""
tokens = prompt_tokens
last_n_tokens_data = last_n_tokens_data[1:] + prompt_tokens
while remaining_tokens > 0:
# Every time we have tokens, use then to get new logits and propabilities
if len(tokens) > 0:
llama_cpp.llama_eval(
self.ctx, (llama_cpp.c_int * len(tokens))(*tokens), len(tokens), self.n_past, multiprocessing.cpu_count()
)
self.n_past += len(tokens)
tokens = []
if len(prompt_tokens) <= input_consumed:
logits = llama_cpp.llama_get_logits(self.ctx)
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
for token_id in range(n_vocab)
])
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
_arr,
last_n_repeat, repeat_penalty)
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
_arr,
last_n_repeat, frequency_penalty, presence_penalty)
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, 40, 1)
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, 0.8, 1)
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, 0.2)
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)
last_n_tokens_data = last_n_tokens_data[1:] + [id]
tokens.append(id)
remaining_tokens -= 1
for id in tokens:
token_str = llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
tokens_as_string += token_str
for ss in stop:
if tokens_as_string.endswith(ss) and remaining_tokens < predictN - 1:
return
if token_str != '\n' and remaining_tokens < predictN - 1:
yield token_str
if len(tokens) > 0 and tokens[-1] == llama_cpp.llama_token_eos():
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment