Skip to content

Instantly share code, notes, and snippets.

@pacman100
Created February 17, 2023 04:39
Show Gist options
  • Save pacman100/c07b7c5b279543d0c1d164bf9c03967b to your computer and use it in GitHub Desktop.
Save pacman100/c07b7c5b279543d0c1d164bf9c03967b to your computer and use it in GitHub Desktop.
import gc
import os
import sys
import psutil
import threading
import argparse
import transformers
import datasets
import numpy as np
import torch
from accelerate import Accelerator
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM,DataCollatorForSeq2Seq, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from datasets import load_from_disk
from peft import LoraConfig, TaskType, get_peft_model, get_peft_model_state_dict
from tqdm import tqdm
import evaluate
import nltk
import numpy as np
from nltk.tokenize import sent_tokenize
nltk.download("punkt", quiet=True)
import pandas as pd
# Converting Bytes to Megabytes
def b2mb(x):
return int(x / 2**20)
# This context manager is used to track the peak memory usage of the process
class TorchTracemalloc:
def __enter__(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = torch.cuda.memory_allocated()
self.process = psutil.Process()
self.cpu_begin = self.cpu_mem_used()
self.peak_monitoring = True
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
peak_monitor_thread.daemon = True
peak_monitor_thread.start()
return self
def cpu_mem_used(self):
"""get resident set size memory for the current process"""
return self.process.memory_info().rss
def peak_monitor_func(self):
self.cpu_peak = -1
while True:
self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
# can't sleep or will not catch the peak right (this comment is here on purpose)
# time.sleep(0.001) # 1msec
if not self.peak_monitoring:
break
def __exit__(self, *exc):
self.peak_monitoring = False
gc.collect()
torch.cuda.empty_cache()
self.end = torch.cuda.memory_allocated()
self.peak = torch.cuda.max_memory_allocated()
self.used = b2mb(self.end - self.begin)
self.peaked = b2mb(self.peak - self.begin)
self.cpu_end = self.cpu_mem_used()
self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
# Metric
metric = evaluate.load("rouge")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
def parse_arge():
"""Parse the arguments."""
parser = argparse.ArgumentParser()
# add model id and dataset path argument
parser.add_argument("--model_id", type=str, default="google/flan-t5-xl", help="Model id to use for training.")
parser.add_argument("--dataset_path", type=str, default="data", help="Path to the already processed dataset.")
# add training hyperparameters for epochs, batch size, learning rate, and seed
parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to train for.")
parser.add_argument("--train_batch_size", type=int, default=8, help="Batch size to use for training.")
parser.add_argument("--eval_batch_size", type=int, default=8, help="Batch size to use for testing.")
parser.add_argument("--lr", type=float, default=3e-3, help="Learning rate to use for training.")
parser.add_argument("--seed", type=int, default=42, help="Seed to use for training.")
parser.add_argument("--use_peft", action="store_true", help="whether to enable peft")
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
)
parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
parser.add_argument("--tracking_steps", type=int, default=100, help="interval for tracking")
args = parser.parse_args()
return args
def training_function(args):
# set seed
set_seed(args.seed)
# Initialize accelerator with config from configs/accelerate_ds_z3.yaml
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
# To have only one message (and not 8) per logs of Transformers or Datasets, we set the logging verbosity
if accelerator.is_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_id)
# define LorA fine-tuning config
if args.use_peft:
peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
# Create PEFT model with LoraConfig
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# load dataset from disk and tokenizer
train_dataset = load_from_disk(os.path.join(args.dataset_path, "train"))
eval_dataset = load_from_disk(os.path.join(args.dataset_path, "eval"))
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8
)
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.train_batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.eval_batch_size)
# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
# lr scheduler
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=(len(train_dataloader) * args.epochs),
)
# Prepare model, optimizer and dataloaders
model, train_dataloader, eval_dataloader, optimizer, lr_scheduler = accelerator.prepare(
model, train_dataloader, eval_dataloader, optimizer, lr_scheduler
)
accelerator.print(model)
is_ds_zero_3 = False
if getattr(accelerator.state, "deepspeed_plugin", None):
is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3
if args.with_tracking:
if args.use_peft:
run_name = f"{args.use_peft=}_{args.model_id}_{peft_config.peft_type}_{peft_config.task_type}"
else:
run_name = f"{args.use_peft=}_{args.model_id}"
experiment_config = vars(args)
accelerator.init_trackers("FlanT5_Dialog_Summarzation", config={}, init_kwargs={"wandb":{"name":run_name}})
# Add a progress bar to keep track of training.
progress_bar = tqdm(range(args.epochs * len(train_dataloader)), disable=not accelerator.is_main_process)
# Now we train the model
for epoch in range(args.epochs):
total_train_loss = 0
with TorchTracemalloc() as tracemalloc:
model.train()
for step, batch in enumerate(train_dataloader):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
total_train_loss += loss.item()
if (step+1)%args.tracking_steps==0:
train_loss_so_far = total_train_loss/(step+1)
accelerator.print(f"train loss: {train_loss_so_far}")
accelerator.log(
{
"train/loss": train_loss_so_far
},
step=(epoch)*(len(train_dataloader))+step+1
)
#break
accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
accelerator.print(
"GPU Total Peak Memory consumed during the train (max): {}".format(
tracemalloc.peaked + b2mb(tracemalloc.begin)
)
)
accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin)))
accelerator.print("CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used))
accelerator.print("CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked))
accelerator.print(
"CPU Total Peak Memory consumed during the train (max): {}".format(
tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
)
)
if args.with_tracking:
accelerator.log(
{
"gpu/(begin)": b2mb(tracemalloc.begin),
"gpu/(end-begin)": tracemalloc.used,
"gpu/(max-begin)": tracemalloc.peaked,
"gpu/train_total_peak_memory(max)": tracemalloc.peaked + b2mb(tracemalloc.begin),
"cpu/(begin)": b2mb(tracemalloc.cpu_begin),
"cpu/(end-begin)": tracemalloc.cpu_used,
"cpu/(max-begin)": tracemalloc.cpu_peaked,
"cpu/train_total_peak_memory(max)": tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
},
step=epoch,
)
pred_df = pd.DataFrame()
for step, batch in enumerate(tqdm(eval_dataloader)):
with torch.no_grad():
gen_kwargs = {
"early_stopping": True,
"length_penalty": 2.0,
"max_new_tokens": 50,
"min_length": 30,
"no_repeat_ngram_size": 3,
"num_beams": 4
}
generated_tokens = accelerator.unwrap_model(model).generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
synced_gpus=is_ds_zero_3,
**gen_kwargs
)
generated_tokens = accelerator.pad_across_processes(
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
)
labels = batch["labels"]
generated_tokens, labels = accelerator.gather_for_metrics((generated_tokens, labels))
generated_tokens = generated_tokens.cpu().numpy()
labels = labels.cpu().numpy()
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
if isinstance(generated_tokens, tuple):
generated_tokens = generated_tokens[0]
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)
if (step+1)%args.tracking_steps==0:
pred_df = pd.concat([pred_df, pd.DataFrame({"decoded_preds": decoded_preds,
"decoded_labels":decoded_labels})]).reset_index()
accelerator.print(pred_df)
#break
result = metric.compute(use_stemmer=True)
result = {k: round(v * 100, 4) for k, v in result.items()}
# print results
accelerator.print(f"epoch {epoch}:", result)
if args.with_tracking:
accelerator.log({f"eval/{k}":v for k,v in result.items()}, step=epoch)
accelerator.log({"eval/comparison_between_pred_and_true_summaries": pred_df})
accelerator.wait_for_everyone()
if args.use_peft:
checkpoint_name = (
f"{args.model_id}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt".replace("/", "_")
)
if args.output_dir is not None:
checkpoint_name = os.path.join(args.output_dir, checkpoint_name)
accelerator.save(get_peft_model_state_dict(model, state_dict=accelerator.get_state_dict(model)), checkpoint_name)
accelerator.wait_for_everyone()
else:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir)
def main():
args = parse_arge()
training_function(args)
if __name__ == "__main__":
main()
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
dataset_id = "samsum"
model_id="google/flan-t5-base"
# Load dataset from the hub
dataset = load_dataset(dataset_id)
# Load tokenizer of FLAN-t5-base
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")
from datasets import concatenate_datasets
# The maximum total input sequence length after tokenization.
# Sequences longer than this will be truncated, sequences shorter will be padded.
tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["dialogue"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
input_lengths = [len(x) for x in tokenized_inputs["input_ids"]]
max_source_length = int(np.percentile(input_lengths, 85)) #max([len(x) for x in tokenized_inputs["input_ids"]])
print(f"Max source length: {max_source_length}")
# The maximum total sequence length for target text after tokenization.
# Sequences longer than this will be truncated, sequences shorter will be padded."
tokenized_targets = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["summary"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
target_lengths = [len(x) for x in tokenized_targets["input_ids"]]
max_target_length = int(np.percentile(target_lengths, 90)) #max([len(x) for x in tokenized_targets["input_ids"]])
print(f"Max target length: {max_target_length}")
def preprocess_function(sample,padding="max_length"):
# add prefix to the input for t5
inputs = ["summarize: " + item for item in sample["dialogue"]]
# tokenize inputs
model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
# Tokenize targets with the `text_target` keyword argument
labels = tokenizer(text_target=sample["summary"], max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length":
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["dialogue", "summary", "id"])
print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")
tokenized_dataset["train"].save_to_disk("data/train")
tokenized_dataset["test"].save_to_disk("data/eval")
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'NO'
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: 'bf16'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
use_cpu: false
accelerate launch --config_file "ds_zero3_cpu.yaml" \
accelerate_lora_t5.py \
--model_id google/flan-t5-xxl \
--dataset_path data \
--epochs 3 \
--train_batch_size 8 \
--eval_batch_size 8 \
--lr 1e-3 \
--output_dir "temp" \
--with_tracking \
--report_to "wandb" \
--use_peft
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment