Skip to content

Instantly share code, notes, and snippets.

@rohan-paul
Created October 15, 2023 23:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-paul/8ad2be0887f2447d3f107d60c72e74f4 to your computer and use it in GitHub Desktop.
Save rohan-paul/8ad2be0887f2447d3f107d60c72e74f4 to your computer and use it in GitHub Desktop.
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
TrainingArguments,
)
from peft.tuners.lora import LoraLayer
from trl import SFTTrainer
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments for creating and preparing the model.
"""
model_name: str = field(
default="tiiuae/falcon-7b",
metadata={"help": "The model name or path from the Hugging Face hub."},
)
use_4bit: bool = field(
default=True,
metadata={"help": "Activate 4bit precision base model loading"},
)
use_nested_quant: bool = field(
default=False,
metadata={"help": "Activate nested quantization for 4bit base models"},
)
bnb_4bit_compute_dtype: str = field(
default="float16",
metadata={"help": "Compute dtype for 4bit base models"},
)
bnb_4bit_quant_type: str = field(
default="nf4",
metadata={"help": "Quantization type: fp4 or nf4"},
)
lora_alpha: int = field(default=16)
lora_dropout: float = field(default=0.1)
lora_r: int = field(default=64)
@dataclass
class ScriptArguments:
"""
Arguments for model training and data handling.
"""
local_rank: int = field(default=-1, metadata={"help": "Used for multi-gpu"})
per_device_train_batch_size: int = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=4)
learning_rate: Optional[float] = field(default=2e-4)
max_grad_norm: Optional[float] = field(default=0.3)
weight_decay: Optional[int] = field(default=0.001)
max_seq_length: Optional[int] = field(default=512)
dataset_name: Optional[str] = field(
default="timdettmers/openassistant-guanaco",
metadata={"help": "The preference dataset to use."},
)
num_train_epochs: Optional[int] = field(
default=1,
metadata={"help": "The number of training epochs for the reward model."},
)
fp16: Optional[bool] = field(
default=False,
metadata={"help": "Enables fp16 training."},
)
bf16: Optional[bool] = field(
default=False,
metadata={"help": "Enables bf16 training."},
)
packing: Optional[bool] = field(
default=False,
metadata={"help": "Use packing dataset creating."},
)
gradient_checkpointing: Optional[bool] = field(
default=True,
metadata={"help": "Enables gradient checkpointing."},
)
optim: Optional[str] = field(
default="paged_adamw_32bit",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: str = field(
default="constant",
metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
)
max_steps: int = field(default=10000, metadata={"help": "How many optimizer update steps to take"})
warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"})
group_by_length: bool = field(
default=True,
metadata={
"help": "Group sequences into batches with same length. Saves memory and speeds up training considerably."
},
)
save_steps: int = field(default=10, metadata={"help": "Save checkpoint every X updates steps."})
logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."})
def get_model_peftconfig_tokenizer(args: ModelArguments):
"""
Create the model, tokenizer, and peft_config based on provided arguments.
"""
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
# Configure BitsAndBytes for model quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=args.use_4bit,
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=args.use_nested_quant,
)
# Alert for bfloat16 acceleration support
if compute_dtype == torch.float16 and args.use_4bit:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16, you can accelerate training with --bf16")
print("=" * 80)
# Load the model with quantization configuration
model = AutoModelForCausalLM.from_pretrained(
args.model_name, quantization_config=bnb_config, device_map={"": 0}, trust_remote_code=True
)
# Define Lora Configuration
peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
],
)
# Load the tokenizer and set padding token
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
# Need to do below for models like Falcon-7B, GPT-2 etc,
# because it doesn't have an official pad token.
tokenizer.pad_token = tokenizer.eos_token
return model, peft_config, tokenizer
def parse_arguments():
"""
Parse Model and Script Arguments.
Returns:
ModelArguments, ScriptArguments
"""
parser = HfArgumentParser((ModelArguments, ScriptArguments))
return parser.parse_args_into_dataclasses()
def load_training_data(dataset_name: str):
"""
Load dataset for training.
Args:
dataset_name (str): Name or path of the dataset.
Returns:
Dataset object
"""
return load_dataset(dataset_name, split="train")
def get_training_args(script_args: ScriptArguments):
"""
Get Training Arguments from ScriptArguments.
Args:
script_args (ScriptArguments): Parsed ScriptArguments.
Returns:
TrainingArguments
"""
return TrainingArguments(
output_dir="./results",
per_device_train_batch_size=script_args.per_device_train_batch_size,
# ... (rest of your args from script_args)
)
def adjust_model_for_bf16(trainer, bf16: bool):
"""
Adjust Model Layers for bf16.
Args:
trainer (SFTTrainer): Initialized SFTTrainer object.
bf16 (bool): Flag to indicate usage of bf16.
"""
for name, module in trainer.model.named_modules():
if isinstance(module, LoraLayer) and bf16:
module = module.to(torch.bfloat16)
if "norm" in name:
module = module.to(torch.float32)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight") and bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
# Main Execution:
model_args, script_args = parse_arguments()
model, peft_config, tokenizer = get_model_peftconfig_tokenizer(model_args)
model.config.use_cache = False
dataset = load_training_data(script_args.dataset_name)
training_arguments = get_training_args(script_args)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=script_args.packing,
)
adjust_model_for_bf16(trainer, script_args.bf16)
# Train the Model
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment