Skip to content

Instantly share code, notes, and snippets.

@ChrisHayduk
Last active May 14, 2024 07:15
Show Gist options
  • Save ChrisHayduk/1a53463331f52dca205e55982baf9930 to your computer and use it in GitHub Desktop.
Save ChrisHayduk/1a53463331f52dca205e55982baf9930 to your computer and use it in GitHub Desktop.
Merging QLoRA weights with quantized model
"""
The code below combines approaches published by both @eugene-yh and @jinyongyoo on Github.
Thanks for the contributions guys!
"""
import torch
import peft
import json
import shutil
from peft.utils import _get_submodules
import os
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit
from peft import PeftModel
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, CodeLlamaTokenizer
import gc
import copy
def save_model(model, tokenizer, to):
print(f"Saving dequantized model to {to}...")
model.save_pretrained(to)
tokenizer.save_pretrained(to)
config_data = json.loads(open(os.path.join(to, 'config.json'), 'r').read())
config_data.pop("quantization_config", None)
config_data.pop("pretraining_tp", None)
with open(os.path.join(to, 'config.json'), 'w') as config:
config.write(json.dumps(config_data, indent=2))
def dequantize_model(model, tokenizer, to='./dequantized_model', dtype=torch.bfloat16, device="cpu"):
"""
'model': the peftmodel you loaded with qlora.
'tokenizer': the model's corresponding hf's tokenizer.
'to': directory to save the dequantized model
'dtype': dtype that the model was trained using
'device': device to load the model to
"""
# Delete the model object if it exists
if os.path.exists(to):
shutil.rmtree(to)
os.makedirs(to, exist_ok=True)
cls = bnb.nn.Linear4bit
with torch.no_grad():
for name, module in model.named_modules():
if isinstance(module, cls):
print(f"Dequantizing `{name}`...")
quant_state = copy.deepcopy(module.weight.quant_state)
quant_state[2] = dtype
weights = dequantize_4bit(module.weight.data, quant_state=quant_state, quant_type="nf4").to(dtype)
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=None, dtype=dtype)
new_module.weight = torch.nn.Parameter(weights)
new_module.to(device=device, dtype=dtype)
parent, target, target_name = _get_submodules(model, name)
setattr(parent, target_name, new_module)
model.is_loaded_in_4bit = False
save_model(model, tokenizer, to)
return model
model_path = 'Huggingface-base-model/path-goes-here'
adapter_path = 'Huggingface-adapter/path-goes-here'
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
try:
print(f"Starting to load the model {model_path} into memory")
model = LlamaForCausalLM.from_pretrained(
model_path,
load_in_4bit=True,
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
device_map="auto"
)
print(model)
tok = LlamaTokenizer.from_pretrained(model_path)
# Note: This function outputs the dequantized model without merging the adapter yet
# The code below it will merge the adapter and then save it to disk
model = dequantize_model(model, tok, to='output-folder-for-dequantized-model-here')
print(model)
model = PeftModel.from_pretrained(model = model, model_id = adapter_path)
print(model)
model = model.merge_and_unload()
print(model)
print(f"Successfully loaded the model {model_path} into memory")
# Note that the output folder here should be different than the one you used for dequantize_model
# This save will output the model merged with LoRA weights
save_model(model, tok, "put-output-folder-here")
print(f"Successfully saved merged model {model_path} to disk")
except Exception as e:
print(f"An error occurred: {e}")
# Delete the model object if it exists
if 'model' in locals():
del model
# Clear the GPU cache
torch.cuda.empty_cache()
# Run the garbage collection
gc.collect()
print("Model, GPU cache, and garbage have been cleared.")
@ChrisHayduk
Copy link
Author

Another way to save GPU memory is to modify new_module.to(device=device, dtype=dtype) to new_module.to(device="cpu", dtype=dtype), after testing, for the llama 7B 4bit quantization model, using the GPU takes 18-19G VRAM, while CPU requires 5-6G VRAM and 14G RAM. The time consumed for the whole process is 272s for the former and 380s for the latter. I think this is acceptable.

Good call! I added this change to the code. Now the default device specified in the dequantize function is "cpu"

@jaredquekjz
Copy link

jaredquekjz commented Sep 29, 2023

Thanks! Will def try out with the CPU optimisation. Can I clarify how is ur gist different or the same from the recent change to the Peft library :

huggingface/peft@140a69b

I notice it refers to this gist as well. Is it the same thing but incorporated directly into Peft?

@ChrisHayduk
Copy link
Author

It looks like it should actually do the same thing! Seems like this is the official PEFT implementation of my gist, you should be able to use that code instead of the gist and it should work the same from what I can tell.

@lapp0
Copy link

lapp0 commented Oct 2, 2023

I'm running into an issue due to module.weight.quant_state being None. Any idea what might cause this? I'm trying to perform this operation without a GPU, and I've set device_map={"": "cpu"}

Dequantizing `model.layers.0.self_attn.q_proj`...
An error occurred: 'NoneType' object does not support item assignment

@jaredquekjz
Copy link

jaredquekjz commented Oct 2, 2023

Hi Chris. There's this serious issue noted by Ronan McGovern and me (huggingface/transformers#26492) that require your advice. It's also highlighted by Benjamin Marie in the final part of the notebook in: https://kaitchup.substack.com/p/lora-adapters-when-a-naive-merge.

It basically involves the merged model losing its finetuning quality (higher perplexity) mysteriously when you load it again in 4-bit. It does not seem to happen if u keep it at fp16. Would you have any advice on why this happens and how we can resolve it? I'm working with a 70bn model and it's not practical for GPU poor folks like us to keep it at fp16...

Really hope we can merge qLora adapters well as it's such a useful technique!

@ChrisHayduk
Copy link
Author

Hi Chris. There's this serious issue noted by Ronan McGovern and me (huggingface/transformers#26492) that require your advice. It's also highlighted by Benjamin Marie in the final part of the notebook in: https://kaitchup.substack.com/p/lora-adapters-when-a-naive-merge.

It basically involves the merged model losing its finetuning quality (higher perplexity) mysteriously when you load it again in 4-bit. It does not seem to happen if u keep it at fp16. Would you have any advice on why this happens and how we can resolve it? I'm working with a 70bn model and it's not practical for GPU poor folks like us to keep it at fp16...

Really hope we can merge qLora adapters well as it's such a useful technique!

Hey Jared, I'll double check here. I've never tried using load_in_4_bit after quantizing, but llama.cpp's quantization methods seem to work for me without issue. Do you have a model/code I can use to reproduce the issue?

@jaredquekjz
Copy link

jaredquekjz commented Oct 4, 2023

Hi Chris. Thanks for looking into it. McGovern provides a fairly detailed way of reproducing it in huggingface/transformers#26492. I did similar things - except I actually use llama.cpp - and when i quiz the model with set questions - I get much worse answers than using a non-merged model with adapters applied (the model is loaded in 4 bit with nf4=true).

Benjamin Marie uses this function to measure perplexity:

def ppl_model(model, tokenizer, dataset):
  nlls= []
  max_length = 2048
  stride = 512
  for s in tqdm(range(len(dataset['text']))):
      encodings = tokenizer(dataset['text'][s], return_tensors="pt")
      seq_len = encodings.input_ids.size(1)
      prev_end_loc = 0
      for begin_loc in range(0, seq_len, stride):
          end_loc = min(begin_loc + max_length, seq_len)
          trg_len = end_loc - prev_end_loc
          input_ids = encodings.input_ids[:, begin_loc:end_loc].to("cuda")
          target_ids = input_ids.clone()
          target_ids[:, :-trg_len] = -100
          with torch.no_grad():
              outputs = model(input_ids, labels=target_ids)
              neg_log_likelihood = outputs.loss
          nlls.append(neg_log_likelihood)
          prev_end_loc = end_loc
          if end_loc == seq_len:
              break
  ppl = torch.exp(torch.stack(nlls).mean())
  return ppl

So when he loads the merged model in 4 bit and measure perplexity - he gets a bizarrely high value of 5.2509 - versus the baseline value of 3.7411 if he uses the merged model in fp16.

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained("./drive/MyDrive/dqz_merge/",  quantization_config=bnb_config, device_map={"": 0})
ppl = ppl_model(model, tokenizer, dataset)
print(ppl)

As you can imagine - this is practically a breaking issue. If you can solve this the community will be very grateful ..

@Yu-Yuqing
Copy link

Yu-Yuqing commented Nov 10, 2023

Hi Chris. Thanks for looking into it. McGovern provides a fairly detailed way of reproducing it in huggingface/transformers#26492. I did similar things - except I actually use llama.cpp - and when i quiz the model with set questions - I get much worse answers than using a non-merged model with adapters applied (the model is loaded in 4 bit with nf4=true). ..

Two possible explanations. First, make sure you are using the correct type inbnb_4bit_compute_dtype=compute_dtype , torch.bfloat16 or torch.float16, Chris's code uses torch.bfloat16 by default. Second, the merged model should not to use LoRA again while inferencing, i.e., comment out model = PeftModel(model, config)

@jaredquekjz
Copy link

Hi Chris. Thanks for looking into it. McGovern provides a fairly detailed way of reproducing it in huggingface/transformers#26492. I did similar things - except I actually use llama.cpp - and when i quiz the model with set questions - I get much worse answers than using a non-merged model with adapters applied (the model is loaded in 4 bit with nf4=true). ..

Two possible explanations. First, make sure you are using the correct type inbnb_4bit_compute_dtype=compute_dtype , torch.bfloat16 or torch.float16, Chris's code uses torch.bfloat16 by default. Second, the merged model should not to use LoRA again while inferencing, i.e., comment out model = PeftModel(model, config)

I am currently using the 'default' method in Peft now - but thanks for clarifying!

@sampbarrow
Copy link

sampbarrow commented Dec 6, 2023

I get:

Dequantizing `model.layers.0.self_attn.q_proj`...
An error occurred: 'QuantState' object does not support item assignment
Model, GPU cache, and garbage have been cleared.

Edit: Fixed by downgrading bitsandbytes to 0.41.0

@NeelMishra
Copy link

@sampbarrow @ChrisHayduk This error is also witnessed by me, how can we rectify it. Downgrading should not be a solution.

@michal-kajstura
Copy link

@sampbarrow @NeelMishra Just change quant_state[2] = dtype to quant_state.dtype = dtype

@drewskidang
Copy link

You can't pass load_in_4bitor load_in_8bit as a kwarg when passing quantization_config

Any recs

@nzw0301
Copy link

nzw0301 commented May 14, 2024

Thank you for sharing this helpful script. I've noticed that the current script ignores bias. To deal with it,

                    has_bias = module.bias is not None
                    new_module = torch.nn.Linear(module.in_features, module.out_features, bias=has_bias, dtype=dtype)
                    new_module.weight = torch.nn.Parameter(weights, requires_grad=False)
                    if has_bias:
                        new_module.bias.data = module.bias.data.detach().to(dtype)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment