Skip to content

Instantly share code, notes, and snippets.

@shreyansh26
Created May 4, 2023 12:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shreyansh26/716701d73bb00a62ca9e26fcc7ab9a17 to your computer and use it in GitHub Desktop.
Save shreyansh26/716701d73bb00a62ca9e26fcc7ab9a17 to your computer and use it in GitHub Desktop.
from datasets import load_dataset, Features, Value, ClassLabel, Sequence
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
from random import randrange
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, HfArgumentParser, TrainingArguments
from datasets import concatenate_datasets
import evaluate
import numpy as np
import argparse
import sys
from dataclasses import dataclass, field
from src.train_utils import load_instruction_dataset, compute_metrics, postprocess_text, preprocess_function
from torch import nn
import torch.distributed as dist
@dataclass
class OtherArgs:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
train_file_path: str = field(
metadata={"help": "Path of training data"}
)
valid_file_path: str = field(
default=None, metadata={"help": "Path of val data"}
)
if __name__ == "__main__":
parser = HfArgumentParser((OtherArgs, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
args, training_args = parser.parse_args_into_dataclasses()
print(args)
print("*"*30)
print(training_args)
dataset = load_instruction_dataset(train_path=args.train_file_path,
valid_path=args.valid_file_path)
print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['valid'])}")
model_id = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(dataset["train"][0])
tokenized_inputs = concatenate_datasets([dataset["train"], dataset["valid"]]).map(lambda x: tokenizer(x["prompt"] + [" "]*(len(x)) + x["input_text"], truncation=True), batched=True, remove_columns=["input_text", "output_text", "prompt"])
max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]])
print(f"Max source length: {max_source_length}")
tokenized_targets = concatenate_datasets([dataset["train"], dataset["valid"]]).map(lambda x: tokenizer(x["output_text"], truncation=True), batched=True, remove_columns=["input_text", "output_text", "prompt"])
max_target_length = max([len(x) for x in tokenized_targets["input_ids"]])
print(f"Max target length: {max_target_length}")
tokenized_dataset = dataset.map(preprocess_function,
fn_kwargs={"tokenizer" : tokenizer,
"max_source_length" : max_source_length,
"max_target_length" : max_target_length},
batched=True,
remove_columns=["prompt", "input_text", "output_text"])
print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")
metric = evaluate.load("rouge")
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q", "v"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM
)
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8
)
training_args = Seq2SeqTrainingArguments(
output_dir=training_args.output_dir,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
predict_with_generate=True,
fp16=False, # Overflows with fp16
learning_rate=1e-3,
num_train_epochs=5,
# logging & evaluation strategies
evaluation_strategy="epoch",
save_strategy="no",
push_to_hub=False,
)
model = nn.parallel.DistributedDataParallel(model.cuda(), device_ids = [training_args.local_rank], output_device=training_args.local_rank)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["valid"],
compute_metrics=lambda x: compute_metrics(tokenizer, metric, x),
)
trainer.train()
# Save LoRA model
peft_model_id = training_args.output_dir
trainer.model.save_pretrained(peft_model_id)
tokenizer.save_pretrained(peft_model_id)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment