Last active
June 1, 2022 01:40
-
-
Save raphael0202/85580b29b27a27ddaae8d393f686f891 to your computer and use it in GitHub Desktop.
Training script for the blog post "How many layers of my BERT model should I freeze?"
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| scikit-learn==0.23.2 | |
| datasets==1.1.3 | |
| torch==1.7.0 | |
| transformers==3.5.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import math | |
| import os | |
| from pathlib import Path | |
| import shutil | |
| from typing import Optional | |
| import uuid | |
| import datasets | |
| import numpy as np | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| Trainer, | |
| TrainingArguments, | |
| ) | |
| def get_random_seed(): | |
| return int.from_bytes(os.urandom(4), "big") | |
| DATASET_MAP = {"sst2": ("glue", "sst2"), "cola": ("glue", "cola"), "imdb": ("imdb",)} | |
| SPLIT_DIR = Path("split") | |
| def get_split_path(dataset_name: str, train_size: int): | |
| if dataset_name not in DATASET_MAP: | |
| raise ValueError(f"unknown dataset: {dataset_name}") | |
| dataset_tuple = DATASET_MAP[dataset_name] | |
| return SPLIT_DIR / f"{dataset_tuple[0]}-{dataset_tuple[1]}-{train_size}.npy" | |
| def get_dataset(tokenizer, dataset_name: str, split: str, split_path: Optional[Path] = None): | |
| ds = datasets.load_dataset(*DATASET_MAP[dataset_name], split=split) | |
| ds = ds.shuffle(seed=42) | |
| if split_path is not None: | |
| # split_path is a npy file containing indexes of samples to keep | |
| print(f"Using split file {split_path}") | |
| split_ids = set(np.load(split_path).tolist()) | |
| ds = ds.filter(lambda idx: idx in split_ids, input_columns="idx") | |
| ds = ds.map(lambda e: tokenizer(e["sentence"], padding=False, truncation=True), batched=True) | |
| ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) | |
| return ds | |
| def compute_metrics(pred): | |
| labels = pred.label_ids | |
| preds = pred.predictions.argmax(-1) | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary") | |
| acc = accuracy_score(labels, preds) | |
| return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall} | |
| def train( | |
| output_dir: str, | |
| dataset_name: str, | |
| train_size: Optional[int] = None, | |
| freeze_layer_count: int = 0, | |
| ): | |
| split_path = get_split_path(dataset_name, train_size) if train_size is not None else None | |
| args_dict = { | |
| "evaluation_strategy": "steps", | |
| "per_device_train_batch_size": 16, | |
| "per_device_eval_batch_size": 16, | |
| "learning_rate": 5e-5, | |
| "num_train_epochs": 10, | |
| "logging_first_step": True, | |
| "save_total_limit": 1, | |
| "fp16": True, | |
| "dataloader_num_workers": 1, | |
| "load_best_model_at_end": True, | |
| "metric_for_best_model": "accuracy", | |
| # we need to generate a random seed manually, as otherwise | |
| # the same constant random seed is used during training for each run | |
| "seed": get_random_seed(), | |
| } | |
| model = AutoModelForSequenceClassification.from_pretrained("roberta-base", return_dict=True) | |
| if freeze_layer_count: | |
| # We freeze here the embeddings of the model | |
| for param in model.roberta.embeddings.parameters(): | |
| param.requires_grad = False | |
| if freeze_layer_count != -1: | |
| # if freeze_layer_count == -1, we only freeze the embedding layer | |
| # otherwise we freeze the first `freeze_layer_count` encoder layers | |
| for layer in model.roberta.encoder.layer[:freeze_layer_count]: | |
| for param in layer.parameters(): | |
| param.requires_grad = False | |
| tokenizer = AutoTokenizer.from_pretrained("roberta-base") | |
| train_ds = get_dataset(tokenizer, dataset_name, split="train", split_path=split_path) | |
| val_ds = get_dataset(tokenizer, dataset_name, split="validation") | |
| epoch_steps = len(train_ds) / args_dict["per_device_train_batch_size"] | |
| args_dict["warmup_steps"] = math.ceil(epoch_steps) # 1 epoch | |
| args_dict["logging_steps"] = max(1, math.ceil(epoch_steps * 0.5)) # 0.5 epoch | |
| args_dict["save_steps"] = args_dict["logging_steps"] | |
| args_dict["run_name"] = output_dir.name | |
| training_args = TrainingArguments(output_dir=str(output_dir), **args_dict) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_ds, | |
| eval_dataset=val_ds, | |
| compute_metrics=compute_metrics, | |
| tokenizer=tokenizer, | |
| ) | |
| trainer.train() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("dataset_name") | |
| parser.add_argument("--freeze_layer_count", type=int, default=0) | |
| parser.add_argument("--train_size", type=int, default=None) | |
| parser.add_argument("--keep-checkpoint", default=False, action="store_true") | |
| args = parser.parse_args() | |
| OUTPUT_DIR = Path("training") | |
| print(f"** Train size: {args.train_size} **") | |
| print(f"** Freeze layers: {args.freeze_layer_count} **") | |
| output_dir = OUTPUT_DIR / str(uuid.uuid4()) | |
| train( | |
| output_dir=output_dir, | |
| dataset_name=args.dataset_name, | |
| train_size=args.train_size, | |
| freeze_layer_count=args.freeze_layer_count, | |
| ) | |
| if not args.keep_checkpoint: | |
| shutil.rmtree(output_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment