Skip to content

Instantly share code, notes, and snippets.

@ruslanmv
Last active May 9, 2024 16:39
Show Gist options
  • Save ruslanmv/331c98a5411ea5c5f13caabd5e85e505 to your computer and use it in GitHub Desktop.
Save ruslanmv/331c98a5411ea5c5f13caabd5e85e505 to your computer and use it in GitHub Desktop.
Finetuning Llama 3 Mutiple GPU
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from peft import LoraConfig
from trl import SFTTrainer
from transformers import TrainingArguments
from accelerate import PartialState
# Load the dataset
dataset_name = "ruslanmv/ai-medical-dataset"
dataset = load_dataset(dataset_name, split="train")
# Select the first 1M rows of the dataset
dataset = dataset.select(range(100))
# Device map
device_map = 'auto' # for PP and running with `python test_sft.py`
#device_map = 'DDP' # for DDP and running with `accelerate launch test_sft.py`
if device_map == "DDP":
device_string = PartialState().process_index
device_map = {'': device_string}
# Load the model + tokenizer
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True,
use_cache=False,
device_map=device_map
)
# PEFT config
lora_alpha = 16
lora_dropout = 0.1
lora_r = 32 # 64
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"],
modules_to_save=["embed_tokens", "input_layernorm", "post_attention_layernorm", "norm"],
)
# Args
max_seq_length = 512
output_dir = "./results"
per_device_train_batch_size = 4 # reduced batch size to avoid OOM
gradient_accumulation_steps = 2
optim = "adamw_torch"
save_steps = 10
logging_steps = 1
learning_rate = 2e-4
max_grad_norm = 0.3
max_steps = 1 # 300 Approx the size of guanaco at bs 8, ga 2, 2 GPUs.
warmup_ratio = 0.1
lr_scheduler_type = "cosine"
training_arguments = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
optim=optim,
save_steps=save_steps,
logging_steps=logging_steps,
learning_rate=learning_rate,
fp16=True,
max_grad_norm=max_grad_norm,
max_steps=max_steps,
warmup_ratio=warmup_ratio,
group_by_length=True,
lr_scheduler_type=lr_scheduler_type,
gradient_checkpointing=True, # gradient checkpointing
gradient_checkpointing_kwargs={"use_reentrant": True}, #must be false for DDP
ddp_find_unused_parameters=True, # if use DDP is false, otherwise true
report_to="wandb",
)
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="context",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
)
# Train :)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment