-
-
Save lewtun/b9d46e00292d9ecdd6fd9628d53c2814 to your computer and use it in GitHub Desktop.
# This is a modified version of TRL's `SFTTrainer` example (https://github.com/huggingface/trl/blob/main/examples/scripts/sft_trainer.py), | |
# adapted to run with DeepSpeed ZeRO-3 and Mistral-7B-V1.0. The settings below were run on 1 node of 8 x A100 (80GB) GPUs. | |
# | |
# Usage: | |
# - Install the latest transformers & accelerate versions: `pip install -U transformers accelerate` | |
# - Install deepspeed: `pip install deepspeed==0.9.5` | |
# - Install TRL from main: pip install git+https://github.com/huggingface/trl.git | |
# - Clone the repo: git clone github.com/huggingface/trl.git | |
# - Copy this Gist into trl/examples/scripts | |
# - Run from root of trl repo with: accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml --gradient_accumulation_steps 8 examples/scripts/sft_trainer.py | |
from dataclasses import dataclass, field | |
from typing import Optional | |
import torch | |
from accelerate import Accelerator | |
from datasets import load_dataset | |
from peft import LoraConfig | |
from tqdm import tqdm | |
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, AutoTokenizer | |
from trl import SFTTrainer | |
tqdm.pandas() | |
# Define and parse arguments. | |
@dataclass | |
class ScriptArguments: | |
""" | |
The name of the Casual LM model we wish to fine with SFTTrainer | |
""" | |
model_name: Optional[str] = field(default="mistralai/Mistral-7B-v0.1", metadata={"help": "the model name"}) | |
dataset_name: Optional[str] = field( | |
default="stingning/ultrachat", metadata={"help": "the dataset name"} | |
) | |
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"}) | |
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"}) | |
learning_rate: Optional[float] = field(default=2.0e-5, metadata={"help": "the learning rate"}) | |
batch_size: Optional[int] = field(default=8, metadata={"help": "the batch size"}) | |
seq_length: Optional[int] = field(default=1024, metadata={"help": "Input sequence length"}) | |
gradient_accumulation_steps: Optional[int] = field( | |
default=8, metadata={"help": "the number of gradient accumulation steps"} | |
) | |
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) | |
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) | |
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) | |
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) | |
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"}) | |
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) | |
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) | |
logging_steps: Optional[int] = field(default=5, metadata={"help": "the number of logging steps"}) | |
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) | |
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"}) | |
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) | |
save_steps: Optional[int] = field( | |
default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"} | |
) | |
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."}) | |
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push the model to HF Hub"}) | |
hub_model_id: Optional[str] = field(default="mistral-7b-finetuned-ultrachat", metadata={"help": "The name of the model on HF Hub"}) | |
parser = HfArgumentParser(ScriptArguments) | |
script_args = parser.parse_args_into_dataclasses()[0] | |
# Step 1: Load the dataset | |
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name) | |
dataset = load_dataset(script_args.dataset_name, split="train[:20000]") | |
dataset = dataset.train_test_split(test_size=0.1) | |
def prepare_dialogue(example): | |
text = "" | |
for idx, msg in enumerate(example["data"]): | |
if idx % 2 == 0: | |
text += f"<|user|>\n{msg}{tokenizer.eos_token}\n" | |
else: | |
text += f"<|assistant|>\n{msg}{tokenizer.eos_token}\n" | |
example["text"] = text | |
return example | |
dataset = dataset.map(prepare_dialogue, num_proc=4, remove_columns=["id", "data"]) | |
# Step 2: Load the model | |
if script_args.load_in_8bit and script_args.load_in_4bit: | |
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") | |
elif script_args.load_in_8bit or script_args.load_in_4bit: | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit | |
) | |
# Copy the model to each device | |
device_map = {"": Accelerator().local_process_index} | |
torch_dtype = torch.bfloat16 | |
else: | |
device_map = None | |
quantization_config = None | |
torch_dtype = None | |
model = AutoModelForCausalLM.from_pretrained( | |
script_args.model_name, | |
quantization_config=quantization_config, | |
device_map=device_map, | |
trust_remote_code=script_args.trust_remote_code, | |
torch_dtype=torch_dtype, | |
use_auth_token=script_args.use_auth_token, | |
) | |
# Step 3: Define the training arguments | |
training_args = TrainingArguments( | |
output_dir=script_args.output_dir, | |
per_device_train_batch_size=script_args.batch_size, | |
gradient_accumulation_steps=script_args.gradient_accumulation_steps, | |
gradient_checkpointing=True, | |
learning_rate=script_args.learning_rate, | |
logging_steps=script_args.logging_steps, | |
num_train_epochs=script_args.num_train_epochs, | |
max_steps=script_args.max_steps, | |
report_to=script_args.log_with, | |
save_steps=script_args.save_steps, | |
save_total_limit=script_args.save_total_limit, | |
push_to_hub=script_args.push_to_hub, | |
hub_model_id=script_args.hub_model_id, | |
bf16=True, | |
lr_scheduler_type="cosine", | |
warmup_ratio=0.1, | |
evaluation_strategy="epoch", | |
logging_first_step=True, | |
) | |
# Step 4: Define the LoraConfig | |
if script_args.use_peft: | |
peft_config = LoraConfig( | |
r=script_args.peft_lora_r, | |
lora_alpha=script_args.peft_lora_alpha, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
else: | |
peft_config = None | |
# Step 5: Define the Trainer | |
trainer = SFTTrainer( | |
model=model, | |
args=training_args, | |
max_seq_length=script_args.seq_length, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["test"], | |
dataset_text_field=script_args.dataset_text_field, | |
peft_config=peft_config, | |
packing=True, | |
) | |
trainer.train() | |
# Step 6: Save the model | |
trainer.save_model(script_args.output_dir) |
I am trying to train with same script using 3* A100 (80GB) but i am not sure if multi-gpu is working as expected because i am getting same cuda oom error with exact same memory required and memory availabe message when using 2 GPUs and 3 GPUs. I have activated peft and tried training but on both setup i.e. 2 GPUs and 3GPUs and on both setup i get my training run for 0.5 epochs and i get cuda OOM error again. Only difference is my max_length which i have set to 4800.
Is this behavior expected? Is 3 A100(80GB) insufficient to finetune the Mistral 7b model even with the Lora setup? I am continously training to train model on multi-gpu, but i am not getting any success alway stuck on cuda OOM error.
Also i find no proper resource how to optimally use multi-gpu using accelerate and deepspeed, i only see direct codingg and it is said it will activate multi-gpu. If there are any guides or resource i can follow, please share it will be really helpful. Thank you.
I'm not sure if you're chucking the dataset to have uniform lengths. But in my experience, this trick would reduce the memory spikes.
I am trying to train with same script using 3* A100 (80GB) but i am not sure if multi-gpu is working as expected because i am getting same cuda oom error with exact same memory required and memory availabe message when using 2 GPUs and 3 GPUs. I have activated peft and tried training but on both setup i.e. 2 GPUs and 3GPUs and on both setup i get my training run for 0.5 epochs and i get cuda OOM error again. Only difference is my max_length which i have set to 4800.
Is this behavior expected? Is 3 A100(80GB) insufficient to finetune the Mistral 7b model even with the Lora setup? I am continously training to train model on multi-gpu, but i am not getting any success alway stuck on cuda OOM error.
Also i find no proper resource how to optimally use multi-gpu using accelerate and deepspeed, i only see direct codingg and it is said it will activate multi-gpu. If there are any guides or resource i can follow, please share it will be really helpful. Thank you.
After some research, I got it running successfully by using ZeRO3-offload model to the CPU.
To follow along, edit the following two values of the file trl/examples/accelerate_configs/deepspeed_zero3.yaml
offload_optimizer_device: 'cpu'
offload_param_device: 'cpu'
What is the version of transformers and torch are you using? (I'm getting errors running the code)
Thanks
What is the version of transformers and torch are you using? (I'm getting errors running the code) Thanks
Ubuntu 20.04
transformers==4.36.0
deepspeed==0.9.5
torch==2.1.0
I had the similar question. shouldn't the <|user|> and <|assistant|> be added as special tokens ? @TheBlackHacker @lewtun
I had the similar question. shouldn't the <|user|> and <|assistant|> be added as special tokens ? @TheBlackHacker @lewtun
I don't think it needed, special tokens is for start/end or do something special to a sentence. The <|user|> and <|assistant|> is just turn define in the prompt. You can even try with just "User:" and "Assistant:".
Yes, whether to treat <|user|> and <|assistant|> as special tokens is optional - one advantage of doing so is that you will save some tokens during inference (i.e. more space in your context window) since e.g. <|user|>
won't be split into subwords
Thanks, I ran this on 4 A10G gpus with batch size =1, Using FSDP
One very quick question: If the model fits in single GPU with BATCH = 1
, let's say phi-2
with Flash Attention 2 + tf32 + bfloat16 + packing + LORA
, will using Deepspeed ZERO 3
give better performance than DDP or any other strategy? Using this script with torchrun .....
on the given above configurations?
Thanks in advance :)
@lewtun ?
One very quick question: If the model fits in single GPU with
BATCH = 1
, let's sayphi-2
withFlash Attention 2 + tf32 + bfloat16 + packing + LORA
, will usingDeepspeed ZERO 3
give better performance than DDP or any other strategy? Using this script withtorchrun .....
on the given above configurations?Thanks in advance :)
@lewtun ?
ZeRO 2 would be better in performance. You can try ZeRO++, it's better than ZeRO3 because it can combine both GPU + CPU computing power.
Btw, you should run the script with accelerate launch ...
Thanks @TheBlackHacker for prompt response. I had a notion that ZERO versions are needed/useful when model can't fit in one GPU due to offloading etc. Would definitely try it instead of the vanilla DDP approaches. Any idea on how to run this script with ZERO 2 or 3? I'll use Phi-2 as long as it's Flash Attention 2 issue is resolved.
Thanks @TheBlackHacker for prompt response. I had a notion that ZERO versions are needed/useful when model can't fit in one GPU due to offloading etc. Would definitely try it instead of the vanilla DDP approaches. Any idea on how to run this script with ZERO 2 or 3? I'll use Phi-2 as long as it's Flash Attention 2 issue is resolved.
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml --gradient_accumulation_steps 8 examples/scripts/sft_trainer.py
Thanks, I ran this on 4 A10G gpus with batch size =1, Using FSDP
could you share the script?
Thanks, I ran this on 4 A10G gpus with batch size =1, Using FSDP
could you share the script?
As a point of reference, this is what the training loss looks like after 1 epoch of training: