Skip to content

Instantly share code, notes, and snippets.

@amangup
Created October 15, 2023 04:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amangup/e49ca9fc042caa062eeb6c1355fcd6c9 to your computer and use it in GitHub Desktop.
Save amangup/e49ca9fc042caa062eeb6c1355fcd6c9 to your computer and use it in GitHub Desktop.
LLama2 70B training script on multiple GPUs, using Accelerate and 4-bit QLoRA
import wandb
wandb.login()
# Start a new wandb run
project = 'llama_multigpu_peft'
job_type = 'fine-tuning'
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
import accelerate
import bitsandbytes
import torch
import optuna
import json
import math
import gc
import yaml
from pynvml import *
from datasets import load_dataset, Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
TrainingArguments,
pipeline,
logging,
)
from accelerate import Accelerator
from peft import LoraConfig, PeftModel, TaskType
from trl import SFTTrainer
from base_chat_prompt import conv_to_base_chat_prompt
from qna_dataset import get_qna_dataset
def print_gpu_utilization():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
print(f"GPU memory occupied: {info.used//1024**2} MB.")
def print_summary(result):
print(f"Time: {result.metrics['train_runtime']:.2f}")
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
print_gpu_utilization()
dataset = get_qna_dataset()
print(dataset)
print(dataset['train'][0]['text'])
model_name = 'meta-llama/Llama-2-70b-hf'
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj', 'up_proj']
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
output_dir = './llama70_qlora_multigpu'
def train_model(model, params):
training_arguments = TrainingArguments(
output_dir=f"{output_dir}",
num_train_epochs=params['num_train_epochs'],
per_device_train_batch_size=params['batch_size'],
per_device_eval_batch_size=params['batch_size'],
learning_rate=params['learning_rate'],
weight_decay=params['weight_decay'],
optim="adamw_torch_fused",
gradient_accumulation_steps=4,
gradient_checkpointing=False,
bf16=True,
save_steps=5,
logging_steps=5,
load_best_model_at_end=True,
evaluation_strategy="steps",
report_to="wandb",
log_level="debug",
max_steps=-1,
)
peft_config = LoraConfig(
lora_alpha=1,
lora_dropout=0.1,
target_modules=target_modules,
r=params['lora_r'],
bias="none",
task_type=TaskType.CAUSAL_LM,
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset['train'],
eval_dataset=dataset['test'],
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=1024,
tokenizer=tokenizer,
args=training_arguments,
packing=False,
)
result = trainer.train()
print_summary(result)
eval_metrics = trainer.evaluate()
print(json.dumps(eval_metrics, indent=4))
print(f"Perplexity: {math.exp(eval_metrics['eval_loss']):.2f}")
return model, trainer, eval_metrics
def train_optimal():
current_device = Accelerator().process_index
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map={"": current_device}
)
print_gpu_utilization
params = {
"learning_rate": 0.0025,
"weight_decay": 0.025,
"num_train_epochs": 1,
"batch_size": 1,
"lora_r": 4
}
model, trainer, eval_metrics = train_model(model, params)
return model, trainer
if __name__ == "__main__":
# Train!
tuned_model, trainer = train_optimal()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment