Skip to content

Instantly share code, notes, and snippets.

@ChrisHayduk
Last active November 18, 2024 12:49
Show Gist options
  • Save ChrisHayduk/1a53463331f52dca205e55982baf9930 to your computer and use it in GitHub Desktop.
Save ChrisHayduk/1a53463331f52dca205e55982baf9930 to your computer and use it in GitHub Desktop.
Merging QLoRA weights with quantized model
"""
The code below combines approaches published by both @eugene-yh and @jinyongyoo on Github.
Thanks for the contributions guys!
"""
import torch
import peft
import json
import shutil
from peft.utils import _get_submodules
import os
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit
from peft import PeftModel
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, CodeLlamaTokenizer
import gc
import copy
def save_model(model, tokenizer, to):
print(f"Saving dequantized model to {to}...")
model.save_pretrained(to)
tokenizer.save_pretrained(to)
config_data = json.loads(open(os.path.join(to, 'config.json'), 'r').read())
config_data.pop("quantization_config", None)
config_data.pop("pretraining_tp", None)
with open(os.path.join(to, 'config.json'), 'w') as config:
config.write(json.dumps(config_data, indent=2))
def dequantize_model(model, tokenizer, to='./dequantized_model', dtype=torch.bfloat16, device="cpu"):
"""
'model': the peftmodel you loaded with qlora.
'tokenizer': the model's corresponding hf's tokenizer.
'to': directory to save the dequantized model
'dtype': dtype that the model was trained using
'device': device to load the model to
"""
# Delete the model object if it exists
if os.path.exists(to):
shutil.rmtree(to)
os.makedirs(to, exist_ok=True)
cls = bnb.nn.Linear4bit
with torch.no_grad():
for name, module in model.named_modules():
if isinstance(module, cls):
print(f"Dequantizing `{name}`...")
quant_state = copy.deepcopy(module.weight.quant_state)
quant_state[2] = dtype
weights = dequantize_4bit(module.weight.data, quant_state=quant_state, quant_type="nf4").to(dtype)
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=None, dtype=dtype)
new_module.weight = torch.nn.Parameter(weights)
new_module.to(device=device, dtype=dtype)
parent, target, target_name = _get_submodules(model, name)
setattr(parent, target_name, new_module)
model.is_loaded_in_4bit = False
save_model(model, tokenizer, to)
return model
model_path = 'Huggingface-base-model/path-goes-here'
adapter_path = 'Huggingface-adapter/path-goes-here'
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
try:
print(f"Starting to load the model {model_path} into memory")
model = LlamaForCausalLM.from_pretrained(
model_path,
load_in_4bit=True,
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
device_map="auto"
)
print(model)
tok = LlamaTokenizer.from_pretrained(model_path)
# Note: This function outputs the dequantized model without merging the adapter yet
# The code below it will merge the adapter and then save it to disk
model = dequantize_model(model, tok, to='output-folder-for-dequantized-model-here')
print(model)
model = PeftModel.from_pretrained(model = model, model_id = adapter_path)
print(model)
model = model.merge_and_unload()
print(model)
print(f"Successfully loaded the model {model_path} into memory")
# Note that the output folder here should be different than the one you used for dequantize_model
# This save will output the model merged with LoRA weights
save_model(model, tok, "put-output-folder-here")
print(f"Successfully saved merged model {model_path} to disk")
except Exception as e:
print(f"An error occurred: {e}")
# Delete the model object if it exists
if 'model' in locals():
del model
# Clear the GPU cache
torch.cuda.empty_cache()
# Run the garbage collection
gc.collect()
print("Model, GPU cache, and garbage have been cleared.")
@jaredquekjz
Copy link

Hi Chris. Thanks for looking into it. McGovern provides a fairly detailed way of reproducing it in huggingface/transformers#26492. I did similar things - except I actually use llama.cpp - and when i quiz the model with set questions - I get much worse answers than using a non-merged model with adapters applied (the model is loaded in 4 bit with nf4=true). ..

Two possible explanations. First, make sure you are using the correct type inbnb_4bit_compute_dtype=compute_dtype , torch.bfloat16 or torch.float16, Chris's code uses torch.bfloat16 by default. Second, the merged model should not to use LoRA again while inferencing, i.e., comment out model = PeftModel(model, config)

I am currently using the 'default' method in Peft now - but thanks for clarifying!

@sampbarrow
Copy link

sampbarrow commented Dec 6, 2023

I get:

Dequantizing `model.layers.0.self_attn.q_proj`...
An error occurred: 'QuantState' object does not support item assignment
Model, GPU cache, and garbage have been cleared.

Edit: Fixed by downgrading bitsandbytes to 0.41.0

@NeelMishra
Copy link

@sampbarrow @ChrisHayduk This error is also witnessed by me, how can we rectify it. Downgrading should not be a solution.

@michal-kajstura
Copy link

@sampbarrow @NeelMishra Just change quant_state[2] = dtype to quant_state.dtype = dtype

@drewskidang
Copy link

You can't pass load_in_4bitor load_in_8bit as a kwarg when passing quantization_config

Any recs

@nzw0301
Copy link

nzw0301 commented May 14, 2024

Thank you for sharing this helpful script. I've noticed that the current script ignores bias. To deal with it,

                    has_bias = module.bias is not None
                    new_module = torch.nn.Linear(module.in_features, module.out_features, bias=has_bias, dtype=dtype)
                    new_module.weight = torch.nn.Parameter(weights, requires_grad=False)
                    if has_bias:
                        new_module.bias.data = module.bias.data.detach().to(dtype)

@npbool
Copy link

npbool commented Aug 6, 2024

Thanks for sharing. I found AutoModel has a dequantize function to dequantize the whole model and it works as expected after lora adapter merging (same ppl). So you don't have to do it yourself. But quantization_config should still be removed manually from saved config.json of the dequantized model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment