Skip to content

Instantly share code, notes, and snippets.

@dnhkng
Created January 14, 2024 16:30
Show Gist options
  • Save dnhkng/34e78b6082ec26124d72624dc3f6f666 to your computer and use it in GitHub Desktop.
Save dnhkng/34e78b6082ec26124d72624dc3f6f666 to your computer and use it in GitHub Desktop.
# to use this, first install python and exllamav2 (https://github.com/turboderp/exllamav2)
# load a model, rearrange the layers as you like, set generation parameters, and chat with it
# duplicate layers need no extra memory
# WARNING: duplicate layers share the same cache, even though they shouldn't.
# This makes the model even more demented than a frankenmerge - use at your own risk.
# public domain - Silphendio
import copy
import math
import random
import sys
import time
import torch
import torch.nn.functional as F
from conversion.tokenize import get_tokens
from exllamav2 import *
from exllamav2.attn import ExLlamaV2Attention
from exllamav2.generator import *
random.seed(1234) # for semi-determinism
config = ExLlamaV2Config()
config.model_dir = "./models/MythoMax-L2-13B-EXL2/"
config.prepare()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache_8bit(model, lazy = True)
print("Loading model...")
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
class ExLlamaV2AttentionWrapperNoCache(ExLlamaV2Attention):
def __init__(self, obj, new_idx):
object.__setattr__(self, '_obj', obj)
object.__setattr__(self, '_new_idx', new_idx)
def __getattribute__(self, name):
if name == 'layer_idx':
return object.__getattribute__(self, '_new_idx')
# Delegate all other attributes to the wrapped object
try:
return getattr(object.__getattribute__(self, '_obj'), name)
except AttributeError:
return object.__getattribute__(self, name)
class ExLlamaV2AttentionWrapper(ExLlamaV2Attention):
def __init__(self, obj, new_idx):
object.__setattr__(self, '_obj', obj)
object.__setattr__(self, '_new_idx', new_idx)
def __getattribute__(self, name):
# Mask layer_idx
if name == 'layer_idx':
return object.__getattribute__(self, '_new_idx')
try:
# Delegate all other attributes to the wrapped object
attr = getattr(object.__getattribute__(self, '_obj'), name)
if not callable(attr):
return attr
except AttributeError:
pass
return object.__getattribute__(self, name)
layers = list(range(40))
repeats = [(0,20),(10,30),(20,40)]
# repeats = [(x, x+10) for x in range(0,31,5)]
# repeats = [(x, x+2) for x in range(39)]
layers = [list(range(*interval)) for interval in repeats]
layers = [item for sublist in layers for item in sublist]
print(layers)
## zpin's code
orig_modules = model.modules
model.modules = orig_modules[:1]
print('building model')
for i, idx in enumerate(layers):
nextModule = ExLlamaV2AttentionWrapperNoCache(orig_modules[idx*2 + 1], i)
model.modules.append(nextModule)
nextModule = copy.copy(orig_modules[idx*2 + 2])
nextModule.layer_idx = i
model.modules.append(nextModule)
model.modules += orig_modules[-2:]
num_layers = int((len(model.modules) - 3) / 2)
model.head_layer_idx = len(model.modules) -1
model.config.num_hidden_layers = num_layers
model.last_kv_layer_idx = len(model.modules) -4
cache_class = type(cache)
del cache
print('Re-creating cache')
model.cache_map = {}
model.set_cache_map()
print(dir(model))
print(model.cache_map)
for i, m in enumerate(model.modules):
if hasattr(m, 'layer_idx'):
print(i, m.key, m.layer_idx)
else:
print(i, m.key)
cache = cache_class(model)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
generator.set_stop_conditions([tokenizer.eos_token_id])
start = time.time()
with torch.inference_mode():
print(f" -- Running perplexity test")
eval_dataset = "wikitext-2-v1_wikitext-test.parquet"
eval_rows = 128
eval_length = 2048
print(f" -- Dataset: {eval_dataset}")
print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
logprob_sum = 0.0
logprob_count = 0
def ppl(input_ids__, logits__, lengths__):
logprob_sum_ = 0.0
logprob_count_ = 0
assert logits__.shape[0] == input_ids__.shape[0]
ll = logits__.shape[1]
for bi in range(logits__.shape[0]):
cl = max(ll - lengths__[bi], 0)
logits_ = logits__[bi:bi+1, cl:, :]
input_ids_ = input_ids__[bi:bi+1, cl:]
chunksize = logits_.shape[1] * 4000 // logits_.shape[2] + 1
b_ = 0
while b_ < logits_.shape[1]:
a_ = b_
b_ = min(b_ + chunksize, logits_.shape[1])
logits_f = logits_[:, a_:b_, :].float() + 1e-10
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
log_probs = F.log_softmax(logits_f, dim=-1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum_ += token_log_probs.sum().item()
logprob_count_ += target_ids.numel()
return logprob_sum_, logprob_count_
print(f" -- Inference", end = "")
sys.stdout.flush()
eval_length = 2048
if cache is None:
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if eval_length > model.config.max_input_len else None
for i in range(eval_tokens.shape[0]):
if i % 10 == 0: print(".", end = "")
sys.stdout.flush()
input_ids = eval_tokens[i:i+1, :]
input_ids = input_ids[:, :]
if cache is not None: cache.current_seq_len = 0
logits = model.forward(input_ids, cache)
logits = logits[:, :-1, :]
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1])
logprob_sum += logprob_sum__
logprob_count += logprob_count__
print()
mean_log_prob = logprob_sum / logprob_count
perplexity = math.exp(-mean_log_prob)
print(f" -- Evaluation perplexity: {perplexity:.4f}")
print(f"dur: {time.time() - start}")
for i, cache_layer in enumerate(cache.key_states):
print(i+1, cache_layer[0,1:4,1,1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment