Skip to content

Instantly share code, notes, and snippets.

@FoobarProtocol
Created October 23, 2023 21:51
Show Gist options
  • Save FoobarProtocol/0e00ea9b91ccae77c21fc6d9ca56eb5f to your computer and use it in GitHub Desktop.
Save FoobarProtocol/0e00ea9b91ccae77c21fc6d9ca56eb5f to your computer and use it in GitHub Desktop.
This is one iteration of the fine-tuning script for CodeT5+; warning I don't think that this script is complete
import argparse
import os
import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, set_peft_model_state_dict
from torch.utils.data import IterableDataset
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, logging, set_seed
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
"""
Fine-Tune StarCoder on Code Alpaca/SE
"""
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
kwargs["model"].save_pretrained(checkpoint_folder)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
torch.save({}, pytorch_model_path)
return control
class LoadBestPeftModelCallback(TrainerCallback):
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
print(f"Loading best peft model from {state.best_model_checkpoint} (score: {state.best_metric}).")
best_model_path = os.path.join(state.best_model_checkpoint, "adapter_model.bin")
adapters_weights = torch.load(best_model_path)
model = kwargs["model"]
set_peft_model_state_dict(model, adapters_weights)
return control
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="bigcode/large-model")
parser.add_argument("--dataset_name", type=str, default="HuggingFaceH4/CodeAlpaca_20K")
parser.add_argument("--subset", type=str)
parser.add_argument("--split", type=str)
parser.add_argument("--size_valid_set", type=int, default=10000)
parser.add_argument("--streaming", action="store_true")
parser.add_argument("--shuffle_buffer", type=int, default=5000)
parser.add_argument("--input_column_name", type=str, default="prompt")
parser.add_argument("--output_column_name", type=str, default="completion")
parser.add_argument("--seq_length", type=int, default=2048)
parser.add_argument("--max_steps", type=int, default=10000)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
parser.add_argument("--eos_token_id", type=int, default=49152)
parser.add_argument("--lora_r", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
parser.add_argument("--num_warmup_steps", type=int, default=100)
parser.add_argument("--weight_decay", type=float, default=0.05)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--no_fp16", action="store_false")
parser.add_argument("--bf16", action="store_true", default=True)
parser.add_argument("--no_gradient_checkpointing", action="store_false", default=False)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num_workers", type=int, default=None)
parser.add_argument("--output_dir", type=str, default="./checkpoints")
parser.add_argument("--log_freq", default=100, type=int)
parser.add_argument("--eval_freq", default=100, type=int)
parser.add_argument("--save_freq", default=1000, type=int)
return parser.parse_args()
def chars_token_ratio(dataset, tokenizer, input_column_name="prompt", output_column_name="completion", nb_examples=400):
"""
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
text = prepare_sample_text(example, input_column_name, output_column_name)
total_characters += len(text)
if tokenizer.is_fast:
total_tokens += len(tokenizer(text).tokens())
else:
total_tokens += len(tokenizer.tokenize(text))
return total_characters / total_tokens
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def prepare_sample_text(example, input_column_name="prompt", output_column_name="completion"):
"""Prepare the text from a sample of the dataset."""
text = f"Question: {example[input_column_name]}\\n\\nAnswer: {example[output_column_name]}"
return text
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
num_of_sequences (int): Number of token sequences to keep in buffer.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
dataset,
infinite=False,
seq_length=1024,
num_of_sequences=1024,
chars_per_token=3.6,
input_column_name="prompt",
output_column_name="completion"
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else args.eos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.input_column_name = input_column_name
self.output_column_name = output_column_name
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(prepare_sample_text(next(iterator), self.input_column_name, self.output_column_name))
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(input_ids),
"labels": torch.LongTensor(input_ids),
}
def create_datasets(tokenizer, args):
dataset = load_dataset(
args.dataset_name,
data_dir=args.subset,
split=args.split,
use_auth_token=True,
num_proc=args.num_workers if not args.streaming else None,
streaming=args.streaming,
)
if args.streaming:
print("Loading the dataset in streaming mode")
valid_data = dataset.take(args.size_valid_set)
train_data = dataset.skip(args.size_valid_set)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
else:
train_data = dataset["train"]
valid_data = dataset["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
chars_per_token = chars_token_ratio(train_data, tokenizer, args.input_column_name, args.output_column_name)
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
train_dataset = ConstantLengthDataset(
tokenizer,
train_data,
infinite=True,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
input_column_name=args.input_column_name,
output_column_name=args.output_column_name
)
valid_dataset = ConstantLengthDataset(
tokenizer,
valid_data,
infinite=False,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
input_column_name=args.input_column_name,
output_column_name=args.output_column_name
)
return train_dataset, valid_dataset
def run_training(args, train_data, val_data):
print("Loading the model")
# disable caching mechanism when using gradient checkpointing
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
use_auth_token=True,
use_cache=not args.no_gradient_checkpointing,
load_in_8bit=True,
device_map={"": Accelerator().process_index},
)
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules = ["c_proj", "c_attn", "q_attn"]
)
model = get_peft_model(model, lora_config)
print_trainable_parameters(model)
train_data.start_iteration = 0
print("Starting main loop")
training_args = TrainingArguments(
output_dir=args.output_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=args.max_steps,
eval_steps=args.eval_freq,
save_steps=args.save_freq,
logging_steps=args.log_freq,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.num_warmup_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=not args.no_gradient_checkpointing,
fp16=not args.no_fp16,
bf16=args.bf16,
weight_decay=args.weight_decay,
run_name="StarCoder-finetuned",
report_to="wandb",
ddp_find_unused_parameters=False,
)
trainer = Trainer(model=model, args=training_args, train_dataset=train_data, eval_dataset=val_data, callbacks=[SavePeftModelCallback, LoadBestPeftModelCallback])
print("Training...")
trainer.train()
print("Saving last checkpoint of the model")
model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
def main(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_auth_token=True)
train_dataset, eval_dataset = create_datasets(tokenizer, args)
run_training(args, train_dataset, eval_dataset)
if __name__ == "__main__":
args = get_args()
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
logging.set_verbosity_error()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment