Skip to content

Instantly share code, notes, and snippets.

@edk208
Created March 3, 2024 06:05
Show Gist options
  • Save edk208/aeacbf4cd8f387bf38dd2b57a8e094e9 to your computer and use it in GitHub Desktop.
Save edk208/aeacbf4cd8f387bf38dd2b57a8e094e9 to your computer and use it in GitHub Desktop.
ExllamaV2 LoRA with Dynamic Layers
# to use this, first install python and exllamav2 (https://github.com/turboderp/exllamav2)
# load a model, rearrange the layers as you like, set generation parameters, and run it
# duplicate layers share tensors, but still need extra memory for the cache
# thanks to @dnhkng for showing that the cache needs to be re-created
# licensed under WTFPL (http://www.wtfpl.net/about/) - Silphendio
# Additional updates to use LoRA with duplicate layers
# Update to model.modules_dict to include the new layers
# LoRA must be created with the static frankenmerge model first
# then can be used on top of the dynamic layers
# also licensed under WTFPL (http://www.wtfpl.net/about/) - edk208
from exllamav2 import *
from exllamav2.generator import *
import sys, torch
from copy import copy
config = ExLlamaV2Config()
config.model_dir = "../Documents/huggingface/Yi-34B-Chat-4.65bpw-h6-exl2"
config.prepare()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache_8bit(model, lazy = True)
print("Loading model...")
print("Keys in model.modules_dict:")
for key in model.modules_dict.keys():
print(key)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
gen_settings = ExLlamaV2Sampler.Settings()
## mix layers here
layer_arrangement = list(range(0,20)) + list(range(10,30))+ list(range(20,40)) + list(range(30,50))+list(range(40,60))
# modules arangement: [embedding, [...layers], rms-norm, head]
# where each layer is [attention, mlp]
old_modules = model.modules
model.modules = old_modules[:1]
model.modules_dict[model.modules[-1].key] = model.modules[-1]
for i, idx in enumerate(layer_arrangement):
model.modules += [copy(old_modules[idx*2 + 1])]
model.modules[-1].layer_idx = i # for duplicate layers to use a different cache
for m in model.modules[-1].submodules:
# Split the key into parts
key_parts = m.key.split('.')
key_parts[2] = str(i) # Convert i to string, since we're dealing with string manipulation
updated_key = '.'.join(key_parts)
# Update the dictionary entry with the new key
model.modules_dict[updated_key] = m
print(f"layer {i} is {model.modules[-1].key}")
model.modules += [old_modules[idx*2 + 2]]
for m in model.modules[-1].submodules:
# Split the key into parts
key_parts = m.key.split('.')
key_parts[2] = str(i) # Convert i to string, since we're dealing with string manipulation
updated_key = '.'.join(key_parts)
# Update the dictionary entry with the new key
model.modules_dict[updated_key] = m
model.modules += old_modules[-2:]
model.head_layer_idx = len(model.modules) -1
model.config.num_hidden_layers = len(layer_arrangement)
model.last_kv_layer_idx = len(model.modules) -4
print('Re-creating cache')
del cache
model.cache_map = {}
model.set_cache_map()
cache = ExLlamaV2Cache_8bit(model)
print("Keys in new model.modules_dict:")
for key in model.modules_dict.keys():
print(key)
# this needs to be re-created after rearranging layers
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
generator.set_stop_conditions([tokenizer.eos_token_id])
## mix layers end
# adjust generation settings
gen_settings.temperature = 0.0 # for deterministic results
#gen_settings.top_k = 50
#gen_settings.top_p = 0.8
#gen_settings.min_p = 0
max_response_length = 512
print("starting generation")
text = """<|im_start|>system
You are a chatbot who can help code!<|im_end|>
<|im_start|>user
Write me a python script to blink an LED on a raspberry PI.<|im_end|>
<|im_start|>assistant
"""
print("\n" + text, end="")
instruction_ids = tokenizer.encode(text, add_bos = True)
context_ids = instruction_ids if generator.sequence_ids is None \
else torch.cat([generator.sequence_ids, instruction_ids], dim = -1)
# Load LoRA
lora_directory = "../Documents/huggingface/lora_frank_yi"
lora = ExLlamaV2Lora.from_directory(model, lora_directory)
generator.begin_stream(context_ids, gen_settings, loras = lora)
for _ in range(max_response_length):
chunk, eos, _ = generator.stream()
if eos: break
text += chunk
if text.endswith("<|im_end|>"):
break
print(chunk, end = "")
sys.stdout.flush()
text += "\n"
# cleanup
model.modules = old_modules
model.unload()
del cache
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment