Skip to content

Instantly share code, notes, and snippets.

@allanj
Last active January 11, 2024 13:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save allanj/a394395db1dd5e1dba5b5843978d0432 to your computer and use it in GitHub Desktop.
Save allanj/a394395db1dd5e1dba5b5843978d0432 to your computer and use it in GitHub Desktop.
demo_sft_script
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, PreTrainedTokenizerFast, set_seed, AutoModelForCausalLM, AutoConfig
from tqdm import tqdm
import argparse
import torch
import torch.nn as nn
import logging
from typing import Dict, Tuple
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
from functools import partial
from datasets import load_dataset
from dataclasses import dataclass
from transformers import AdamW, get_linear_schedule_with_warmup
def parse_arguments(parser: argparse.ArgumentParser):
# data Hyperparameters
parser.add_argument('--batch_size', type=int, default=2) ## for A100
parser.add_argument('--train_num', type=int, default=20, help="The number of training data, -1 means all data")
parser.add_argument('--test_num', type=int, default=20, help="The number of development data, -1 means all data")
parser.add_argument('--max_length', type=int, default=512, help="maximum length for training")
parser.add_argument('--pretrained_model_path', type=str, default="facebook/galactica-125m") # meta-llama/Llama-2-7b-hf
# model
parser.add_argument('--seed', type=int, default=42, help="random seed")
# training
parser.add_argument('--mode', type=str, default="train", choices=["train", "test"], help="learning rate of the AdamW optimizer")
parser.add_argument('--learning_rate', type=float, default=2e-5, help="learning rate of the AdamW optimizer")
parser.add_argument('--max_grad_norm', type=float, default=1.0, help="The maximum gradient norm")
parser.add_argument('--num_epochs', type=int, default=20, help="The number of epochs to run")
parser.add_argument('--fp16', type=int, default=0, choices=[0, 1], help="using fp16 to train the model")
parser.add_argument('--num_workers', type=int, default=8, help="number of workers in data loader")
args = parser.parse_args()
# Print out the arguments
for k in args.__dict__:
logger.info(f"{k} = {args.__dict__[k]}")
return args
def get_optimizers(model: nn.Module, learning_rate:float, num_training_steps: int, weight_decay:float = 0.01,
warmup_step: int = -1, eps:float = 1e-8, use_lora: bool = False) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
# no_decay = ["b ias", "LayerNorm.weight", 'LayerNorm.bias']
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer_grouped_parameters = optimizer_grouped_parameters if use_lora else model.parameters()
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=eps, no_deprecation_warning=True) # , correct_bias=False)
# optimizer = AdamW(optimizer_grouped_parameters, eps=eps) # , correct_bias=False)
logger.info(f"optimizer: {optimizer}")
warmup_step = warmup_step if warmup_step >= 0 else int(0.1 * num_training_steps)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_step, num_training_steps=num_training_steps
)
return optimizer, scheduler
def tokenize_function(examples: Dict,
tokenizer: PreTrainedTokenizerFast,
is_train: bool,
max_length: int): # note: be careful, we only add this for validation loss evaluation
# Tokenize the texts
features = {"question": [], "input_ids": [], "attention_mask": [], "answer": [], "source_len": []}
for idx, question in enumerate(examples["question"]):
source_text_res = tokenizer(question)
answer_text = examples["answer"][idx]
answer_text_res = tokenizer(answer_text)
# Because llama tokenizer will have <bos> in front of the sequence
answer_text_res_input_ids = answer_text_res["input_ids"] if len(answer_text_res["input_ids"]) > 0 and answer_text_res["input_ids"][0] != tokenizer.bos_token_id else answer_text_res["input_ids"][1:]
input_ids = source_text_res["input_ids"] + answer_text_res_input_ids if is_train else source_text_res["input_ids"]
if is_train:
input_ids = input_ids + [tokenizer.eos_token_id]
input_ids = input_ids[:max_length]
attention_mask = [1]*len(input_ids)
source_len = len(source_text_res["input_ids"])
if source_len > max_length:
source_len = max_length
features["source_len"].append(source_len)
features["question"].append(question)
features["input_ids"].append(input_ids)
features["attention_mask"].append(attention_mask)
features["answer"].append(answer_text)
return features
@dataclass
class PaddedCollator:
tokenizer: PreTrainedTokenizerFast
label_pad_token_id: int = -100
def __call__(self, features, return_tensors=None):
batch = {"input_ids": [], "attention_mask": [], "labels": []}
max_input_length = max(len(x["input_ids"]) for x in features)
for feature in features:
# change to left padding
left_padded_length = max_input_length - len(feature["input_ids"])
input_ids = [self.tokenizer.pad_token_id] * left_padded_length + feature["input_ids"]
attention_mask = [0] * (max_input_length - len(feature["attention_mask"])) + feature["attention_mask"]
labels = [self.label_pad_token_id] * (left_padded_length + feature["source_len"]) + feature["input_ids"][feature["source_len"]:]
batch["input_ids"].append(input_ids)
batch["attention_mask"].append(attention_mask)
batch["labels"].append(labels)
batch["input_ids"] = torch.tensor(batch["input_ids"])
batch["attention_mask"] = torch.tensor(batch["attention_mask"])
batch["labels"] = torch.tensor(batch["labels"])
return batch
def train(args, train_dataloader: DataLoader, num_epochs: int,
auto_model_name: str,
tokenizer: PreTrainedTokenizerFast,
test_dataloader: DataLoader = None):
gradient_accumulation_steps = 1
t_total = int(len(train_dataloader) // gradient_accumulation_steps * num_epochs)
is_llama = "llama" in auto_model_name.lower()
if args.fp16:
current_dtype = torch.bfloat16 if is_llama else torch.float16
else:
current_dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(auto_model_name,
pad_token_id=tokenizer.pad_token_id,
low_cpu_mem_usage=not is_llama,
torch_dtype=current_dtype,
return_dict=True)
optimizer, scheduler = get_optimizers(model=model, learning_rate=args.learning_rate, num_training_steps=t_total, warmup_step=-1)
# only use the line below for torch 2.0
# model = torch.compile(model)
best_pfm = -1
model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler)
global_step = 0
for epoch in range(num_epochs):
total_loss = 0
model.train()
for iter, feature in tqdm(enumerate(train_dataloader, 1), desc="--training batch", total=len(train_dataloader)):
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(enabled=bool(args.fp16), dtype=current_dtype):
loss = model(**feature).loss
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
total_loss += loss.item()
optimizer.step()
scheduler.step()
global_step += 1
if global_step % 10 == 0 and accelerator.is_main_process:
logger.info(f"epoch: {epoch}, global_step: {global_step}, current mean loss: {total_loss / iter:.2f}")
if accelerator.is_main_process:
logger.info(f"Finish epoch: {epoch}, loss: {total_loss:.2f}, mean loss: {total_loss / len(train_dataloader):.2f}")
if test_dataloader is not None:
test_pfm = evaluate(test_dataloader, model, fp16=bool(args.fp16), tokenizer=tokenizer, max_gen_len=args.max_length,
name="test", current_dtype=current_dtype)
if test_pfm > best_pfm:
# because if use lora, we use loss to compare
logger.info(f"[Model Info] Saving the best model with best valid performance {test_pfm:.6f} at epoch {epoch}")
best_pfm = test_pfm
# you can also save model here
if accelerator.is_main_process:
logger.info(f"[Model Info] Best validation performance: {best_pfm}")
def evaluate(valid_dataloader: DataLoader, model: nn.Module, fp16: bool,
tokenizer,
max_gen_len: int,
name: str,
current_dtype: torch.dtype) -> float:
model.eval()
predictions = []
with torch.no_grad():
for index, feature in tqdm(enumerate(valid_dataloader), desc="--validation", total=len(valid_dataloader)):
with torch.cuda.amp.autocast(enabled=fp16, dtype=current_dtype):
## Note: need to check if the underlying model has revised the "prepare_inputs_for_generation" method
module = accelerator.unwrap_model(model)
generated_ids = module.generate(input_ids=feature["input_ids"],
attention_mask=feature["attention_mask"],
num_beams=1,
max_length=max_gen_len,
eos_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
do_sample=False).sequences
generated_ids = generated_ids[:, feature["input_ids"].size(1):].contiguous()
generated_ids = accelerator.pad_across_processes(generated_ids, dim=1, pad_index=tokenizer.eos_token_id)
generated_ids = accelerator.gather_for_metrics(generated_ids)
# note: remember to make it False, if you want space, new line token.
prediction = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
predictions.extend(prediction)
dataset = valid_dataloader.dataset
all_predicted_vals = []
results = []
answer_values = []
for idx, (prediction_step, answer_value) in enumerate(zip(predictions, dataset["answer"])):
## step extraction and cleaning
split_vals = prediction_step.split("####")
all_predicted_vals.append(split_vals[1].strip() if len(split_vals) > 1 else None)
results.append({
"question": dataset["question"][idx],
"answer_value": answer_value,
"prediction": split_vals,
})
answer_values.append(answer_value)
## eval the overall accuracy (on the main process only)
accuracy = 0
if accelerator.is_main_process:
corr = 0
for idx, (predicted_value, gold_val) in enumerate(zip(all_predicted_vals, answer_values)):
try:
correctness = abs(predicted_value - float(gold_val)) <= 1e-2
except:
correctness = False
results[idx]["correctness"] = correctness
# sometimes the solving res is not json serializable
results[idx]["solving_res"] = predicted_value
if correctness:
corr += 1
accuracy = corr / len(results) * 100
logger.info(f"[Eval Info] {name} accuracy: {accuracy:.2f}%.")
return accuracy
def main():
parser = argparse.ArgumentParser(description="any_project_name")
args = parse_arguments(parser)
set_seed(args.seed)
logger.info("[Data Info] Reading all data")
# read dataset
# load dataset from huggingface (see https://huggingface.co/datasets/gsm8k)
dataset = load_dataset("gsm8k", "main")
train_data = dataset["train"]
test_data = dataset["test"]
# perform data tokenization
# load llama tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, use_fast=True)
collator = PaddedCollator(tokenizer=tokenizer) # padding side is left
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if tokenizer.eos_token_id is None and tokenizer.pad_token_id is None:
# because galactica does not have eos and pad token
# only for galactica
tokenizer.eos_token_id = 2
tokenizer.pad_token_id = 1
tokenizer.bos_token_id = 0
assert tokenizer.convert_tokens_to_ids(["</s>"])[0] == 2
assert tokenizer.convert_tokens_to_ids(["<pad>"])[0] == 1
assert tokenizer.convert_tokens_to_ids(["<s>"])[0] == 0
train_dataset = train_data.select(range(args.train_num)) if args.train_num > 0 else train_data
test_dataset = test_data.select(range(args.test_num)) if args.test_num > 0 else test_data
train_dataset = train_dataset.map(
lambda examples: tokenize_function(examples,
tokenizer,
is_train=True,
max_length=args.max_length),
batched=True, batch_size=1000,
remove_columns=train_dataset.column_names,
load_from_cache_file=False, num_proc=8,
desc="Running tokenizer on train dataset",
)
test_dataset = test_dataset.map(
lambda examples: tokenize_function(examples, tokenizer, is_train=False,
max_length=args.max_length,),
batched=True, batch_size=1000,
remove_columns=test_dataset.column_names,
load_from_cache_file=False, num_proc=8,
desc="Running tokenizer on test dataset",
)
logger.info(f"[Train Data ] after tokenized: {train_dataset}")
logger.info(f"[Test Data ] after tokenized: {test_dataset}")
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collator)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collator)
# Train the model
train(args, train_dataloader,
num_epochs=args.num_epochs,
auto_model_name=args.pretrained_model_path,
test_dataloader=test_dataloader,
tokenizer=tokenizer)
if __name__ == "__main__":
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
# accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
accelerator = Accelerator()
logger = get_logger(__name__, log_level="INFO")
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
tqdm = partial(tqdm, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}', disable=not accelerator.is_local_main_process)
main()
@allanj
Copy link
Author

allanj commented Jan 11, 2024

The task in this example: math problem solving

requirements

pip3 install transformers
pip3 install accelerate
pip3 install sentencepiece

If use llama:
change "facebook/galactica-125m" to "meta-llama/Llama-2-7b-hf"

train_num and test_num is set to 20 for demonstration purpose, set to -1 for full data.

Running the script

Before experiments, configure your environment:

accelerate config

Reference: https://github.com/huggingface/accelerate

accelerate launch --main_process_port 8888 demo_sft.py --train_num=10 --test_num=10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment