Created
March 3, 2024 06:05
-
-
Save edk208/aeacbf4cd8f387bf38dd2b57a8e094e9 to your computer and use it in GitHub Desktop.
ExllamaV2 LoRA with Dynamic Layers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# to use this, first install python and exllamav2 (https://github.com/turboderp/exllamav2) | |
# load a model, rearrange the layers as you like, set generation parameters, and run it | |
# duplicate layers share tensors, but still need extra memory for the cache | |
# thanks to @dnhkng for showing that the cache needs to be re-created | |
# licensed under WTFPL (http://www.wtfpl.net/about/) - Silphendio | |
# Additional updates to use LoRA with duplicate layers | |
# Update to model.modules_dict to include the new layers | |
# LoRA must be created with the static frankenmerge model first | |
# then can be used on top of the dynamic layers | |
# also licensed under WTFPL (http://www.wtfpl.net/about/) - edk208 | |
from exllamav2 import * | |
from exllamav2.generator import * | |
import sys, torch | |
from copy import copy | |
config = ExLlamaV2Config() | |
config.model_dir = "../Documents/huggingface/Yi-34B-Chat-4.65bpw-h6-exl2" | |
config.prepare() | |
model = ExLlamaV2(config) | |
cache = ExLlamaV2Cache_8bit(model, lazy = True) | |
print("Loading model...") | |
print("Keys in model.modules_dict:") | |
for key in model.modules_dict.keys(): | |
print(key) | |
model.load_autosplit(cache) | |
tokenizer = ExLlamaV2Tokenizer(config) | |
gen_settings = ExLlamaV2Sampler.Settings() | |
## mix layers here | |
layer_arrangement = list(range(0,20)) + list(range(10,30))+ list(range(20,40)) + list(range(30,50))+list(range(40,60)) | |
# modules arangement: [embedding, [...layers], rms-norm, head] | |
# where each layer is [attention, mlp] | |
old_modules = model.modules | |
model.modules = old_modules[:1] | |
model.modules_dict[model.modules[-1].key] = model.modules[-1] | |
for i, idx in enumerate(layer_arrangement): | |
model.modules += [copy(old_modules[idx*2 + 1])] | |
model.modules[-1].layer_idx = i # for duplicate layers to use a different cache | |
for m in model.modules[-1].submodules: | |
# Split the key into parts | |
key_parts = m.key.split('.') | |
key_parts[2] = str(i) # Convert i to string, since we're dealing with string manipulation | |
updated_key = '.'.join(key_parts) | |
# Update the dictionary entry with the new key | |
model.modules_dict[updated_key] = m | |
print(f"layer {i} is {model.modules[-1].key}") | |
model.modules += [old_modules[idx*2 + 2]] | |
for m in model.modules[-1].submodules: | |
# Split the key into parts | |
key_parts = m.key.split('.') | |
key_parts[2] = str(i) # Convert i to string, since we're dealing with string manipulation | |
updated_key = '.'.join(key_parts) | |
# Update the dictionary entry with the new key | |
model.modules_dict[updated_key] = m | |
model.modules += old_modules[-2:] | |
model.head_layer_idx = len(model.modules) -1 | |
model.config.num_hidden_layers = len(layer_arrangement) | |
model.last_kv_layer_idx = len(model.modules) -4 | |
print('Re-creating cache') | |
del cache | |
model.cache_map = {} | |
model.set_cache_map() | |
cache = ExLlamaV2Cache_8bit(model) | |
print("Keys in new model.modules_dict:") | |
for key in model.modules_dict.keys(): | |
print(key) | |
# this needs to be re-created after rearranging layers | |
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) | |
generator.set_stop_conditions([tokenizer.eos_token_id]) | |
## mix layers end | |
# adjust generation settings | |
gen_settings.temperature = 0.0 # for deterministic results | |
#gen_settings.top_k = 50 | |
#gen_settings.top_p = 0.8 | |
#gen_settings.min_p = 0 | |
max_response_length = 512 | |
print("starting generation") | |
text = """<|im_start|>system | |
You are a chatbot who can help code!<|im_end|> | |
<|im_start|>user | |
Write me a python script to blink an LED on a raspberry PI.<|im_end|> | |
<|im_start|>assistant | |
""" | |
print("\n" + text, end="") | |
instruction_ids = tokenizer.encode(text, add_bos = True) | |
context_ids = instruction_ids if generator.sequence_ids is None \ | |
else torch.cat([generator.sequence_ids, instruction_ids], dim = -1) | |
# Load LoRA | |
lora_directory = "../Documents/huggingface/lora_frank_yi" | |
lora = ExLlamaV2Lora.from_directory(model, lora_directory) | |
generator.begin_stream(context_ids, gen_settings, loras = lora) | |
for _ in range(max_response_length): | |
chunk, eos, _ = generator.stream() | |
if eos: break | |
text += chunk | |
if text.endswith("<|im_end|>"): | |
break | |
print(chunk, end = "") | |
sys.stdout.flush() | |
text += "\n" | |
# cleanup | |
model.modules = old_modules | |
model.unload() | |
del cache |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment