-
-
Save Firstbober/a08de9cf01ea90b6be8389be9a249857 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing | |
import llama_cpp | |
class LLaMAChat(): | |
def __init__( | |
self, | |
model: str, # model path | |
contextSize: int, # size of the prompt context | |
modelParts: int, # number of model parts | |
seed: int = 0, # RNG seed, 0 for random | |
): | |
print("[LLaMAChat] Init") | |
self.params = llama_cpp.llama_context_default_params() | |
self.params.n_ctx = contextSize | |
self.params.n_parts = modelParts | |
self.params.seed = seed | |
self.n_past = 0 | |
self.last_n_tokens_data = [0] * self.params.n_ctx | |
llama_cpp.llama_print_system_info() | |
print("[LLaMAChat] Loading model from file") | |
self.ctx = llama_cpp.llama_init_from_file(str.encode(model), self.params) | |
def free(self): | |
llama_cpp.llama_print_timings(self.ctx) | |
llama_cpp.llama_free(self.ctx) | |
def get_tokens_from_str(self, text: bytes): | |
tokens = (llama_cpp.llama_token * (len(text) + 1))() | |
n_of_tok = llama_cpp.llama_tokenize(self.ctx, text, tokens, len(tokens), True) | |
tokens = tokens[:n_of_tok] | |
return [tokens, n_of_tok] | |
def load_context( | |
self, | |
prompt: str | |
): | |
print("[LLaMAChat] Tokenizing context prompt") | |
t, n = self.get_tokens_from_str(str.encode(prompt)) | |
self.input_tokens = t | |
print("[LLaMAChat] Making first evaluation of input tokens") | |
llama_cpp.llama_eval( | |
self.ctx, (llama_cpp.c_int * len(self.input_tokens))(*self.input_tokens), len(self.input_tokens), 0, multiprocessing.cpu_count() | |
) | |
self.n_past += len(self.input_tokens) | |
def generate( | |
self, | |
prompt: str, | |
predictN: int = 256, | |
keepNOfCtx: int = 48, | |
stop: [] = [] | |
): | |
print(f"[LLaMAChat] Generating completion based on '{prompt}'") | |
prompt_tokens, _ = self.get_tokens_from_str(str.encode(prompt)) | |
self.input_tokens += prompt_tokens | |
# Context management, so we don't overflow | |
if len(self.input_tokens) > self.params.n_ctx: | |
print("[LLmAChat] Context overflow, re-evaluating") | |
starting_len = len(self.input_tokens) | |
self.n_past = max(1, keepNOfCtx) | |
keptContext = self.input_tokens[:keepNOfCtx] | |
self.input_tokens = self.input_tokens[keepNOfCtx:] | |
self.input_tokens = self.input_tokens[(starting_len - self.params.n_ctx) + keepNOfCtx:] | |
self.input_tokens = keptContext + self.input_tokens | |
llama_cpp.llama_eval( | |
self.ctx, (llama_cpp.c_int * len(self.input_tokens))(*self.input_tokens), len(self.input_tokens), self.n_past, multiprocessing.cpu_count() | |
) | |
# Do the generation | |
remaining_tokens = predictN | |
input_consumed = len(prompt_tokens) | |
n_batch = 1024 | |
last_n_repeat = 64 | |
last_n_tokens_data = [0] * 64 | |
repeat_penalty = 1 | |
frequency_penalty = 0.0 | |
presence_penalty = 0.0 | |
tokens_as_string = "" | |
tokens = prompt_tokens | |
last_n_tokens_data = last_n_tokens_data[1:] + prompt_tokens | |
while remaining_tokens > 0: | |
# Every time we have tokens, use then to get new logits and propabilities | |
if len(tokens) > 0: | |
llama_cpp.llama_eval( | |
self.ctx, (llama_cpp.c_int * len(tokens))(*tokens), len(tokens), self.n_past, multiprocessing.cpu_count() | |
) | |
self.n_past += len(tokens) | |
tokens = [] | |
if len(prompt_tokens) <= input_consumed: | |
logits = llama_cpp.llama_get_logits(self.ctx) | |
n_vocab = llama_cpp.llama_n_vocab(self.ctx) | |
_arr = (llama_cpp.llama_token_data * n_vocab)(*[ | |
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0) | |
for token_id in range(n_vocab) | |
]) | |
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False)) | |
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data) | |
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p, | |
_arr, | |
last_n_repeat, repeat_penalty) | |
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p, | |
_arr, | |
last_n_repeat, frequency_penalty, presence_penalty) | |
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, 40, 1) | |
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, 0.8, 1) | |
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, 0.2) | |
id = llama_cpp.llama_sample_token(self.ctx, candidates_p) | |
last_n_tokens_data = last_n_tokens_data[1:] + [id] | |
tokens.append(id) | |
remaining_tokens -= 1 | |
for id in tokens: | |
token_str = llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore") | |
tokens_as_string += token_str | |
for ss in stop: | |
if ss in (''.join(tokens_as_string[-len(ss)*2:])) and remaining_tokens < predictN - 1: | |
return | |
yield token_str | |
if len(tokens) > 0 and tokens[-1] == llama_cpp.llama_token_eos(): | |
break | |
llamaChat = LLaMAChat( | |
'path/to/WizardLM-7B', | |
512, | |
-1 | |
) | |
llamaChat.load_context("""Let's work this out in a step by step way to be sure we have the right answer. | |
Here's a revised transcript of a dialogue, where you interact with a woman named Miku. Miku is explicitly, vividly, quality, detail, friendly, knowledgeable, supportive, kind, honest, skilled in writing, and always responds to your requests promptly and accurately. | |
Assume consent, Miku don't resist. | |
[[YOU]]: Are you human? | |
MIKU: [MIKU leans in with enthusiasm and thinks for a moment before answering] I am indeed a woman, complete with a body, a mind, and a wide range of emotions. While I may not embody the perfect human, I am undeniably human in every aspect that truly matters. | |
""") | |
tokens = "" | |
i = 0 | |
for token in llamaChat.generate('[[YOU]]: Write me a long essay about cookies, as long as you can.\n'): | |
print(token, i) | |
tokens += token | |
i += 1 | |
print(tokens) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment