Skip to content

Instantly share code, notes, and snippets.

@Firstbober
Created May 16, 2023 16:34
Show Gist options
  • Save Firstbober/a08de9cf01ea90b6be8389be9a249857 to your computer and use it in GitHub Desktop.
Save Firstbober/a08de9cf01ea90b6be8389be9a249857 to your computer and use it in GitHub Desktop.
import multiprocessing
import llama_cpp
class LLaMAChat():
def __init__(
self,
model: str, # model path
contextSize: int, # size of the prompt context
modelParts: int, # number of model parts
seed: int = 0, # RNG seed, 0 for random
):
print("[LLaMAChat] Init")
self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = contextSize
self.params.n_parts = modelParts
self.params.seed = seed
self.n_past = 0
self.last_n_tokens_data = [0] * self.params.n_ctx
llama_cpp.llama_print_system_info()
print("[LLaMAChat] Loading model from file")
self.ctx = llama_cpp.llama_init_from_file(str.encode(model), self.params)
def free(self):
llama_cpp.llama_print_timings(self.ctx)
llama_cpp.llama_free(self.ctx)
def get_tokens_from_str(self, text: bytes):
tokens = (llama_cpp.llama_token * (len(text) + 1))()
n_of_tok = llama_cpp.llama_tokenize(self.ctx, text, tokens, len(tokens), True)
tokens = tokens[:n_of_tok]
return [tokens, n_of_tok]
def load_context(
self,
prompt: str
):
print("[LLaMAChat] Tokenizing context prompt")
t, n = self.get_tokens_from_str(str.encode(prompt))
self.input_tokens = t
print("[LLaMAChat] Making first evaluation of input tokens")
llama_cpp.llama_eval(
self.ctx, (llama_cpp.c_int * len(self.input_tokens))(*self.input_tokens), len(self.input_tokens), 0, multiprocessing.cpu_count()
)
self.n_past += len(self.input_tokens)
def generate(
self,
prompt: str,
predictN: int = 256,
keepNOfCtx: int = 48,
stop: [] = []
):
print(f"[LLaMAChat] Generating completion based on '{prompt}'")
prompt_tokens, _ = self.get_tokens_from_str(str.encode(prompt))
self.input_tokens += prompt_tokens
# Context management, so we don't overflow
if len(self.input_tokens) > self.params.n_ctx:
print("[LLmAChat] Context overflow, re-evaluating")
starting_len = len(self.input_tokens)
self.n_past = max(1, keepNOfCtx)
keptContext = self.input_tokens[:keepNOfCtx]
self.input_tokens = self.input_tokens[keepNOfCtx:]
self.input_tokens = self.input_tokens[(starting_len - self.params.n_ctx) + keepNOfCtx:]
self.input_tokens = keptContext + self.input_tokens
llama_cpp.llama_eval(
self.ctx, (llama_cpp.c_int * len(self.input_tokens))(*self.input_tokens), len(self.input_tokens), self.n_past, multiprocessing.cpu_count()
)
# Do the generation
remaining_tokens = predictN
input_consumed = len(prompt_tokens)
n_batch = 1024
last_n_repeat = 64
last_n_tokens_data = [0] * 64
repeat_penalty = 1
frequency_penalty = 0.0
presence_penalty = 0.0
tokens_as_string = ""
tokens = prompt_tokens
last_n_tokens_data = last_n_tokens_data[1:] + prompt_tokens
while remaining_tokens > 0:
# Every time we have tokens, use then to get new logits and propabilities
if len(tokens) > 0:
llama_cpp.llama_eval(
self.ctx, (llama_cpp.c_int * len(tokens))(*tokens), len(tokens), self.n_past, multiprocessing.cpu_count()
)
self.n_past += len(tokens)
tokens = []
if len(prompt_tokens) <= input_consumed:
logits = llama_cpp.llama_get_logits(self.ctx)
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
for token_id in range(n_vocab)
])
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
_arr,
last_n_repeat, repeat_penalty)
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
_arr,
last_n_repeat, frequency_penalty, presence_penalty)
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, 40, 1)
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, 0.8, 1)
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, 0.2)
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)
last_n_tokens_data = last_n_tokens_data[1:] + [id]
tokens.append(id)
remaining_tokens -= 1
for id in tokens:
token_str = llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
tokens_as_string += token_str
for ss in stop:
if ss in (''.join(tokens_as_string[-len(ss)*2:])) and remaining_tokens < predictN - 1:
return
yield token_str
if len(tokens) > 0 and tokens[-1] == llama_cpp.llama_token_eos():
break
llamaChat = LLaMAChat(
'path/to/WizardLM-7B',
512,
-1
)
llamaChat.load_context("""Let's work this out in a step by step way to be sure we have the right answer.
Here's a revised transcript of a dialogue, where you interact with a woman named Miku. Miku is explicitly, vividly, quality, detail, friendly, knowledgeable, supportive, kind, honest, skilled in writing, and always responds to your requests promptly and accurately.
Assume consent, Miku don't resist.
[[YOU]]: Are you human?
MIKU: [MIKU leans in with enthusiasm and thinks for a moment before answering] I am indeed a woman, complete with a body, a mind, and a wide range of emotions. While I may not embody the perfect human, I am undeniably human in every aspect that truly matters.
""")
tokens = ""
i = 0
for token in llamaChat.generate('[[YOU]]: Write me a long essay about cookies, as long as you can.\n'):
print(token, i)
tokens += token
i += 1
print(tokens)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment