Skip to content

Instantly share code, notes, and snippets.

@raphael0202
Last active June 1, 2022 01:40
Show Gist options
  • Select an option

  • Save raphael0202/85580b29b27a27ddaae8d393f686f891 to your computer and use it in GitHub Desktop.

Select an option

Save raphael0202/85580b29b27a27ddaae8d393f686f891 to your computer and use it in GitHub Desktop.
Training script for the blog post "How many layers of my BERT model should I freeze?"
scikit-learn==0.23.2
datasets==1.1.3
torch==1.7.0
transformers==3.5.1
import argparse
import math
import os
from pathlib import Path
import shutil
from typing import Optional
import uuid
import datasets
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
def get_random_seed():
return int.from_bytes(os.urandom(4), "big")
DATASET_MAP = {"sst2": ("glue", "sst2"), "cola": ("glue", "cola"), "imdb": ("imdb",)}
SPLIT_DIR = Path("split")
def get_split_path(dataset_name: str, train_size: int):
if dataset_name not in DATASET_MAP:
raise ValueError(f"unknown dataset: {dataset_name}")
dataset_tuple = DATASET_MAP[dataset_name]
return SPLIT_DIR / f"{dataset_tuple[0]}-{dataset_tuple[1]}-{train_size}.npy"
def get_dataset(tokenizer, dataset_name: str, split: str, split_path: Optional[Path] = None):
ds = datasets.load_dataset(*DATASET_MAP[dataset_name], split=split)
ds = ds.shuffle(seed=42)
if split_path is not None:
# split_path is a npy file containing indexes of samples to keep
print(f"Using split file {split_path}")
split_ids = set(np.load(split_path).tolist())
ds = ds.filter(lambda idx: idx in split_ids, input_columns="idx")
ds = ds.map(lambda e: tokenizer(e["sentence"], padding=False, truncation=True), batched=True)
ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
return ds
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
def train(
output_dir: str,
dataset_name: str,
train_size: Optional[int] = None,
freeze_layer_count: int = 0,
):
split_path = get_split_path(dataset_name, train_size) if train_size is not None else None
args_dict = {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 16,
"per_device_eval_batch_size": 16,
"learning_rate": 5e-5,
"num_train_epochs": 10,
"logging_first_step": True,
"save_total_limit": 1,
"fp16": True,
"dataloader_num_workers": 1,
"load_best_model_at_end": True,
"metric_for_best_model": "accuracy",
# we need to generate a random seed manually, as otherwise
# the same constant random seed is used during training for each run
"seed": get_random_seed(),
}
model = AutoModelForSequenceClassification.from_pretrained("roberta-base", return_dict=True)
if freeze_layer_count:
# We freeze here the embeddings of the model
for param in model.roberta.embeddings.parameters():
param.requires_grad = False
if freeze_layer_count != -1:
# if freeze_layer_count == -1, we only freeze the embedding layer
# otherwise we freeze the first `freeze_layer_count` encoder layers
for layer in model.roberta.encoder.layer[:freeze_layer_count]:
for param in layer.parameters():
param.requires_grad = False
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
train_ds = get_dataset(tokenizer, dataset_name, split="train", split_path=split_path)
val_ds = get_dataset(tokenizer, dataset_name, split="validation")
epoch_steps = len(train_ds) / args_dict["per_device_train_batch_size"]
args_dict["warmup_steps"] = math.ceil(epoch_steps) # 1 epoch
args_dict["logging_steps"] = max(1, math.ceil(epoch_steps * 0.5)) # 0.5 epoch
args_dict["save_steps"] = args_dict["logging_steps"]
args_dict["run_name"] = output_dir.name
training_args = TrainingArguments(output_dir=str(output_dir), **args_dict)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
)
trainer.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("dataset_name")
parser.add_argument("--freeze_layer_count", type=int, default=0)
parser.add_argument("--train_size", type=int, default=None)
parser.add_argument("--keep-checkpoint", default=False, action="store_true")
args = parser.parse_args()
OUTPUT_DIR = Path("training")
print(f"** Train size: {args.train_size} **")
print(f"** Freeze layers: {args.freeze_layer_count} **")
output_dir = OUTPUT_DIR / str(uuid.uuid4())
train(
output_dir=output_dir,
dataset_name=args.dataset_name,
train_size=args.train_size,
freeze_layer_count=args.freeze_layer_count,
)
if not args.keep_checkpoint:
shutil.rmtree(output_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment