Initiate the chat with load_context
and then continue with generate
.
Created
May 15, 2023 20:41
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing | |
import llama_cpp | |
class LLaMAChat(): | |
def __init__( | |
self, | |
model: str, # model path | |
contextSize: int, # size of the prompt context | |
modelParts: int, # number of model parts | |
seed: int = 0, # RNG seed, 0 for random | |
): | |
print("[LLaMAChat] Init") | |
self.params = llama_cpp.llama_context_default_params() | |
self.params.n_ctx = contextSize | |
self.params.n_parts = modelParts | |
self.params.seed = seed | |
self.n_past = 0 | |
self.last_n_tokens_data = [0] * self.params.n_ctx | |
llama_cpp.llama_print_system_info() | |
print("[LLaMAChat] Loading model from file") | |
self.ctx = llama_cpp.llama_init_from_file(str.encode(model), self.params) | |
def free(self): | |
llama_cpp.llama_print_timings(self.ctx) | |
llama_cpp.llama_free(self.ctx) | |
def get_tokens_from_str(self, text: bytes): | |
tokens = (llama_cpp.llama_token * (len(text) + 1))() | |
n_of_tok = llama_cpp.llama_tokenize(self.ctx, text, tokens, len(tokens), True) | |
tokens = tokens[:n_of_tok] | |
return [tokens, n_of_tok] | |
def load_context( | |
self, | |
prompt: str | |
): | |
print("[LLaMAChat] Tokenizing context prompt") | |
t, n = self.get_tokens_from_str(str.encode(prompt)) | |
self.input_tokens = t | |
print("[LLaMAChat] Making first evaluation of input tokens") | |
llama_cpp.llama_eval( | |
self.ctx, (llama_cpp.c_int * len(self.input_tokens))(*self.input_tokens), len(self.input_tokens), 0, multiprocessing.cpu_count() | |
) | |
self.n_past += len(self.input_tokens) | |
def generate( | |
self, | |
prompt: str, | |
predictN: int = 256, | |
keepNOfCtx: int = 48, | |
stop: [] = [] | |
): | |
print(f"[LLaMAChat] Generating completion based on '{prompt}'") | |
prompt_tokens, _ = self.get_tokens_from_str(str.encode(prompt)) | |
self.input_tokens += prompt_tokens | |
# Context management, so we don't overflow | |
if len(self.input_tokens) > self.params.n_ctx: | |
print("[LLmAChat] Context overflow, re-evaluating") | |
starting_len = len(self.input_tokens) | |
self.n_past = max(1, keepNOfCtx) | |
keptContext = self.input_tokens[:keepNOfCtx] | |
self.input_tokens = self.input_tokens[keepNOfCtx:] | |
self.input_tokens = self.input_tokens[(starting_len - self.params.n_ctx) + keepNOfCtx:] | |
self.input_tokens = keptContext + self.input_tokens | |
llama_cpp.llama_eval( | |
self.ctx, (llama_cpp.c_int * len(self.input_tokens))(*self.input_tokens), len(self.input_tokens), self.n_past, multiprocessing.cpu_count() | |
) | |
# Do the generation | |
remaining_tokens = predictN | |
input_consumed = len(prompt_tokens) | |
n_batch = 1024 | |
last_n_repeat = 64 | |
last_n_tokens_data = [0] * 64 | |
repeat_penalty = 1 | |
frequency_penalty = 0.0 | |
presence_penalty = 0.0 | |
tokens_as_string = "" | |
tokens = prompt_tokens | |
last_n_tokens_data = last_n_tokens_data[1:] + prompt_tokens | |
while remaining_tokens > 0: | |
# Every time we have tokens, use then to get new logits and propabilities | |
if len(tokens) > 0: | |
llama_cpp.llama_eval( | |
self.ctx, (llama_cpp.c_int * len(tokens))(*tokens), len(tokens), self.n_past, multiprocessing.cpu_count() | |
) | |
self.n_past += len(tokens) | |
tokens = [] | |
if len(prompt_tokens) <= input_consumed: | |
logits = llama_cpp.llama_get_logits(self.ctx) | |
n_vocab = llama_cpp.llama_n_vocab(self.ctx) | |
_arr = (llama_cpp.llama_token_data * n_vocab)(*[ | |
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0) | |
for token_id in range(n_vocab) | |
]) | |
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False)) | |
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data) | |
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p, | |
_arr, | |
last_n_repeat, repeat_penalty) | |
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p, | |
_arr, | |
last_n_repeat, frequency_penalty, presence_penalty) | |
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, 40, 1) | |
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, 0.8, 1) | |
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, 0.2) | |
id = llama_cpp.llama_sample_token(self.ctx, candidates_p) | |
last_n_tokens_data = last_n_tokens_data[1:] + [id] | |
tokens.append(id) | |
remaining_tokens -= 1 | |
for id in tokens: | |
token_str = llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore") | |
tokens_as_string += token_str | |
for ss in stop: | |
if tokens_as_string.endswith(ss) and remaining_tokens < predictN - 1: | |
return | |
if token_str != '\n' and remaining_tokens < predictN - 1: | |
yield token_str | |
if len(tokens) > 0 and tokens[-1] == llama_cpp.llama_token_eos(): | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment