Skip to content

Instantly share code, notes, and snippets.

@dnhkng
Created January 14, 2024 11:12
Show Gist options
  • Save dnhkng/b4bad5d07b4cc532c00c306e46cb1db5 to your computer and use it in GitHub Desktop.
Save dnhkng/b4bad5d07b4cc532c00c306e46cb1db5 to your computer and use it in GitHub Desktop.
Testing the cache
# to use this, first install python and exllamav2 (https://github.com/turboderp/exllamav2)
# load a model, rearrange the layers as you like, set generation parameters, and chat with it
# duplicate layers need no extra memory
# WARNING: duplicate layers share the same cache, even though they shouldn't.
# This makes the model even more demented than a frankenmerge - use at your own risk.
# public domain - Silphendio
import random
import sys
import torch
from exllamav2 import *
from exllamav2.attn import ExLlamaV2Attention
from exllamav2.generator import *
random.seed(1234) # for semi-determinism
config = ExLlamaV2Config()
config.model_dir = "./models/TinyLlama-1.1B-Chat-v1.0-5.0bpw-h6-exl2"
config.prepare()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache_8bit(model, lazy = True)
print("Loading model...")
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
generator.set_stop_conditions([tokenizer.eos_token_id])
gen_settings = ExLlamaV2Sampler.Settings()
class ExLlamaV2AttentionWrapper(ExLlamaV2Attention):
def __init__(self, obj, new_idx):
object.__setattr__(self, '_obj', obj)
object.__setattr__(self, '_new_idx', new_idx)
def __getattribute__(self, name):
if name == 'layer_idx':
return object.__getattribute__(self, '_new_idx')
# Delegate all other attributes to the wrapped object
try:
return getattr(object.__getattribute__(self, '_obj'), name)
except AttributeError:
return object.__getattribute__(self, name)
## mix layers here
layer_arrangement = list(range(0,14)) + list(range(8,22))
# modules arangement: [embedding, [...layers], rms-norm, head]
# where each layer is [attention, mlp]
### silphendio's code
# old_modules = model.modules
# model.modules = old_modules[:1]
# for idx in layer_arrangement:
# model.modules += old_modules[idx*2 + 1 : idx*2 + 3]
# model.modules += old_modules[-2:]
# model.head_layer_idx = len(model.modules) -1
# model.config.num_hidden_layers = len(layer_arrangement)
# model.last_kv_layer_idx = len(model.modules) -4
## mix layers end
layers = list(range(0,14)) + list(range(8,14)) + list(range(8,14)) + list(range(8,22))
## zpin's code
orig_modules = model.modules
model.modules = orig_modules[:1]
for i, idx in enumerate(layers):
model.modules.append(ExLlamaV2AttentionWrapper(orig_modules[idx*2 + 1], i))
model.modules.append(orig_modules[idx*2 + 2])
model.modules += orig_modules[-2:]
num_layers = int((len(model.modules) - 3) / 2)
model.head_layer_idx = len(model.modules) -1
model.config.num_hidden_layers = num_layers
model.last_kv_layer_idx = len(model.modules) -4
cache_class = type(cache)
del generator
del cache
print('Re-creating cache')
model.cache_map = {}
model.set_cache_map()
cache = cache_class(model)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
## Test the system!
max_tokens = 1024
# chat using ChatML format
system_prompt = "<|system|>You are a chatbot who can tell bedtime stories!</s>\n"
text: str = system_prompt
instruction = "User: Tell me a story about a princess and a dragon."
print()
print("Assistant: ", end = "")
text += f"<|user|>{instruction}</s>\n<|assistant|>"
instruction_ids = tokenizer.encode(text, add_bos = True)
context_ids = instruction_ids if generator.sequence_ids is None \
else torch.cat([generator.sequence_ids, instruction_ids], dim = -1)
generator.begin_stream(context_ids, gen_settings)
for _ in range(max_tokens):
chunk, eos, _ = generator.stream()
if eos: break
text += chunk
if text.endswith("<|user|>"):
break
print(chunk, end = "")
sys.stdout.flush()
text += "\n"
for i, cache_layer in enumerate(cache.key_states):
print(i+1, cache_layer[0,1,1,1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment