Skip to content

Instantly share code, notes, and snippets.

@weiqi-dyania
Last active May 19, 2023 04:47
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save weiqi-dyania/a16cb88ea3dc0433c6b698c6bb26911d to your computer and use it in GitHub Desktop.
Save weiqi-dyania/a16cb88ea3dc0433c6b698c6bb26911d to your computer and use it in GitHub Desktop.
"""
An example script to fine tune stanford/BioMedLM on dummy data without trainer and accelerate
"""
import math
import copy
import datasets
import torch
from datasets import Dataset, load_dataset
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
from transformers import (
AdamW,
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
get_scheduler,
set_seed,
)
# training parameters
per_device_train_batch_size = 8
per_device_eval_batch_size = 8
learning_rate = 1e-5
weight_decay = 1e-6
gradient_accumulation_steps = 1
lr_scheduler_type = 'linear'
num_warmup_steps = 0
num_train_epochs = 3
set_seed(1111)
# load tokenizer
model_name = "stanford-crfm/BioMedLM"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# dummy training data
train_tokens = tokenizer(["this is a dummy train sentence" * 15] * 8)
train_tokens['labels'] = copy.deepcopy(train_tokens['input_ids'])
train_dataset = Dataset.from_dict(train_tokens)
# dummy validation data
eval_tokens = tokenizer(["this is a dummy eval sentence" * 15] * 8)
eval_tokens['labels'] = copy.deepcopy(eval_tokens['input_ids'])
eval_dataset = Dataset.from_dict(eval_tokens)
# prepare data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# prepare dataloader
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=per_device_train_batch_size
)
eval_dataloader = DataLoader(
eval_dataset, collate_fn=data_collator, batch_size=per_device_eval_batch_size
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
max_train_steps = num_train_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
name=lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=max_train_steps,
)
# Train
total_batch_size = per_device_train_batch_size * gradient_accumulation_steps
print("***** Running training *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num Epochs = {num_train_epochs}")
print(f" Instantaneous batch size per device = {per_device_train_batch_size}")
print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
print(f" Total optimization steps = {max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(max_train_steps))
completed_steps = 0
for epoch in range(num_train_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
loss.backward()
if step % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
completed_steps += 1
if completed_steps >= max_train_steps:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment