Skip to content

Instantly share code, notes, and snippets.

@SunMarc
Last active May 2, 2024 16:41
Show Gist options
  • Star 34 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SunMarc/dcdb499ac16d355a8f265aa497645996 to your computer and use it in GitHub Desktop.
Save SunMarc/dcdb499ac16d355a8f265aa497645996 to your computer and use it in GitHub Desktop.
Finetune GPTQ model with peft and tlr
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
AutoTokenizer,
TrainingArguments,
)
from peft import prepare_model_for_kbit_training, get_peft_model
from transformers import GPTQConfig
from trl import SFTTrainer
# This example fine-tunes Llama 2 model on Guanaco dataset
# using GPTQ and peft.
# Use it by correctly passing --model_name argument when running the
# script. The default model is ybelkada/llama-7b-GPTQ-test
# Versions used:
# accelerate == 0.21.0
# auto-gptq == 0.4.2
# trl == 0.4.7
# peft from source
# transformers from source
# optimum from source
# For models that have `config.pretraining_tp > 1` install:
# pip install git+https://github.com/huggingface/transformers.git
@dataclass
class ScriptArguments:
"""
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""
local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})
per_device_train_batch_size: Optional[int] = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=4)
learning_rate: Optional[float] = field(default=2e-4)
max_grad_norm: Optional[float] = field(default=0.3)
weight_decay: Optional[int] = field(default=0.001)
lora_alpha: Optional[int] = field(default=16)
lora_dropout: Optional[float] = field(default=0.1)
lora_r: Optional[int] = field(default=64)
max_seq_length: Optional[int] = field(default=512)
model_name: Optional[str] = field(
default="ybelkada/llama-7b-GPTQ-test",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
}
)
dataset_name: Optional[str] = field(
default="timdettmers/openassistant-guanaco",
metadata={"help": "The preference dataset to use."},
)
num_train_epochs: Optional[int] = field(
default=1,
metadata={"help": "The number of training epochs for the reward model."},
)
fp16: Optional[bool] = field(
default=False,
metadata={"help": "Enables fp16 training."},
)
bf16: Optional[bool] = field(
default=False,
metadata={"help": "Enables bf16 training."},
)
packing: Optional[bool] = field(
default=False,
metadata={"help": "Use packing dataset creating."},
)
gradient_checkpointing: Optional[bool] = field(
default=True,
metadata={"help": "Enables gradient checkpointing."},
)
optim: Optional[str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: str = field(
default="constant",
metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
)
max_steps: int = field(default=10000, metadata={"help": "How many optimizer update steps to take"})
warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"})
group_by_length: bool = field(
default=True,
metadata={
"help": "Group sequences into batches with same length. Saves memory and speeds up training considerably."
},
)
save_steps: int = field(default=10, metadata={"help": "Save checkpoint every X updates steps."})
logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."})
merge_and_push: Optional[bool] = field(
default=False,
metadata={"help": "Merge and push weights after training"},
)
output_dir: str = field(
default="./results",
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
def create_and_prepare_model(args):
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
print("=" * 80)
# Load the entire model on the GPU 0
device_map = {"":0}
# switch to `device_map = "auto"` for multi-GPU
# device_map = "auto"
# need to disable exllama kernel
# exllama kernel are not very stable for training
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
device_map=device_map,
quantization_config= GPTQConfig(bits=4, disable_exllama=True)
)
# check: https://github.com/huggingface/transformers/pull/24906
model.config.pretraining_tp = 1
peft_config = LoraConfig(
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
r=script_args.lora_r,
bias="none",
task_type="CAUSAL_LM",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
return model, peft_config, tokenizer
training_arguments = TrainingArguments(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.per_device_train_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
optim=script_args.optim,
save_steps=script_args.save_steps,
logging_steps=script_args.logging_steps,
learning_rate=script_args.learning_rate,
fp16=script_args.fp16,
bf16=script_args.bf16,
max_grad_norm=script_args.max_grad_norm,
max_steps=script_args.max_steps,
warmup_ratio=script_args.warmup_ratio,
group_by_length=script_args.group_by_length,
lr_scheduler_type=script_args.lr_scheduler_type,
)
model, peft_config, tokenizer = create_and_prepare_model(script_args)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.config.use_cache = False
dataset = load_dataset(script_args.dataset_name, split="train")
# Fix weird overflow issue with fp16 training
tokenizer.padding_side = "right"
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=script_args.packing,
)
trainer.train()
if script_args.merge_and_push:
output_dir = os.path.join(script_args.output_dir, "final_checkpoints")
trainer.model.save_pretrained(output_dir)
# Free memory for merging weights
del model
torch.cuda.empty_cache()
@SunMarc
Copy link
Author

SunMarc commented Aug 25, 2023

This happened because you probably passed an non quantized model. For this script to work, you need to pass a gptq quantized model like TheBloke/CodeLlama-13B-GPTQ

@CesarCalvoCobo
Copy link

CesarCalvoCobo commented Aug 26, 2023

working for me with quantized model . Is it possible to merge the resulting adapter with the base model ? any example for this ?

@SunMarc
Copy link
Author

SunMarc commented Aug 28, 2023

No, merging is not supported yet.

@tridungduong16
Copy link

Thanks for your sharing. I wonder about the dataset format, can we modify to seq2seq trainer with "input" and "response" format?

@SunMarc
Copy link
Author

SunMarc commented Sep 1, 2023

If it works with fp32 or fp16 models with adapter layers, it should definitely work with gptq model with adapter layers !

@aleksanderhan
Copy link

If merging is not supported, how do i load the saved adapter and apply it to the original model later?

@SunMarc
Copy link
Author

SunMarc commented Sep 6, 2023

You can load it the train adapters to the model and use it as it is. Of course, it will be slower than if it was merged.

@aleksanderhan
Copy link

just like this:

  tokenizer = AutoTokenizer.from_pretrained(model_id)
  tokenizer.pad_token = tokenizer.eos_token
  tokenizer.padding_side = "right"
  
  model = AutoModelForCausalLM.from_pretrained(
      model_id,
      quantization_config= GPTQConfig(bits=4, disable_exllama=True),
      device_map="auto"
  )
  model = PeftModel.from_pretrained(model, adapter_folder)

Just checking, because i see no performance improvements on my task after I've finetuned the quantized model. Thinking im loading it wrong

@SunMarc
Copy link
Author

SunMarc commented Sep 6, 2023

Check the lora weight if it is the same as the one you trained. If the newest version of transformers, you can follow this guide to easily save/load your adapters.

@M-H482
Copy link

M-H482 commented Sep 18, 2023

Will merge be supported? Thank you!

@SunMarc
Copy link
Author

SunMarc commented Sep 18, 2023

Currently, it is not on the roadmap. You should better finetune your model with bitsandbytes and use gptq for production. See this blogpost for more information : https://huggingface.co/blog/overview-quantization-transformers

@M-H482
Copy link

M-H482 commented Sep 18, 2023

Currently, it is not on the roadmap. You should better finetune your model with bitsandbytes and use gptq for production. See this blogpost for more information : https://huggingface.co/blog/overview-quantization-transformers

ok, thank you!

@KartavyaBagga
Copy link

How to run this Code ? can you provide a shell command for fine-tuning this
in which i can set my custom dataset file, --model_name

@zhengshuo1
Copy link

Thanks for your sharing. I wonder about the dataset format, can we modify to seq2seq trainer with "input" and "response" format?

Hello, have you found the proper format of the dataset?

@SunMarc
Copy link
Author

SunMarc commented Nov 6, 2023

Hi @KartavyaBagga, to load the dataset, we use dataset = load_dataset(script_args.dataset_name, split="train") from the datasets library. The simplest would be to add it there and load it by passing the dataset_name args.

@MuhammadShifa
Copy link

Hi @SunMarc Thank you so much for this script.
I am finetunning LLama-2 qptq model, the script is working fine on my dataset. But the question is how I can use this model for chat conversation if I am unable to merge the model? As I am using this repository for gptq and transformer models.

@SunMarc
Copy link
Author

SunMarc commented Dec 18, 2023

Hi @MuhammadShifa, while this is true that you can't merge the model, you should still be able to run it. It will be slower than if you could merge it however. Otherwise, if you really want to merge this model, you can use bitsandbytes 4-bit quant as it supports merging (QLoRa)

@MuhammadShifa
Copy link

Thanks @SunMarc for your quick response,
Yes I have tried to merge it but facing an issue. Please can you share me a resource link how I can use this model for realtime conversation?

@SunMarc
Copy link
Author

SunMarc commented Dec 18, 2023

@MuhammadShifa
Copy link

Thank you so much @SunMarc

@glf1030
Copy link

glf1030 commented Dec 27, 2023

ValueError: Please specify target_modules in peft_config. hi, I am wondering how to fix this, my quantized model is :QWenLMHeadModel(
(transformer): QWenModel(
(wte): Embedding(152064, 8192)
(drop): Dropout(p=0.0, inplace=False)
(rotary_emb): RotaryEmbedding()
(h): ModuleList(
(0-79): 80 x QWenBlock(
(ln_1): RMSNorm()
(attn): QWenAttention(
(attn_dropout): Dropout(p=0.0, inplace=False)
(c_attn): QuantLinear()
(c_proj): QuantLinear()
)
(ln_2): RMSNorm()
(mlp): QWenMLP(
(c_proj): QuantLinear()
(w1): QuantLinear()
(w2): QuantLinear()
)
)
)
(ln_f): RMSNorm()
)
(lm_head): Linear(in_features=8192, out_features=152064, bias=False)
)

@SunMarc
Copy link
Author

SunMarc commented Dec 27, 2023

Hi @glf1030, this happens because qwen model is not supported by default. See supported list here. Hence you need to pass target_modules in peft_config. For your mdoel, you can pass target_modules = ["c_attn"]

@glf1030
Copy link

glf1030 commented Dec 28, 2023

Hi @glf1030, this happens because qwen model is not supported by default. See supported list here. Hence you need to pass target_modules in peft_config. For your mdoel, you can pass target_modules = ["c_attn"]

hi, thanks for your reply. I passed target_modules=['c_attn'], and it works for training;
but for inference, I used following code:

model_id = "/data/lifan/Qwen-72B-Chat-Int4"
adapter_model_id = "/data/lifan/cicc_lora_after_gptq_training_checkpoint"

tokenizer = AutoTokenizer.from_pretrained(model_id,use_fast=True,
    trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

model = AutoModelForCausalLM.from_pretrained(model_id,device_map="auto",
    trust_remote_code=True)
peft_config = PeftConfig.from_pretrained(adapter_model_id)
# to initiate with random weights
peft_config.init_lora_weights = False

model.add_adapter(peft_config)
model.enable_adapters()
output = model.generate(**inputs)

return tokenizer.decode(output[0])

============================= I got following error ======================================

File "/data/lifan/miniconda3/envs/llama_factory/lib/python3.9/site-packages/peft/tuners/tuners_utils.py", line 90, in init
self.inject_adapter(self.model, adapter_name)
File "/data/lifan/miniconda3/envs/llama_factory/lib/python3.9/site-packages/peft/tuners/tuners_utils.py", line 247, in inject_adapter
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, **optional_kwargs)
File "/data/lifan/miniconda3/envs/llama_factory/lib/python3.9/site-packages/peft/tuners/lora/model.py", line 202, in _create_and_replace
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
File "/data/lifan/miniconda3/envs/llama_factory/lib/python3.9/site-packages/peft/tuners/lora/model.py", line 355, in _create_new_module
raise ValueError(
ValueError: Target module QuantLinear() is not supported. Currently, only the following modules are supported: torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D.

Would you please take a look at it ? Best.

@SunMarc
Copy link
Author

SunMarc commented Dec 28, 2023

Hi @glf1030, make sure you have the latest version of peft.

@glf1030
Copy link

glf1030 commented Jan 2, 2024

Hi @glf1030, make sure you have the latest version of peft.

hi, my version is 0.7.1.
(llama_factory) [lifan@iZ0jld5hy53xg1wwoistghZ ~]$ pip show peft
Name: peft
Version: 0.7.1
Summary: Parameter-Efficient Fine-Tuning (PEFT)
Home-page: https://github.com/huggingface/peft
Author: The HuggingFace team
Author-email: sourab@huggingface.co
License: Apache
Location: /data/lifan/miniconda3/envs/llama_factory/lib/python3.9/site-packages
Requires: accelerate, huggingface-hub, numpy, packaging, psutil, pyyaml, safetensors, torch, tqdm, transformers
Required-by: auto-gptq

@DimensionZer0
Copy link

working for me with quantized model . Is it possible to merge the resulting adapter with the base model of gptq ? any example for this ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment