Created
April 19, 2022 19:00
-
-
Save bouthilx/1411f3519ed3da737baf097994bf5c12 to your computer and use it in GitHub Desktop.
Oríon + wandb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""Cifar-10 Image Classification using PyTorch Lightning | |
From WandB tutorial written by Ayush Thakur: | |
https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY | |
Original file is located at | |
https://colab.research.google.com/drive/12oNQ8XGeJFMiSBGsQ8uth8TaghVBC9H3 | |
""" | |
import argparse | |
import os | |
import re | |
import yaml | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
import wandb | |
from pytorch_lightning.callbacks import Callback, ModelCheckpoint | |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping | |
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor | |
from pytorch_lightning.loggers import WandbLogger | |
from pytorch_lightning.metrics.functional import accuracy | |
# from sklearn.metrics import precision_recall_curve | |
# from sklearn.preprocessing import label_binarize | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader, random_split | |
from torchvision import transforms | |
from torchvision.datasets import CIFAR10 | |
from orion.client import report_objective, build_experiment | |
class CIFAR10DataModule(pl.LightningDataModule): | |
def __init__(self, batch_size, data_dir="./", split_seed=1): | |
super().__init__() | |
self.data_dir = data_dir | |
self.split_seed = split_seed | |
self.batch_size = batch_size | |
self.transform_train = transforms.Compose( | |
[ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) | |
), | |
] | |
) | |
self.transform_test = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) | |
), | |
] | |
) | |
self.dims = (3, 32, 32) | |
self.num_classes = 10 | |
def prepare_data(self): | |
# download | |
CIFAR10(self.data_dir, train=True, download=True) | |
CIFAR10(self.data_dir, train=False, download=True) | |
def setup(self, stage=None): | |
# Assign train/val datasets for use in dataloaders | |
if stage == "fit" or stage is None: | |
cifar_full = CIFAR10( | |
self.data_dir, train=True, transform=self.transform_train | |
) | |
self.cifar_train, self.cifar_val = random_split( | |
cifar_full, | |
[45000, 5000], | |
generator=torch.Generator().manual_seed(self.split_seed), | |
) | |
self.cifar_val.dataset.transform = self.transform_test | |
# Assign test dataset for use in dataloader(s) | |
if stage == "test" or stage is None: | |
self.cifar_test = CIFAR10( | |
self.data_dir, train=False, transform=self.transform_test | |
) | |
def train_dataloader(self): | |
return DataLoader( | |
self.cifar_train, batch_size=self.batch_size, shuffle=True, num_workers=5 | |
) | |
def val_dataloader(self): | |
return DataLoader(self.cifar_val, batch_size=self.batch_size, num_workers=5) | |
def test_dataloader(self): | |
return DataLoader(self.cifar_test, batch_size=self.batch_size, num_workers=5) | |
class LitModel(pl.LightningModule): | |
def __init__( | |
self, | |
input_shape, | |
num_classes, | |
batch_size=128, | |
learning_rate=2e-4, | |
weight_decay=0, | |
momentum=0.8, | |
gamma=0.99, | |
): | |
super().__init__() | |
# log hyperparameters | |
self.save_hyperparameters() | |
self.batch_size = batch_size | |
self.learning_rate = learning_rate | |
self.weight_decay = weight_decay | |
self.momentum = momentum | |
self.gamma = gamma | |
self.conv1 = nn.Conv2d(3, 32, 3, 1) | |
self.conv2 = nn.Conv2d(32, 32, 3, 1) | |
self.conv3 = nn.Conv2d(32, 64, 3, 1) | |
self.conv4 = nn.Conv2d(64, 64, 3, 1) | |
self.pool1 = torch.nn.MaxPool2d(2) | |
self.pool2 = torch.nn.MaxPool2d(2) | |
n_sizes = self._get_conv_output(input_shape) | |
self.fc1 = nn.Linear(n_sizes, 512) | |
self.fc2 = nn.Linear(512, 128) | |
self.fc3 = nn.Linear(128, num_classes) | |
# returns the size of the output tensor going into Linear layer from the conv block. | |
def _get_conv_output(self, shape): | |
batch_size = 1 | |
input = torch.autograd.Variable(torch.rand(batch_size, *shape)) | |
output_feat = self._forward_features(input) | |
n_size = output_feat.data.view(batch_size, -1).size(1) | |
return n_size | |
# returns the feature tensor from the conv block | |
def _forward_features(self, x): | |
x = F.relu(self.conv1(x)) | |
x = self.pool1(F.relu(self.conv2(x))) | |
x = F.relu(self.conv3(x)) | |
x = self.pool2(F.relu(self.conv4(x))) | |
return x | |
# will be used during inference | |
def forward(self, x): | |
x = self._forward_features(x) | |
x = x.view(x.size(0), -1) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = F.log_softmax(self.fc3(x), dim=1) | |
return x | |
# logic for a single training step | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = F.nll_loss(logits, y) | |
# training metrics | |
preds = torch.argmax(logits, dim=1) | |
acc = accuracy(preds, y) | |
self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True) | |
self.log("train_acc", acc, on_step=True, on_epoch=True, logger=True) | |
return loss | |
# logic for a single validation step | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = F.nll_loss(logits, y) | |
# validation metrics | |
preds = torch.argmax(logits, dim=1) | |
acc = accuracy(preds, y) | |
self.log("val_loss", loss, prog_bar=True) | |
self.log("val_acc", acc, prog_bar=True) | |
return loss | |
# logic for a single testing step | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = F.nll_loss(logits, y) | |
# validation metrics | |
preds = torch.argmax(logits, dim=1) | |
acc = accuracy(preds, y) | |
self.log("test_loss", loss, prog_bar=True) | |
self.log("test_acc", acc, prog_bar=True) | |
return loss | |
def configure_optimizers(self): | |
optimizer = torch.optim.SGD( | |
self.parameters(), | |
lr=self.learning_rate, | |
momentum=self.momentum, | |
weight_decay=self.weight_decay, | |
) | |
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, self.gamma) | |
return [optimizer], [lr_scheduler] | |
def cli(argv=None): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--epochs", default=50, type=int, help="Number of epochs to train." | |
) | |
parser.add_argument( | |
"--batch-size", | |
default=None, | |
type=int, | |
help="Batch size during training. If None, set by trainer.tune()", | |
) | |
parser.add_argument( | |
"--learning-rate", | |
default=0.1, | |
type=float, | |
help="Learning rate of the optimizer.", | |
) | |
parser.add_argument( | |
"--gamma", | |
default=0.99, | |
type=float, | |
help="Gamma decay rate for exponential learning rate schedule.", | |
) | |
parser.add_argument( | |
"--weight-decay", | |
default=0, | |
type=float, | |
help="L2 regularization.", | |
) | |
parser.add_argument( | |
"--momentum", | |
default=0.8, | |
type=float, | |
help="Momentum for the optimizer.", | |
) | |
parser.add_argument( | |
"--config", | |
default=None, | |
type=str, | |
help="Configuration file to set hyperparameters. Will override those set in commandline.", | |
) | |
parser.add_argument( | |
"--project-name", | |
default=None, | |
type=str, | |
help="Name of project in WandB", | |
) | |
parser.add_argument( | |
"--trial-dir", | |
default=".", | |
type=str, | |
help="Folder to save checkpoints of the trial.", | |
) | |
parser.add_argument( | |
"--trial-id", | |
default="model", | |
type=str, | |
help="ID of the trial. Used for resuming.", | |
) | |
options = vars(parser.parse_args(argv)) | |
config = options.pop("config") | |
if config: | |
with open(config, "r") as f: | |
print(f.read()) | |
with open(config, "r") as f: | |
options.update(yaml.load(f)) | |
main(**options) | |
def main( | |
epochs=50, | |
batch_size=None, | |
learning_rate=0.1, | |
gamma=0.99, | |
n_gamma=None, | |
weight_decay=1e-10, | |
momentum=0.8, | |
project_name="test", | |
trial_dir=".", | |
trial_id=None, | |
): | |
print(epochs) | |
if batch_size is None: | |
backup_batch_size = 128 | |
else: | |
backup_batch_size = batch_size | |
# Init our data pipeline | |
dm = CIFAR10DataModule(batch_size=backup_batch_size) | |
# To access the x_dataloader we need to call prepare_data and setup. | |
dm.prepare_data() | |
dm.setup() | |
early_stop_callback = EarlyStopping( | |
monitor="val_acc", patience=5, verbose=False, mode="max" | |
) | |
lr_monitor = LearningRateMonitor() | |
if n_gamma is not None: | |
gamma = (n_gamma - 1) / n_gamma | |
# Init our model | |
model = LitModel( | |
dm.size(), | |
dm.num_classes, | |
batch_size=backup_batch_size, | |
learning_rate=learning_rate, | |
weight_decay=weight_decay, | |
gamma=gamma, | |
momentum=momentum, | |
) | |
# Initialize wandb logger | |
wandb_logger = WandbLogger(project=project_name, job_type="train", id=trial_id) | |
# MODEL_CKPT = "model/{id}-{{epoch:02d}}-{{val_loss:.2f}}" | |
checkpoint_filename = "last" | |
checkpoint_path = os.path.join(trial_dir, "last.ckpt") | |
print(trial_id) | |
print(checkpoint_path) | |
checkpoint_callback = ModelCheckpoint(dirpath=trial_dir, save_last=True) | |
# Initialize a trainer | |
trainer = pl.Trainer( | |
max_epochs=epochs, | |
progress_bar_refresh_rate=20, | |
gpus=[1], | |
logger=wandb_logger, | |
callbacks=[early_stop_callback, lr_monitor], | |
checkpoint_callback=checkpoint_callback, | |
resume_from_checkpoint=checkpoint_path, | |
default_root_dir=trial_dir, | |
auto_scale_batch_size="power", | |
auto_select_gpus=True, | |
) | |
# Find the largest possible batch size | |
if batch_size is None: | |
trainer.tune(model, dm) | |
# TODO: Maybe changing epochs when resuming a trial will cause some issue | |
wandb_logger.log_hyperparams( | |
{ | |
"epochs": epochs, | |
"learning_rate": learning_rate, | |
"batch_size": batch_size, | |
"gamma": gamma, | |
"weight_decay": weight_decay, | |
"momentum": momentum, | |
} | |
) | |
# Train the model ⚡🚅⚡ | |
trainer.fit(model, dm) | |
# Evaluate the model on the held out test set ⚡⚡ | |
trainer.test() | |
# Close wandb run | |
wandb.finish() | |
wandb_run = wandb.Api().run(wandb_logger.experiment.path) | |
# Send back final validation error to Oríon | |
report_objective(1 - wandb_run.summary["val_acc"]) | |
return wandb_run.summary["val_acc"] | |
def run_hpo(working_dir=".", **kwargs): | |
kwargs.setdefault("epochs", 50) | |
kwargs.setdefault("batch_size", 8192) | |
# Specify the database where the experiments are stored. We use a local PickleDB here. | |
storage = { | |
"type": "legacy", | |
"database": { | |
"type": "pickleddb", | |
"host": "./db.pkl", | |
}, | |
} | |
# Load the data for the specified experiment | |
experiment = build_experiment( | |
"test-orion-hyperband-cifar10", | |
space={ | |
"epochs": "fidelity(1, 120, base=4)", | |
"learning_rate": "loguniform(1e-5, 0.5)", | |
"momentum": "uniform(0.8, 0.99)", | |
"weight_decay": "loguniform(1e-10, 1e-2)", | |
"n_gamma": "loguniform(10, 10000)", | |
}, | |
algorithms={ | |
"hyperband": { | |
"seed": 1, | |
"repetitions": 2, | |
}, | |
}, | |
storage=storage, | |
working_dir=working_dir, | |
) | |
while not experiment.is_done: | |
trial = experiment.suggest() | |
if trial is None: | |
break | |
kwargs.update(trial.params) | |
valid_error_rate = main( | |
**kwargs, | |
project_name=f"{experiment.name}-v{experiment.version}", | |
trial_dir=f"{experiment.working_dir}/{trial.hash_params}", | |
trial_id=trial.hash_params, | |
) | |
experiment.observe( | |
trial, | |
[ | |
{ | |
"name": "valid_error_rate", | |
"type": "objective", | |
"value": valid_error_rate, | |
} | |
], | |
) | |
if __name__ == "__main__": | |
cli() | |
# run_hpo() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment