Skip to content

Instantly share code, notes, and snippets.

@spmurrayzzz
Forked from younesbelkada/finetune_sft_trl.py
Created September 12, 2023 20:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save spmurrayzzz/7ef32b5591b6d82d2a9c3b702e9c9400 to your computer and use it in GitHub Desktop.
Save spmurrayzzz/7ef32b5591b6d82d2a9c3b702e9c9400 to your computer and use it in GitHub Desktop.
Benchmarking SFT trainer with 8bit models
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from tqdm import tqdm
from accelerate import Accelerator
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
LlamaTokenizer,
HfArgumentParser,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTTrainer
tqdm.pandas()
########################################################################
# This is a fully working simple example to use trl's SFTTrainer.
#
# This example fine-tunes any causal language model (GPT-2, GPT-Neo, etc.)
# by using the SFTTrainer from trl, we will leverage PEFT library to finetune
# adapters on the model.
#
########################################################################
@dataclass
class ScriptArguments:
"""
Define the arguments used in this script.
"""
model_name: Optional[str] = field(default="decapoda-research/llama-7b-hf", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(default="ybelkada/oasst1-tiny-subset", metadata={"help": "the dataset name"})
use_8_bit: Optional[bool] = field(default=False, metadata={"help": "use 8 bit precision"})
use_seq2seq_lm: Optional[bool] = field(default=False, metadata={"help": "use seq2seq LM"})
use_4_bit: Optional[bool] = field(default=True, metadata={"help": "use 4 bit precision"})
bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
use_bnb_nested_quant: Optional[bool] = field(default=False, metadata={"help": "use nested quantization"})
use_multi_gpu: Optional[bool] = field(default=False, metadata={"help": "use multi GPU"})
use_adapters: Optional[bool] = field(default=True, metadata={"help": "use adapters"})
batch_size: Optional[int] = field(default=1, metadata={"help": "input batch size"})
max_seq_length: Optional[int] = field(default=512, metadata={"help": "max sequence length"})
optimizer_name: Optional[str] = field(default="adamw_hf", metadata={"help": "Optimizer name"})
def get_current_device():
return Accelerator().process_index
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset(script_args.dataset_name, split="train[:1%]")
# We load the model
if script_args.use_multi_gpu:
device_map = "auto"
else:
device_map = {"":get_current_device()}
if script_args.use_8_bit and script_args.use_4_bit:
raise ValueError(
"You can't use 8 bit and 4 bit precision at the same time"
)
if script_args.use_4_bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type=script_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=script_args.use_bnb_nested_quant,
)
else:
bnb_config = None
transformers_class = AutoModelForSeq2SeqLM if script_args.use_seq2seq_lm else AutoModelForCausalLM
model = transformers_class.from_pretrained(
script_args.model_name,
load_in_8bit=script_args.use_8_bit,
load_in_4bit=script_args.use_4_bit,
device_map=device_map if (script_args.use_8_bit or script_args.use_4_bit) else None,
quantization_config=bnb_config,
torch_dtype=torch.float16,
)
if script_args.use_adapters:
peft_config = LoraConfig(
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM" if not script_args.use_seq2seq_lm else "SEQ_2_SEQ_LM",
)
else:
peft_config = None
if script_args.use_8_bit:
raise ValueError(
"You need to use adapters to use 8 bit precision"
)
if "llama" in script_args.model_name:
tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
else:
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
with tempfile.TemporaryDirectory() as tmp_dir:
training_arguments = TrainingArguments(
per_device_train_batch_size=script_args.batch_size,
max_steps=10,
gradient_accumulation_steps=4,
per_device_eval_batch_size=script_args.batch_size,
output_dir=tmp_dir,
report_to=["none"],
optim=script_args.optimizer_name,
fp16=True,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="messages",
peft_config=peft_config,
max_seq_length=script_args.max_seq_length,
args=training_arguments,
)
trainer.train()
assert "adapter_model.bin" in os.listdir(tmp_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment