Created
March 2, 2024 17:20
-
-
Save maldevide/b48ed8189f6bf95c4331a16558f723cb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# engine/contextflow.py | |
from ctypes import c_float, c_size_t, c_void_p, c_char, c_int, c_uint8, c_int8, c_int32, pointer, byref | |
import logging | |
import multiprocessing | |
import numpy as np | |
import os | |
from typing import Any, List, Optional, Dict | |
import llama_cpp | |
from llama_cpp._internals import _LlamaTokenDataArray | |
from .grammar import load_grammar | |
from .output import OutputHandler | |
logger = logging.getLogger(__name__) | |
class EngineException(Exception): | |
"""An exception for errors in the FlowEngine class""" | |
def __init__(self, message, tokens): | |
super().__init__(message) | |
self.tokens = tokens | |
""" | |
void llama_batch_clear(struct llama_batch & batch) { | |
batch.n_tokens = 0; | |
} | |
void llama_batch_add( | |
struct llama_batch & batch, | |
llama_token id, | |
llama_pos pos, | |
const std::vector<llama_seq_id> & seq_ids, | |
bool logits) { | |
batch.token [batch.n_tokens] = id; | |
batch.pos [batch.n_tokens] = pos, | |
batch.n_seq_id[batch.n_tokens] = seq_ids.size(); | |
for (size_t i = 0; i < seq_ids.size(); ++i) { | |
batch.seq_id[batch.n_tokens][i] = seq_ids[i]; | |
} | |
batch.logits [batch.n_tokens] = logits; | |
batch.n_tokens++; | |
} | |
""" | |
def llama_batch_clear(batch : llama_cpp.llama_batch): | |
batch.n_tokens = 0 | |
def llama_batch_add(batch : llama_cpp.llama_batch, id : int, pos : int, seq_ids : list[int], logits : int): | |
batch.token[batch.n_tokens] = c_int32(id) | |
batch.pos[batch.n_tokens] = c_int32(pos) | |
batch.n_seq_id[batch.n_tokens] = c_int32(len(seq_ids)) | |
#for i in range(len(seq_ids)): | |
# batch.seq_id[batch.n_tokens][i] = c_int32(seq_ids[i]) | |
#batch.logits[batch.n_tokens] = c_int8(logits) | |
batch.n_tokens += 1 | |
def load_context(ctx : llama_cpp.llama_context_p, save_file : str, **kwargs) -> int: | |
if not os.path.exists(save_file): | |
logger.info(f"Error: {save_file} does not exist") | |
return -1 | |
# Load state from file | |
state_size = llama_cpp.llama_get_state_size(ctx) | |
logger.debug(f"Context: State size: {state_size}") | |
state_mem = (c_uint8 * state_size)() | |
with open(save_file, "rb") as fp_read: | |
bytes_read = fp_read.readinto(state_mem) | |
logger.debug(f"Context: Read {bytes_read} {state_size} bytes from {save_file}") | |
if bytes_read != state_size: | |
logger.error("Error: failed to read state") | |
return -1 | |
rc = llama_cpp.llama_set_state_data(ctx, state_mem) | |
return rc | |
def save_context(ctx : llama_cpp.llama_context_p, save_file : str, **kwargs) -> int: | |
state_size = llama_cpp.llama_get_state_size(ctx) | |
state_mem = (c_uint8 * state_size)() | |
rc = llama_cpp.llama_copy_state_data(ctx, state_mem) | |
if rc < 0: | |
logger.error("Failed to copy state data") | |
return rc | |
# Save the data to a binary file | |
with open(save_file, "wb") as fp: | |
fp.write(state_mem) | |
llama_cpp.llama_set_state_data(ctx, state_mem) | |
return rc | |
def cache_context(ctx : llama_cpp.llama_context_p, **kwargs) -> Any: | |
state_size = llama_cpp.llama_get_state_size(ctx) | |
llama_state = (llama_cpp.c_uint8 * int(state_size))() | |
n_bytes = llama_cpp.llama_copy_state_data(ctx, llama_state) | |
if int(n_bytes) > int(state_size): | |
raise RuntimeError("Failed to copy llama state data") | |
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))() | |
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes)) | |
return llama_state_compact | |
def restore_context(ctx : llama_cpp.llama_context_p, state : Any, **kwargs) -> int: | |
if state is None: | |
return -1 | |
state_size = len(state) | |
llama_state = (llama_cpp.c_uint8 * int(state_size))() | |
llama_cpp.ctypes.memmove(llama_state, state, int(state_size)) | |
rc = llama_cpp.llama_set_state_data(ctx, llama_state) | |
return rc | |
class StateContainer: | |
def __init__(self, last_n_size : int = 64, **kwargs : Dict[str, Any]): | |
self.tokens : List[llama_cpp.llama_token] = [] | |
self.n_past = 0 | |
self.last_n_size = last_n_size | |
self.last_n_tokens_data = [0] * self.last_n_size | |
self.prompt = "" | |
def reset(self, **kwargs): | |
self.n_past = 0 | |
self.tokens = [] | |
self.last_n_tokens_data = [0] * self.last_n_size | |
self.prompt = "" | |
def clone(self, **kwargs): | |
new_state = StateContainer(last_n_size=self.last_n_size, **kwargs) | |
new_state.n_past = self.n_past | |
new_state.tokens = self.tokens.copy() | |
new_state.last_n_tokens_data = self.last_n_tokens_data.copy() | |
new_state.prompt = self.prompt | |
return new_state | |
def add_tokens(self, tokens : List[llama_cpp.llama_token], **kwargs): | |
self.tokens += tokens | |
self.n_past += len(tokens) | |
def set_prompt(self, prompt : str, **kwargs): | |
self.prompt = prompt | |
def commit_prompt(self, prompt : str, **kwargs): | |
self.prompt += prompt | |
class ContextContainer: | |
def __init__(self, grammar: Optional[llama_cpp.LlamaGrammar] = None, **kwargs : Dict[str, Any]): | |
self.grammar = grammar | |
self.mem_cache : Optional[Any] = None | |
self.config = kwargs | |
self.session = StateContainer(**kwargs) | |
self.cache = StateContainer(**kwargs) | |
@property | |
def n_past(self) -> int: | |
return self.session.n_past | |
def reset(self, ctx : llama_cpp.llama_context_p, reset_cache: bool = False, **kwargs : Dict[str, Any]): | |
if reset_cache: | |
self.cache.reset(**kwargs) | |
self.mem_cache = None | |
self.read_checkpoint(ctx, **kwargs) | |
def set_checkpoint(self, ctx: llama_cpp.llama_context_p, **kwargs) -> bool: | |
"""Set a checkpoint for the current state""" | |
self.mem_cache = cache_context(ctx, **kwargs) | |
self.cache = self.session.clone(**kwargs) | |
return True | |
def read_checkpoint(self, ctx: llama_cpp.llama_context_p, **kwargs) -> bool: | |
"""Load a checkpoint for the current state""" | |
restore_context(ctx, self.mem_cache, **kwargs) | |
self.session = self.cache.clone(**kwargs) | |
return True | |
def clear_saved_context(self, **kwargs) -> int: | |
"""clear the saved context""" | |
self.mem_cache = None | |
def add_tokens(self, tokens : List[llama_cpp.llama_token], **kwargs): | |
self.session.add_tokens(tokens, **kwargs) | |
def accept_token(self, token_id : int, ctx: llama_cpp.llama_context_p, **kwargs): | |
self.session.add_tokens([token_id], **kwargs) | |
if self.grammar is not None: | |
llama_cpp.llama_grammar_accept_token(ctx=ctx, token=llama_cpp.llama_token(token_id), grammar=self.grammar.grammar) | |
self.set_checkpoint(ctx, **kwargs) | |
@classmethod | |
def from_config(cls, grammar_path : str, grammar_type : Optional[str] = None, **kwargs): | |
""" | |
Create a new ContextContainer from the given model | |
Args: | |
grammar_path (str): The path to the grammar file | |
grammar_type (str, optional): The type of grammar to use. Defaults to None. | |
""" | |
grammar = None | |
if grammar_type is not None: | |
grammar = load_grammar(grammar_file=f'{grammar_type}.gbnf', grammar_path=grammar_path) | |
return cls(grammar=grammar, **kwargs) | |
class FlowEngine: | |
""" | |
FlowEngine is a wrapper around the llama_cpp library that provides a simple interface for reading and writing to the model | |
""" | |
@staticmethod | |
def get_mparams( | |
n_gpu_layers: int = 0, | |
main_gpu: int = 0, | |
tensor_split: Optional[List[float]] = None, | |
vocab_only: bool = False, | |
use_mmap: bool = True, | |
use_mlock: bool = False, | |
**kwargs | |
): | |
"""Generate a llama_model_params struct with the given parameters""" | |
mparams = llama_cpp.llama_model_default_params() | |
mparams.n_gpu_layers = ( | |
0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers | |
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers | |
mparams.main_gpu = main_gpu | |
mparams.vocab_only = vocab_only | |
mparams.use_mmap = use_mmap | |
mparams.use_mlock = use_mlock | |
return mparams | |
@classmethod | |
def ctx_from_model(cls, model : llama_cpp.llama_model_p, n_ctx : int, **kwargs): | |
cparams = cls.get_cparams(n_ctx=n_ctx, **kwargs) | |
ctx = llama_cpp.llama_new_context_with_model(model, cparams) | |
return ctx | |
@staticmethod | |
def get_cparams( | |
seed: int = llama_cpp.LLAMA_DEFAULT_SEED, | |
n_ctx: int = 4096, | |
n_batch: int = 512, | |
n_threads: Optional[int] = None, | |
n_threads_batch: Optional[int] = None, | |
rope_freq_base: float = 0.0, | |
rope_freq_scale: float = 0.0, | |
mul_mat_q: bool = True, | |
f16_kv: bool = True, | |
logits_all: bool = False, | |
embedding: bool = False, | |
**kwargs | |
): | |
"""Generate a llama_context_params struct with the given parameters""" | |
cparams = llama_cpp.llama_context_default_params() | |
n_batch = min(n_ctx, n_batch) # We don't want a batch being larger than our context | |
n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) | |
n_threads_batch = n_threads_batch or max( | |
multiprocessing.cpu_count() // 2, 1 | |
) | |
cparams.seed = seed | |
cparams.n_ctx = n_ctx | |
cparams.n_batch = n_batch | |
cparams.n_threads = n_threads | |
cparams.n_threads_batch = n_threads_batch | |
cparams.rope_freq_base = ( | |
rope_freq_base if rope_freq_base != 0.0 else 0 | |
) | |
cparams.rope_freq_scale = ( | |
rope_freq_scale if rope_freq_scale != 0.0 else 0 | |
) | |
cparams.mul_mat_q = mul_mat_q | |
cparams.f16_kv = f16_kv | |
cparams.logits_all = logits_all | |
cparams.embedding = embedding | |
return cparams | |
@classmethod | |
def from_config(cls, model_path : str, model_file : str, n_ctx : int, output : Optional[OutputHandler] = None, **kwargs): | |
"""Create a new FlowEngine with the given parameters""" | |
llama_cpp.llama_backend_init(numa=False) | |
model_loc = os.path.join(model_path, model_file) | |
mparams = cls.get_mparams(**kwargs) | |
model = llama_cpp.llama_load_model_from_file(model_loc.encode('utf-8'), mparams) | |
ctx = cls.ctx_from_model(model, n_ctx, **kwargs) | |
return cls(model=model, ctx=ctx, n_ctx=n_ctx, output=output) | |
def __init__(self, model : c_void_p, ctx : llama_cpp.llama_context_p, n_ctx : int, output : Optional[OutputHandler] = None): | |
self.model = model | |
self.ctx = ctx | |
self.n_ctx = n_ctx | |
self.output = output | |
def set_output_handler(self, output : OutputHandler): | |
self.output = output | |
def feed(self, prompt : str, cc : ContextContainer, n_batch : int, scope : Optional[str] = None, show_progress : bool = False, **kwargs) -> int: | |
"""Feed the given prompt to the model""" | |
if prompt is None: | |
logger.warning(f"Feeding empty prompt") | |
return -1 | |
logger.debug(f"Feeding {prompt}") | |
kwargs['n_ctx'] = self.n_ctx | |
b_prompt = prompt.encode('ascii', 'ignore') | |
b_prompt = b" " + b_prompt | |
pl = len(b_prompt) | |
# I hate that we alloc all of this extra space, but otherwise we overrun our buffer | |
embd_inp = (llama_cpp.llama_token * (pl + 1))() | |
n_of_tok = llama_cpp.llama_tokenize( | |
model=self.model, text=b_prompt, text_len=pl, tokens=embd_inp, n_max_tokens=embd_inp._length_, | |
add_bos=True, special=False) | |
embd_inp = embd_inp[:n_of_tok] | |
clearance = self.n_ctx - cc.n_past - n_of_tok - 100 | |
if clearance < 0: | |
raise EngineException("Too many tokens in prompt", clearance) | |
# This is a hack to make sure we don't overrun our buffer | |
# Lets create an offset that clips to the last n_ctx - 100 tokens | |
n_ctx_floor = self.n_ctx - 100 | |
n_ctx_floor = n_ctx_floor if n_of_tok > n_ctx_floor else n_of_tok | |
embd_inp = embd_inp[-n_ctx_floor:] | |
input_consumed = 0 | |
first_n = cc.n_past | |
logger.debug(f"Feeding ({pl} chars -> {n_of_tok} tokens), {input_consumed} consumed, {len(embd_inp)} remaining") | |
logger.debug(f"```{prompt}```") | |
if self.output is not None and show_progress: | |
if scope is not None: | |
self.output.handle_token(f"{scope} - ") | |
self.output.handle_progress(0.0) | |
embd = [] | |
while len(embd_inp) > input_consumed: | |
cc.session.last_n_tokens_data.copy() | |
while len(embd_inp) > input_consumed: | |
if len(embd) >= n_batch: | |
break | |
embd.append(embd_inp[input_consumed]) | |
cc.session.last_n_tokens_data = cc.session.last_n_tokens_data[1:] + [embd_inp[input_consumed]] | |
input_consumed += 1 | |
if len(embd) > 0: | |
# Docs say to use llama_decode and llama_batch | |
tokens = (llama_cpp.llama_token * len(embd))(*embd) | |
logger.debug(f"Writing to model {len(embd)} tokens, {input_consumed} consumed") | |
return_code = llama_cpp.llama_eval(ctx=self.ctx, tokens=tokens, n_tokens=len(embd), n_past=cc.n_past) | |
if return_code != 0: | |
logger.error(f"Break - Model Eval return code {return_code}") | |
break | |
cc.add_tokens(tokens=tokens, **kwargs) | |
embd = [] | |
if self.output is not None and show_progress: | |
self.output.handle_progress(float(input_consumed) / len(embd_inp)) | |
cc.session.commit_prompt(prompt) | |
return cc.n_past - first_n | |
def read(self, cc : ContextContainer, max_tokens : int = 512, abort_tokens : list = [], stop_tokens : list = [], | |
sequence_tokens : list = [], log_chunk_length : int = 25, n_temp: float = 0.7, | |
mirostat: int = 0, mirostat_tau : float = 0, mirostat_eta : float = 0, top_k: int = 40, | |
min_p: float = 0.0, min_keep: int = 1, | |
n_tfs_z: float = 0.0, n_typical_p: float = 0.0, n_top_p: float = 0.0, | |
penalty_last_n: int = 1024, penalty_repeat: float = 1.08, penalty_freq: float = 0.0, penalty_present: float = 0.0, | |
**kwargs) -> Optional[List[Any]]: | |
"""Read from the model until the given number of tokens is reached""" | |
remaining_tokens = max_tokens | |
stop_set = set(stop_tokens) | |
abort_set = set(abort_tokens) | |
sequence_set = set([tuple(o) for o in sequence_tokens]) | |
response_tokens = [] | |
n_generated = 0 | |
buf = (c_char * 32)() | |
log_chunks = [] | |
log_ids = [] | |
last_piece = '' | |
last_id = 0 | |
n_vocab = llama_cpp.llama_n_vocab(self.model) | |
nl_token = llama_cpp.llama_token_nl(self.model) | |
token_data_array = _LlamaTokenDataArray( | |
n_vocab=n_vocab | |
) # TODO: Only create this once | |
try: | |
while remaining_tokens > 0: | |
# Mirroring llama.cpp/common/sampling.cpp | |
logits = llama_cpp.llama_get_logits(self.ctx) | |
n_vocab = llama_cpp.llama_n_vocab(self.model) | |
logits_array = np.array( | |
llama_cpp.ctypes.cast(logits, llama_cpp.ctypes.POINTER(llama_cpp.ctypes.c_float * n_vocab)).contents, | |
dtype=np.single, | |
) | |
token_data_array.copy_logits(logits_array) | |
logit_arr = (llama_cpp.llama_token_data * n_vocab)(*[ | |
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0) | |
for token_id in range(n_vocab) | |
]) | |
candidates_p = pointer(llama_cpp.llama_token_data_array(logit_arr, len(logit_arr), False)) | |
nl_logit = logits[nl_token] | |
last_n_arr = (c_int * len(cc.session.last_n_tokens_data))(*cc.session.last_n_tokens_data) | |
llama_cpp.llama_sample_repetition_penalties(ctx=self.ctx, candidates=candidates_p, last_tokens_data=last_n_arr, | |
penalty_last_n=c_size_t(penalty_last_n), penalty_repeat=c_float(penalty_repeat), | |
penalty_freq=c_float(penalty_freq), penalty_present=c_float(penalty_present)) | |
# This lovely little hack gets rid of our newline penalty | |
logit_arr[nl_token] = llama_cpp.llama_token_data(nl_token, nl_logit, 0.0) | |
if cc.grammar is not None and cc.grammar.grammar is not None: | |
llama_cpp.llama_sample_grammar(ctx=self.ctx, candidates=candidates_p, grammar=cc.grammar.grammar) | |
if min_p > 0.0: | |
llama_cpp.llama_sample_min_p(ctx=self.ctx, candidates=candidates_p, p=c_float(min_p), min_keep=min_keep) | |
if n_temp < 0.0: | |
id = llama_cpp.llama_sample_softmax(ctx=self.ctx, candidates=candidates_p) | |
elif n_temp == 0: | |
# Greedy sampling | |
id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p) | |
elif mirostat == 1: | |
mirostat_mu = 2.0 * mirostat_tau | |
mirostat_m = 100 | |
llama_cpp.llama_sample_temp(ctx=self.ctx, candidates=candidates_p, temp=c_float(n_temp)) | |
id = llama_cpp.llama_sample_token_mirostat(ctx=self.ctx, candidates=candidates_p, | |
tau=c_float(mirostat_tau), eta=c_float(mirostat_eta), m=c_size_t(mirostat_m), mu=c_float(mirostat_mu)) | |
elif mirostat == 2: | |
mirostat_mu = 2.0 * mirostat_tau | |
llama_cpp.llama_sample_temp(ctx=self.ctx, candidates=candidates_p, temp=c_float(n_temp)) | |
id = llama_cpp.llama_sample_token_mirostat_v2(ctx=self.ctx, candidates=candidates_p, | |
tau=c_float(mirostat_tau), eta=c_float(mirostat_eta), mu=c_float(mirostat_mu)) | |
else: | |
# Temperature sampling | |
min_keep = c_size_t(1) | |
llama_cpp.llama_sample_top_k(ctx=self.ctx, candidates=candidates_p, | |
k=top_k, min_keep=min_keep) | |
llama_cpp.llama_sample_tail_free(ctx=self.ctx, candidates=candidates_p, | |
z=c_float(n_tfs_z), min_keep=min_keep) | |
llama_cpp.llama_sample_typical(ctx=self.ctx, candidates=candidates_p, | |
p=c_float(n_typical_p), min_keep=min_keep) | |
llama_cpp.llama_sample_top_p(ctx=self.ctx, candidates=candidates_p, | |
p=c_float(n_top_p), min_keep=min_keep) | |
llama_cpp.llama_sample_min_p(ctx=self.ctx, candidates=candidates_p, | |
p=c_float(n_top_p), min_keep=min_keep) | |
llama_cpp.llama_sample_temp(ctx=self.ctx, candidates=candidates_p, | |
temp=c_float(n_temp)) | |
id = llama_cpp.llama_sample_token(self.ctx, candidates_p) | |
n = llama_cpp.llama_token_to_piece( | |
self.model, llama_cpp.llama_token(id), buf, 32 | |
) | |
piece = buf[:n].decode('utf-8', 'ignore') | |
cc.session.last_n_tokens_data = cc.session.last_n_tokens_data[1:] + [id] | |
running = True | |
if piece in abort_set: | |
logger.debug(f"Break ({len(log_chunks)}): Aborting on {piece} ({id})") | |
running = False | |
# TODO Do I need to inject a newline in-context here? | |
id = None | |
return response_tokens | |
elif id == 2 and (n_generated == 0 or last_id == 13): | |
# 2 is '', repeating, this is bad model output. | |
running = False | |
id = None | |
return response_tokens | |
elif (last_piece, piece) in sequence_set: | |
logger.debug(f"Break ({len(log_chunks)}): sequence {last_piece}, {piece} ({id})") | |
running = False | |
id = None | |
elif piece == '\n' and last_piece == '\n': | |
logger.debug(f"Break ({len(log_chunks)}): Double Newline ({id})") | |
running = False | |
id = None | |
elif id == llama_cpp.llama_token_eos(self.ctx): | |
logger.debug(f"Break ({len(log_chunks)}): EOS ({id})") | |
running = False | |
# TODO Do I need to inject a newline in-context here? | |
id = 13 | |
if id is not None: | |
#tokens = (llama_cpp.llama_token * 1)(id) | |
#return_code = llama_cpp.llama_decode(ctx=self.ctx, batch=llama_cpp.llama_batch_get_one(tokens=tokens, n_tokens=1, pos_0=self.n_past, seq_id=0)) | |
tokens = (llama_cpp.llama_token * 1)(id) | |
return_code = llama_cpp.llama_eval(ctx=self.ctx, tokens=tokens, n_tokens=1, n_past=cc.session.n_past) | |
log_chunks.append(piece) | |
log_ids.append(id) | |
if return_code != 0: | |
logger.error(f"Break - Model Eval return code {return_code}") | |
running = False | |
else: | |
cc.session.n_past += 1 | |
n_generated += 1 | |
response_tokens.append(piece) | |
if self.output is not None: | |
self.output.handle_token(piece) | |
remaining_tokens -= 1 | |
last_piece = piece | |
last_id = id | |
cc.accept_token(id, self.ctx) | |
if piece in stop_set: | |
running = False | |
if len(log_chunks) > 0 and (not running or len(log_chunks) % log_chunk_length == 0): | |
# Generally superseded by the output handler | |
#logger.debug(f"Generated ({n_generated}): {''.join(log_chunks).strip()}") | |
#logger.debug(f"Tokens ({n_generated}): {log_chunks}, {log_ids}") | |
log_chunks = [] | |
if not running: | |
break | |
finally: | |
#llama_cpp.llama_batch_free(llama_batch) | |
pass | |
cc.session.commit_prompt(''.join(response_tokens)) | |
return response_tokens | |
def __del__(self): | |
llama_cpp.llama_free(self.ctx) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment