Skip to content

Instantly share code, notes, and snippets.

@ChrisHayduk
Last active May 14, 2024 07:15
Show Gist options
  • Save ChrisHayduk/1a53463331f52dca205e55982baf9930 to your computer and use it in GitHub Desktop.
Save ChrisHayduk/1a53463331f52dca205e55982baf9930 to your computer and use it in GitHub Desktop.
Merging QLoRA weights with quantized model
"""
The code below combines approaches published by both @eugene-yh and @jinyongyoo on Github.
Thanks for the contributions guys!
"""
import torch
import peft
import json
import shutil
from peft.utils import _get_submodules
import os
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit
from peft import PeftModel
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, CodeLlamaTokenizer
import gc
import copy
def save_model(model, tokenizer, to):
print(f"Saving dequantized model to {to}...")
model.save_pretrained(to)
tokenizer.save_pretrained(to)
config_data = json.loads(open(os.path.join(to, 'config.json'), 'r').read())
config_data.pop("quantization_config", None)
config_data.pop("pretraining_tp", None)
with open(os.path.join(to, 'config.json'), 'w') as config:
config.write(json.dumps(config_data, indent=2))
def dequantize_model(model, tokenizer, to='./dequantized_model', dtype=torch.bfloat16, device="cpu"):
"""
'model': the peftmodel you loaded with qlora.
'tokenizer': the model's corresponding hf's tokenizer.
'to': directory to save the dequantized model
'dtype': dtype that the model was trained using
'device': device to load the model to
"""
# Delete the model object if it exists
if os.path.exists(to):
shutil.rmtree(to)
os.makedirs(to, exist_ok=True)
cls = bnb.nn.Linear4bit
with torch.no_grad():
for name, module in model.named_modules():
if isinstance(module, cls):
print(f"Dequantizing `{name}`...")
quant_state = copy.deepcopy(module.weight.quant_state)
quant_state[2] = dtype
weights = dequantize_4bit(module.weight.data, quant_state=quant_state, quant_type="nf4").to(dtype)
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=None, dtype=dtype)
new_module.weight = torch.nn.Parameter(weights)
new_module.to(device=device, dtype=dtype)
parent, target, target_name = _get_submodules(model, name)
setattr(parent, target_name, new_module)
model.is_loaded_in_4bit = False
save_model(model, tokenizer, to)
return model
model_path = 'Huggingface-base-model/path-goes-here'
adapter_path = 'Huggingface-adapter/path-goes-here'
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
try:
print(f"Starting to load the model {model_path} into memory")
model = LlamaForCausalLM.from_pretrained(
model_path,
load_in_4bit=True,
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
device_map="auto"
)
print(model)
tok = LlamaTokenizer.from_pretrained(model_path)
# Note: This function outputs the dequantized model without merging the adapter yet
# The code below it will merge the adapter and then save it to disk
model = dequantize_model(model, tok, to='output-folder-for-dequantized-model-here')
print(model)
model = PeftModel.from_pretrained(model = model, model_id = adapter_path)
print(model)
model = model.merge_and_unload()
print(model)
print(f"Successfully loaded the model {model_path} into memory")
# Note that the output folder here should be different than the one you used for dequantize_model
# This save will output the model merged with LoRA weights
save_model(model, tok, "put-output-folder-here")
print(f"Successfully saved merged model {model_path} to disk")
except Exception as e:
print(f"An error occurred: {e}")
# Delete the model object if it exists
if 'model' in locals():
del model
# Clear the GPU cache
torch.cuda.empty_cache()
# Run the garbage collection
gc.collect()
print("Model, GPU cache, and garbage have been cleared.")
@Yu-Yuqing
Copy link

Yu-Yuqing commented Nov 10, 2023

Hi Chris. Thanks for looking into it. McGovern provides a fairly detailed way of reproducing it in huggingface/transformers#26492. I did similar things - except I actually use llama.cpp - and when i quiz the model with set questions - I get much worse answers than using a non-merged model with adapters applied (the model is loaded in 4 bit with nf4=true). ..

Two possible explanations. First, make sure you are using the correct type inbnb_4bit_compute_dtype=compute_dtype , torch.bfloat16 or torch.float16, Chris's code uses torch.bfloat16 by default. Second, the merged model should not to use LoRA again while inferencing, i.e., comment out model = PeftModel(model, config)

@jaredquekjz
Copy link

Hi Chris. Thanks for looking into it. McGovern provides a fairly detailed way of reproducing it in huggingface/transformers#26492. I did similar things - except I actually use llama.cpp - and when i quiz the model with set questions - I get much worse answers than using a non-merged model with adapters applied (the model is loaded in 4 bit with nf4=true). ..

Two possible explanations. First, make sure you are using the correct type inbnb_4bit_compute_dtype=compute_dtype , torch.bfloat16 or torch.float16, Chris's code uses torch.bfloat16 by default. Second, the merged model should not to use LoRA again while inferencing, i.e., comment out model = PeftModel(model, config)

I am currently using the 'default' method in Peft now - but thanks for clarifying!

@sampbarrow
Copy link

sampbarrow commented Dec 6, 2023

I get:

Dequantizing `model.layers.0.self_attn.q_proj`...
An error occurred: 'QuantState' object does not support item assignment
Model, GPU cache, and garbage have been cleared.

Edit: Fixed by downgrading bitsandbytes to 0.41.0

@NeelMishra
Copy link

@sampbarrow @ChrisHayduk This error is also witnessed by me, how can we rectify it. Downgrading should not be a solution.

@michal-kajstura
Copy link

@sampbarrow @NeelMishra Just change quant_state[2] = dtype to quant_state.dtype = dtype

@drewskidang
Copy link

You can't pass load_in_4bitor load_in_8bit as a kwarg when passing quantization_config

Any recs

@nzw0301
Copy link

nzw0301 commented May 14, 2024

Thank you for sharing this helpful script. I've noticed that the current script ignores bias. To deal with it,

                    has_bias = module.bias is not None
                    new_module = torch.nn.Linear(module.in_features, module.out_features, bias=has_bias, dtype=dtype)
                    new_module.weight = torch.nn.Parameter(weights, requires_grad=False)
                    if has_bias:
                        new_module.bias.data = module.bias.data.detach().to(dtype)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment