Skip to content

Instantly share code, notes, and snippets.

@maldevide
Created March 2, 2024 17:20
Show Gist options
  • Save maldevide/b48ed8189f6bf95c4331a16558f723cb to your computer and use it in GitHub Desktop.
Save maldevide/b48ed8189f6bf95c4331a16558f723cb to your computer and use it in GitHub Desktop.
# engine/contextflow.py
from ctypes import c_float, c_size_t, c_void_p, c_char, c_int, c_uint8, c_int8, c_int32, pointer, byref
import logging
import multiprocessing
import numpy as np
import os
from typing import Any, List, Optional, Dict
import llama_cpp
from llama_cpp._internals import _LlamaTokenDataArray
from .grammar import load_grammar
from .output import OutputHandler
logger = logging.getLogger(__name__)
class EngineException(Exception):
"""An exception for errors in the FlowEngine class"""
def __init__(self, message, tokens):
super().__init__(message)
self.tokens = tokens
"""
void llama_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos,
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.n_tokens++;
}
"""
def llama_batch_clear(batch : llama_cpp.llama_batch):
batch.n_tokens = 0
def llama_batch_add(batch : llama_cpp.llama_batch, id : int, pos : int, seq_ids : list[int], logits : int):
batch.token[batch.n_tokens] = c_int32(id)
batch.pos[batch.n_tokens] = c_int32(pos)
batch.n_seq_id[batch.n_tokens] = c_int32(len(seq_ids))
#for i in range(len(seq_ids)):
# batch.seq_id[batch.n_tokens][i] = c_int32(seq_ids[i])
#batch.logits[batch.n_tokens] = c_int8(logits)
batch.n_tokens += 1
def load_context(ctx : llama_cpp.llama_context_p, save_file : str, **kwargs) -> int:
if not os.path.exists(save_file):
logger.info(f"Error: {save_file} does not exist")
return -1
# Load state from file
state_size = llama_cpp.llama_get_state_size(ctx)
logger.debug(f"Context: State size: {state_size}")
state_mem = (c_uint8 * state_size)()
with open(save_file, "rb") as fp_read:
bytes_read = fp_read.readinto(state_mem)
logger.debug(f"Context: Read {bytes_read} {state_size} bytes from {save_file}")
if bytes_read != state_size:
logger.error("Error: failed to read state")
return -1
rc = llama_cpp.llama_set_state_data(ctx, state_mem)
return rc
def save_context(ctx : llama_cpp.llama_context_p, save_file : str, **kwargs) -> int:
state_size = llama_cpp.llama_get_state_size(ctx)
state_mem = (c_uint8 * state_size)()
rc = llama_cpp.llama_copy_state_data(ctx, state_mem)
if rc < 0:
logger.error("Failed to copy state data")
return rc
# Save the data to a binary file
with open(save_file, "wb") as fp:
fp.write(state_mem)
llama_cpp.llama_set_state_data(ctx, state_mem)
return rc
def cache_context(ctx : llama_cpp.llama_context_p, **kwargs) -> Any:
state_size = llama_cpp.llama_get_state_size(ctx)
llama_state = (llama_cpp.c_uint8 * int(state_size))()
n_bytes = llama_cpp.llama_copy_state_data(ctx, llama_state)
if int(n_bytes) > int(state_size):
raise RuntimeError("Failed to copy llama state data")
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
return llama_state_compact
def restore_context(ctx : llama_cpp.llama_context_p, state : Any, **kwargs) -> int:
if state is None:
return -1
state_size = len(state)
llama_state = (llama_cpp.c_uint8 * int(state_size))()
llama_cpp.ctypes.memmove(llama_state, state, int(state_size))
rc = llama_cpp.llama_set_state_data(ctx, llama_state)
return rc
class StateContainer:
def __init__(self, last_n_size : int = 64, **kwargs : Dict[str, Any]):
self.tokens : List[llama_cpp.llama_token] = []
self.n_past = 0
self.last_n_size = last_n_size
self.last_n_tokens_data = [0] * self.last_n_size
self.prompt = ""
def reset(self, **kwargs):
self.n_past = 0
self.tokens = []
self.last_n_tokens_data = [0] * self.last_n_size
self.prompt = ""
def clone(self, **kwargs):
new_state = StateContainer(last_n_size=self.last_n_size, **kwargs)
new_state.n_past = self.n_past
new_state.tokens = self.tokens.copy()
new_state.last_n_tokens_data = self.last_n_tokens_data.copy()
new_state.prompt = self.prompt
return new_state
def add_tokens(self, tokens : List[llama_cpp.llama_token], **kwargs):
self.tokens += tokens
self.n_past += len(tokens)
def set_prompt(self, prompt : str, **kwargs):
self.prompt = prompt
def commit_prompt(self, prompt : str, **kwargs):
self.prompt += prompt
class ContextContainer:
def __init__(self, grammar: Optional[llama_cpp.LlamaGrammar] = None, **kwargs : Dict[str, Any]):
self.grammar = grammar
self.mem_cache : Optional[Any] = None
self.config = kwargs
self.session = StateContainer(**kwargs)
self.cache = StateContainer(**kwargs)
@property
def n_past(self) -> int:
return self.session.n_past
def reset(self, ctx : llama_cpp.llama_context_p, reset_cache: bool = False, **kwargs : Dict[str, Any]):
if reset_cache:
self.cache.reset(**kwargs)
self.mem_cache = None
self.read_checkpoint(ctx, **kwargs)
def set_checkpoint(self, ctx: llama_cpp.llama_context_p, **kwargs) -> bool:
"""Set a checkpoint for the current state"""
self.mem_cache = cache_context(ctx, **kwargs)
self.cache = self.session.clone(**kwargs)
return True
def read_checkpoint(self, ctx: llama_cpp.llama_context_p, **kwargs) -> bool:
"""Load a checkpoint for the current state"""
restore_context(ctx, self.mem_cache, **kwargs)
self.session = self.cache.clone(**kwargs)
return True
def clear_saved_context(self, **kwargs) -> int:
"""clear the saved context"""
self.mem_cache = None
def add_tokens(self, tokens : List[llama_cpp.llama_token], **kwargs):
self.session.add_tokens(tokens, **kwargs)
def accept_token(self, token_id : int, ctx: llama_cpp.llama_context_p, **kwargs):
self.session.add_tokens([token_id], **kwargs)
if self.grammar is not None:
llama_cpp.llama_grammar_accept_token(ctx=ctx, token=llama_cpp.llama_token(token_id), grammar=self.grammar.grammar)
self.set_checkpoint(ctx, **kwargs)
@classmethod
def from_config(cls, grammar_path : str, grammar_type : Optional[str] = None, **kwargs):
"""
Create a new ContextContainer from the given model
Args:
grammar_path (str): The path to the grammar file
grammar_type (str, optional): The type of grammar to use. Defaults to None.
"""
grammar = None
if grammar_type is not None:
grammar = load_grammar(grammar_file=f'{grammar_type}.gbnf', grammar_path=grammar_path)
return cls(grammar=grammar, **kwargs)
class FlowEngine:
"""
FlowEngine is a wrapper around the llama_cpp library that provides a simple interface for reading and writing to the model
"""
@staticmethod
def get_mparams(
n_gpu_layers: int = 0,
main_gpu: int = 0,
tensor_split: Optional[List[float]] = None,
vocab_only: bool = False,
use_mmap: bool = True,
use_mlock: bool = False,
**kwargs
):
"""Generate a llama_model_params struct with the given parameters"""
mparams = llama_cpp.llama_model_default_params()
mparams.n_gpu_layers = (
0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
mparams.main_gpu = main_gpu
mparams.vocab_only = vocab_only
mparams.use_mmap = use_mmap
mparams.use_mlock = use_mlock
return mparams
@classmethod
def ctx_from_model(cls, model : llama_cpp.llama_model_p, n_ctx : int, **kwargs):
cparams = cls.get_cparams(n_ctx=n_ctx, **kwargs)
ctx = llama_cpp.llama_new_context_with_model(model, cparams)
return ctx
@staticmethod
def get_cparams(
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
n_ctx: int = 4096,
n_batch: int = 512,
n_threads: Optional[int] = None,
n_threads_batch: Optional[int] = None,
rope_freq_base: float = 0.0,
rope_freq_scale: float = 0.0,
mul_mat_q: bool = True,
f16_kv: bool = True,
logits_all: bool = False,
embedding: bool = False,
**kwargs
):
"""Generate a llama_context_params struct with the given parameters"""
cparams = llama_cpp.llama_context_default_params()
n_batch = min(n_ctx, n_batch) # We don't want a batch being larger than our context
n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1
)
cparams.seed = seed
cparams.n_ctx = n_ctx
cparams.n_batch = n_batch
cparams.n_threads = n_threads
cparams.n_threads_batch = n_threads_batch
cparams.rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
cparams.rope_freq_scale = (
rope_freq_scale if rope_freq_scale != 0.0 else 0
)
cparams.mul_mat_q = mul_mat_q
cparams.f16_kv = f16_kv
cparams.logits_all = logits_all
cparams.embedding = embedding
return cparams
@classmethod
def from_config(cls, model_path : str, model_file : str, n_ctx : int, output : Optional[OutputHandler] = None, **kwargs):
"""Create a new FlowEngine with the given parameters"""
llama_cpp.llama_backend_init(numa=False)
model_loc = os.path.join(model_path, model_file)
mparams = cls.get_mparams(**kwargs)
model = llama_cpp.llama_load_model_from_file(model_loc.encode('utf-8'), mparams)
ctx = cls.ctx_from_model(model, n_ctx, **kwargs)
return cls(model=model, ctx=ctx, n_ctx=n_ctx, output=output)
def __init__(self, model : c_void_p, ctx : llama_cpp.llama_context_p, n_ctx : int, output : Optional[OutputHandler] = None):
self.model = model
self.ctx = ctx
self.n_ctx = n_ctx
self.output = output
def set_output_handler(self, output : OutputHandler):
self.output = output
def feed(self, prompt : str, cc : ContextContainer, n_batch : int, scope : Optional[str] = None, show_progress : bool = False, **kwargs) -> int:
"""Feed the given prompt to the model"""
if prompt is None:
logger.warning(f"Feeding empty prompt")
return -1
logger.debug(f"Feeding {prompt}")
kwargs['n_ctx'] = self.n_ctx
b_prompt = prompt.encode('ascii', 'ignore')
b_prompt = b" " + b_prompt
pl = len(b_prompt)
# I hate that we alloc all of this extra space, but otherwise we overrun our buffer
embd_inp = (llama_cpp.llama_token * (pl + 1))()
n_of_tok = llama_cpp.llama_tokenize(
model=self.model, text=b_prompt, text_len=pl, tokens=embd_inp, n_max_tokens=embd_inp._length_,
add_bos=True, special=False)
embd_inp = embd_inp[:n_of_tok]
clearance = self.n_ctx - cc.n_past - n_of_tok - 100
if clearance < 0:
raise EngineException("Too many tokens in prompt", clearance)
# This is a hack to make sure we don't overrun our buffer
# Lets create an offset that clips to the last n_ctx - 100 tokens
n_ctx_floor = self.n_ctx - 100
n_ctx_floor = n_ctx_floor if n_of_tok > n_ctx_floor else n_of_tok
embd_inp = embd_inp[-n_ctx_floor:]
input_consumed = 0
first_n = cc.n_past
logger.debug(f"Feeding ({pl} chars -> {n_of_tok} tokens), {input_consumed} consumed, {len(embd_inp)} remaining")
logger.debug(f"```{prompt}```")
if self.output is not None and show_progress:
if scope is not None:
self.output.handle_token(f"{scope} - ")
self.output.handle_progress(0.0)
embd = []
while len(embd_inp) > input_consumed:
cc.session.last_n_tokens_data.copy()
while len(embd_inp) > input_consumed:
if len(embd) >= n_batch:
break
embd.append(embd_inp[input_consumed])
cc.session.last_n_tokens_data = cc.session.last_n_tokens_data[1:] + [embd_inp[input_consumed]]
input_consumed += 1
if len(embd) > 0:
# Docs say to use llama_decode and llama_batch
tokens = (llama_cpp.llama_token * len(embd))(*embd)
logger.debug(f"Writing to model {len(embd)} tokens, {input_consumed} consumed")
return_code = llama_cpp.llama_eval(ctx=self.ctx, tokens=tokens, n_tokens=len(embd), n_past=cc.n_past)
if return_code != 0:
logger.error(f"Break - Model Eval return code {return_code}")
break
cc.add_tokens(tokens=tokens, **kwargs)
embd = []
if self.output is not None and show_progress:
self.output.handle_progress(float(input_consumed) / len(embd_inp))
cc.session.commit_prompt(prompt)
return cc.n_past - first_n
def read(self, cc : ContextContainer, max_tokens : int = 512, abort_tokens : list = [], stop_tokens : list = [],
sequence_tokens : list = [], log_chunk_length : int = 25, n_temp: float = 0.7,
mirostat: int = 0, mirostat_tau : float = 0, mirostat_eta : float = 0, top_k: int = 40,
min_p: float = 0.0, min_keep: int = 1,
n_tfs_z: float = 0.0, n_typical_p: float = 0.0, n_top_p: float = 0.0,
penalty_last_n: int = 1024, penalty_repeat: float = 1.08, penalty_freq: float = 0.0, penalty_present: float = 0.0,
**kwargs) -> Optional[List[Any]]:
"""Read from the model until the given number of tokens is reached"""
remaining_tokens = max_tokens
stop_set = set(stop_tokens)
abort_set = set(abort_tokens)
sequence_set = set([tuple(o) for o in sequence_tokens])
response_tokens = []
n_generated = 0
buf = (c_char * 32)()
log_chunks = []
log_ids = []
last_piece = ''
last_id = 0
n_vocab = llama_cpp.llama_n_vocab(self.model)
nl_token = llama_cpp.llama_token_nl(self.model)
token_data_array = _LlamaTokenDataArray(
n_vocab=n_vocab
) # TODO: Only create this once
try:
while remaining_tokens > 0:
# Mirroring llama.cpp/common/sampling.cpp
logits = llama_cpp.llama_get_logits(self.ctx)
n_vocab = llama_cpp.llama_n_vocab(self.model)
logits_array = np.array(
llama_cpp.ctypes.cast(logits, llama_cpp.ctypes.POINTER(llama_cpp.ctypes.c_float * n_vocab)).contents,
dtype=np.single,
)
token_data_array.copy_logits(logits_array)
logit_arr = (llama_cpp.llama_token_data * n_vocab)(*[
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
for token_id in range(n_vocab)
])
candidates_p = pointer(llama_cpp.llama_token_data_array(logit_arr, len(logit_arr), False))
nl_logit = logits[nl_token]
last_n_arr = (c_int * len(cc.session.last_n_tokens_data))(*cc.session.last_n_tokens_data)
llama_cpp.llama_sample_repetition_penalties(ctx=self.ctx, candidates=candidates_p, last_tokens_data=last_n_arr,
penalty_last_n=c_size_t(penalty_last_n), penalty_repeat=c_float(penalty_repeat),
penalty_freq=c_float(penalty_freq), penalty_present=c_float(penalty_present))
# This lovely little hack gets rid of our newline penalty
logit_arr[nl_token] = llama_cpp.llama_token_data(nl_token, nl_logit, 0.0)
if cc.grammar is not None and cc.grammar.grammar is not None:
llama_cpp.llama_sample_grammar(ctx=self.ctx, candidates=candidates_p, grammar=cc.grammar.grammar)
if min_p > 0.0:
llama_cpp.llama_sample_min_p(ctx=self.ctx, candidates=candidates_p, p=c_float(min_p), min_keep=min_keep)
if n_temp < 0.0:
id = llama_cpp.llama_sample_softmax(ctx=self.ctx, candidates=candidates_p)
elif n_temp == 0:
# Greedy sampling
id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
elif mirostat == 1:
mirostat_mu = 2.0 * mirostat_tau
mirostat_m = 100
llama_cpp.llama_sample_temp(ctx=self.ctx, candidates=candidates_p, temp=c_float(n_temp))
id = llama_cpp.llama_sample_token_mirostat(ctx=self.ctx, candidates=candidates_p,
tau=c_float(mirostat_tau), eta=c_float(mirostat_eta), m=c_size_t(mirostat_m), mu=c_float(mirostat_mu))
elif mirostat == 2:
mirostat_mu = 2.0 * mirostat_tau
llama_cpp.llama_sample_temp(ctx=self.ctx, candidates=candidates_p, temp=c_float(n_temp))
id = llama_cpp.llama_sample_token_mirostat_v2(ctx=self.ctx, candidates=candidates_p,
tau=c_float(mirostat_tau), eta=c_float(mirostat_eta), mu=c_float(mirostat_mu))
else:
# Temperature sampling
min_keep = c_size_t(1)
llama_cpp.llama_sample_top_k(ctx=self.ctx, candidates=candidates_p,
k=top_k, min_keep=min_keep)
llama_cpp.llama_sample_tail_free(ctx=self.ctx, candidates=candidates_p,
z=c_float(n_tfs_z), min_keep=min_keep)
llama_cpp.llama_sample_typical(ctx=self.ctx, candidates=candidates_p,
p=c_float(n_typical_p), min_keep=min_keep)
llama_cpp.llama_sample_top_p(ctx=self.ctx, candidates=candidates_p,
p=c_float(n_top_p), min_keep=min_keep)
llama_cpp.llama_sample_min_p(ctx=self.ctx, candidates=candidates_p,
p=c_float(n_top_p), min_keep=min_keep)
llama_cpp.llama_sample_temp(ctx=self.ctx, candidates=candidates_p,
temp=c_float(n_temp))
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)
n = llama_cpp.llama_token_to_piece(
self.model, llama_cpp.llama_token(id), buf, 32
)
piece = buf[:n].decode('utf-8', 'ignore')
cc.session.last_n_tokens_data = cc.session.last_n_tokens_data[1:] + [id]
running = True
if piece in abort_set:
logger.debug(f"Break ({len(log_chunks)}): Aborting on {piece} ({id})")
running = False
# TODO Do I need to inject a newline in-context here?
id = None
return response_tokens
elif id == 2 and (n_generated == 0 or last_id == 13):
# 2 is '', repeating, this is bad model output.
running = False
id = None
return response_tokens
elif (last_piece, piece) in sequence_set:
logger.debug(f"Break ({len(log_chunks)}): sequence {last_piece}, {piece} ({id})")
running = False
id = None
elif piece == '\n' and last_piece == '\n':
logger.debug(f"Break ({len(log_chunks)}): Double Newline ({id})")
running = False
id = None
elif id == llama_cpp.llama_token_eos(self.ctx):
logger.debug(f"Break ({len(log_chunks)}): EOS ({id})")
running = False
# TODO Do I need to inject a newline in-context here?
id = 13
if id is not None:
#tokens = (llama_cpp.llama_token * 1)(id)
#return_code = llama_cpp.llama_decode(ctx=self.ctx, batch=llama_cpp.llama_batch_get_one(tokens=tokens, n_tokens=1, pos_0=self.n_past, seq_id=0))
tokens = (llama_cpp.llama_token * 1)(id)
return_code = llama_cpp.llama_eval(ctx=self.ctx, tokens=tokens, n_tokens=1, n_past=cc.session.n_past)
log_chunks.append(piece)
log_ids.append(id)
if return_code != 0:
logger.error(f"Break - Model Eval return code {return_code}")
running = False
else:
cc.session.n_past += 1
n_generated += 1
response_tokens.append(piece)
if self.output is not None:
self.output.handle_token(piece)
remaining_tokens -= 1
last_piece = piece
last_id = id
cc.accept_token(id, self.ctx)
if piece in stop_set:
running = False
if len(log_chunks) > 0 and (not running or len(log_chunks) % log_chunk_length == 0):
# Generally superseded by the output handler
#logger.debug(f"Generated ({n_generated}): {''.join(log_chunks).strip()}")
#logger.debug(f"Tokens ({n_generated}): {log_chunks}, {log_ids}")
log_chunks = []
if not running:
break
finally:
#llama_cpp.llama_batch_free(llama_batch)
pass
cc.session.commit_prompt(''.join(response_tokens))
return response_tokens
def __del__(self):
llama_cpp.llama_free(self.ctx)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment