Skip to content

Instantly share code, notes, and snippets.

@bouthilx
Created April 19, 2022 19:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bouthilx/1411f3519ed3da737baf097994bf5c12 to your computer and use it in GitHub Desktop.
Save bouthilx/1411f3519ed3da737baf097994bf5c12 to your computer and use it in GitHub Desktop.
Oríon + wandb
# -*- coding: utf-8 -*-
"""Cifar-10 Image Classification using PyTorch Lightning
From WandB tutorial written by Ayush Thakur:
https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY
Original file is located at
https://colab.research.google.com/drive/12oNQ8XGeJFMiSBGsQ8uth8TaghVBC9H3
"""
import argparse
import os
import re
import yaml
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.metrics.functional import accuracy
# from sklearn.metrics import precision_recall_curve
# from sklearn.preprocessing import label_binarize
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import CIFAR10
from orion.client import report_objective, build_experiment
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, batch_size, data_dir="./", split_seed=1):
super().__init__()
self.data_dir = data_dir
self.split_seed = split_seed
self.batch_size = batch_size
self.transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)
self.transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)
self.dims = (3, 32, 32)
self.num_classes = 10
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
cifar_full = CIFAR10(
self.data_dir, train=True, transform=self.transform_train
)
self.cifar_train, self.cifar_val = random_split(
cifar_full,
[45000, 5000],
generator=torch.Generator().manual_seed(self.split_seed),
)
self.cifar_val.dataset.transform = self.transform_test
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.cifar_test = CIFAR10(
self.data_dir, train=False, transform=self.transform_test
)
def train_dataloader(self):
return DataLoader(
self.cifar_train, batch_size=self.batch_size, shuffle=True, num_workers=5
)
def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size, num_workers=5)
def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size, num_workers=5)
class LitModel(pl.LightningModule):
def __init__(
self,
input_shape,
num_classes,
batch_size=128,
learning_rate=2e-4,
weight_decay=0,
momentum=0.8,
gamma=0.99,
):
super().__init__()
# log hyperparameters
self.save_hyperparameters()
self.batch_size = batch_size
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.momentum = momentum
self.gamma = gamma
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 32, 3, 1)
self.conv3 = nn.Conv2d(32, 64, 3, 1)
self.conv4 = nn.Conv2d(64, 64, 3, 1)
self.pool1 = torch.nn.MaxPool2d(2)
self.pool2 = torch.nn.MaxPool2d(2)
n_sizes = self._get_conv_output(input_shape)
self.fc1 = nn.Linear(n_sizes, 512)
self.fc2 = nn.Linear(512, 128)
self.fc3 = nn.Linear(128, num_classes)
# returns the size of the output tensor going into Linear layer from the conv block.
def _get_conv_output(self, shape):
batch_size = 1
input = torch.autograd.Variable(torch.rand(batch_size, *shape))
output_feat = self._forward_features(input)
n_size = output_feat.data.view(batch_size, -1).size(1)
return n_size
# returns the feature tensor from the conv block
def _forward_features(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(F.relu(self.conv2(x)))
x = F.relu(self.conv3(x))
x = self.pool2(F.relu(self.conv4(x)))
return x
# will be used during inference
def forward(self, x):
x = self._forward_features(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x), dim=1)
return x
# logic for a single training step
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# training metrics
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
self.log("train_acc", acc, on_step=True, on_epoch=True, logger=True)
return loss
# logic for a single validation step
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# validation metrics
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
return loss
# logic for a single testing step
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# validation metrics
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log("test_loss", loss, prog_bar=True)
self.log("test_acc", acc, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=self.learning_rate,
momentum=self.momentum,
weight_decay=self.weight_decay,
)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, self.gamma)
return [optimizer], [lr_scheduler]
def cli(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"--epochs", default=50, type=int, help="Number of epochs to train."
)
parser.add_argument(
"--batch-size",
default=None,
type=int,
help="Batch size during training. If None, set by trainer.tune()",
)
parser.add_argument(
"--learning-rate",
default=0.1,
type=float,
help="Learning rate of the optimizer.",
)
parser.add_argument(
"--gamma",
default=0.99,
type=float,
help="Gamma decay rate for exponential learning rate schedule.",
)
parser.add_argument(
"--weight-decay",
default=0,
type=float,
help="L2 regularization.",
)
parser.add_argument(
"--momentum",
default=0.8,
type=float,
help="Momentum for the optimizer.",
)
parser.add_argument(
"--config",
default=None,
type=str,
help="Configuration file to set hyperparameters. Will override those set in commandline.",
)
parser.add_argument(
"--project-name",
default=None,
type=str,
help="Name of project in WandB",
)
parser.add_argument(
"--trial-dir",
default=".",
type=str,
help="Folder to save checkpoints of the trial.",
)
parser.add_argument(
"--trial-id",
default="model",
type=str,
help="ID of the trial. Used for resuming.",
)
options = vars(parser.parse_args(argv))
config = options.pop("config")
if config:
with open(config, "r") as f:
print(f.read())
with open(config, "r") as f:
options.update(yaml.load(f))
main(**options)
def main(
epochs=50,
batch_size=None,
learning_rate=0.1,
gamma=0.99,
n_gamma=None,
weight_decay=1e-10,
momentum=0.8,
project_name="test",
trial_dir=".",
trial_id=None,
):
print(epochs)
if batch_size is None:
backup_batch_size = 128
else:
backup_batch_size = batch_size
# Init our data pipeline
dm = CIFAR10DataModule(batch_size=backup_batch_size)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()
early_stop_callback = EarlyStopping(
monitor="val_acc", patience=5, verbose=False, mode="max"
)
lr_monitor = LearningRateMonitor()
if n_gamma is not None:
gamma = (n_gamma - 1) / n_gamma
# Init our model
model = LitModel(
dm.size(),
dm.num_classes,
batch_size=backup_batch_size,
learning_rate=learning_rate,
weight_decay=weight_decay,
gamma=gamma,
momentum=momentum,
)
# Initialize wandb logger
wandb_logger = WandbLogger(project=project_name, job_type="train", id=trial_id)
# MODEL_CKPT = "model/{id}-{{epoch:02d}}-{{val_loss:.2f}}"
checkpoint_filename = "last"
checkpoint_path = os.path.join(trial_dir, "last.ckpt")
print(trial_id)
print(checkpoint_path)
checkpoint_callback = ModelCheckpoint(dirpath=trial_dir, save_last=True)
# Initialize a trainer
trainer = pl.Trainer(
max_epochs=epochs,
progress_bar_refresh_rate=20,
gpus=[1],
logger=wandb_logger,
callbacks=[early_stop_callback, lr_monitor],
checkpoint_callback=checkpoint_callback,
resume_from_checkpoint=checkpoint_path,
default_root_dir=trial_dir,
auto_scale_batch_size="power",
auto_select_gpus=True,
)
# Find the largest possible batch size
if batch_size is None:
trainer.tune(model, dm)
# TODO: Maybe changing epochs when resuming a trial will cause some issue
wandb_logger.log_hyperparams(
{
"epochs": epochs,
"learning_rate": learning_rate,
"batch_size": batch_size,
"gamma": gamma,
"weight_decay": weight_decay,
"momentum": momentum,
}
)
# Train the model ⚡🚅⚡
trainer.fit(model, dm)
# Evaluate the model on the held out test set ⚡⚡
trainer.test()
# Close wandb run
wandb.finish()
wandb_run = wandb.Api().run(wandb_logger.experiment.path)
# Send back final validation error to Oríon
report_objective(1 - wandb_run.summary["val_acc"])
return wandb_run.summary["val_acc"]
def run_hpo(working_dir=".", **kwargs):
kwargs.setdefault("epochs", 50)
kwargs.setdefault("batch_size", 8192)
# Specify the database where the experiments are stored. We use a local PickleDB here.
storage = {
"type": "legacy",
"database": {
"type": "pickleddb",
"host": "./db.pkl",
},
}
# Load the data for the specified experiment
experiment = build_experiment(
"test-orion-hyperband-cifar10",
space={
"epochs": "fidelity(1, 120, base=4)",
"learning_rate": "loguniform(1e-5, 0.5)",
"momentum": "uniform(0.8, 0.99)",
"weight_decay": "loguniform(1e-10, 1e-2)",
"n_gamma": "loguniform(10, 10000)",
},
algorithms={
"hyperband": {
"seed": 1,
"repetitions": 2,
},
},
storage=storage,
working_dir=working_dir,
)
while not experiment.is_done:
trial = experiment.suggest()
if trial is None:
break
kwargs.update(trial.params)
valid_error_rate = main(
**kwargs,
project_name=f"{experiment.name}-v{experiment.version}",
trial_dir=f"{experiment.working_dir}/{trial.hash_params}",
trial_id=trial.hash_params,
)
experiment.observe(
trial,
[
{
"name": "valid_error_rate",
"type": "objective",
"value": valid_error_rate,
}
],
)
if __name__ == "__main__":
cli()
# run_hpo()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment