Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Created August 3, 2023 09:43
Show Gist options
  • Save younesbelkada/cdda6e4abcb09e58f6324d75e0d88862 to your computer and use it in GitHub Desktop.
Save younesbelkada/cdda6e4abcb09e58f6324d75e0d88862 to your computer and use it in GitHub Desktop.
Train adapters using transformers integration of PEFT
from datasets import load_dataset
import torch
from peft import LoraConfig, prepare_model_for_int8_training
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
dataset_name = "timdettmers/openassistant-guanaco"
dataset = load_dataset(dataset_name, split="train")
model_name = "facebook/opt-350m"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map={"":0}
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
)
model = prepare_model_for_int8_training(model)
model.add_adapter(peft_config)
output_dir = "./train_mpt_7b"
per_device_train_batch_size = 2
gradient_accumulation_steps = 16
optim = "paged_adamw_32bit"
save_steps = 10
logging_steps = 1
learning_rate = 1e-4
max_grad_norm = 0.3
max_steps = 1000
warmup_ratio = 0.03
lr_scheduler_type = "linear"
training_arguments = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
optim=optim,
save_steps=save_steps,
logging_steps=logging_steps,
learning_rate=learning_rate,
fp16=True,
max_grad_norm=max_grad_norm,
max_steps=max_steps,
warmup_ratio=warmup_ratio,
lr_scheduler_type=lr_scheduler_type,
)
max_seq_length = 512
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=True,
)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment