Created
April 18, 2023 00:20
-
-
Save woshiyyya/15973bcf15da629c986e083171fa3b2e to your computer and use it in GitHub Desktop.
[Canva] Integrating Wandb Logger with LightningTrainer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import ray | |
from ray.air.config import ScalingConfig | |
from ray.train.lightning import LightningTrainer, LightningConfigBuilder | |
from ray.air.integrations.wandb import setup_wandb | |
import torch | |
import pytorch_lightning as pl | |
import torch.nn.functional as F | |
from pytorch_lightning.loggers.wandb import WandbLogger | |
from torchmetrics import Accuracy | |
from torch.utils.data import DataLoader, random_split | |
from torchvision.datasets import MNIST | |
from torchvision import transforms | |
import wandb | |
api_key = os.getenv("WANDB_API_KEY") | |
class MNISTClassifier(pl.LightningModule): | |
def __init__(self, lr, feature_dim): | |
super(MNISTClassifier, self).__init__() | |
self.fc1 = torch.nn.Linear(28 * 28, feature_dim) | |
self.fc2 = torch.nn.Linear(feature_dim, 10) | |
self.lr = lr | |
self.accuracy = Accuracy(task="multiclass", num_classes=10) | |
def setup(self, stage): | |
super().setup(stage) | |
self.wandb = setup_wandb(api_key=api_key, project="MNIST-2", name="experiment-1") | |
def forward(self, x): | |
x = x.view(-1, 28 * 28) | |
x = torch.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = torch.nn.functional.cross_entropy(y_hat, y) | |
self.log("train_loss", loss, on_step=True) | |
self.wandb.log({"train_loss": loss}) | |
return loss | |
def validation_step(self, val_batch, batch_idx): | |
x, y = val_batch | |
logits = self.forward(x) | |
loss = F.nll_loss(logits, y) | |
acc = self.accuracy(logits, y) | |
return {"val_loss": loss, "val_accuracy": acc} | |
def validation_epoch_end(self, outputs): | |
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() | |
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean() | |
self.log("ptl/val_loss", avg_loss, sync_dist=True) | |
self.log("ptl/val_accuracy", avg_acc, sync_dist=True) | |
self.wandb.log({"ptl/val_loss": avg_loss}) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) | |
return optimizer | |
class MNISTDataModule(pl.LightningDataModule): | |
def __init__(self, batch_size=100): | |
super().__init__() | |
self.data_dir = os.getcwd() | |
self.batch_size = batch_size | |
self.transform = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | |
) | |
def setup(self, stage=None): | |
# split data into train and val sets | |
if stage == "fit" or stage is None: | |
mnist = MNIST( | |
self.data_dir, train=True, download=True, transform=self.transform | |
) | |
self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000]) | |
# assign test set for use in dataloader(s) | |
if stage == "test" or stage is None: | |
self.mnist_test = MNIST( | |
self.data_dir, train=False, download=True, transform=self.transform | |
) | |
def train_dataloader(self): | |
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4) | |
def val_dataloader(self): | |
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4) | |
def test_dataloader(self): | |
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4) | |
class WandbInitCallback(pl.Callback): | |
def setup(self, trainer, pl_module, stage): | |
self.wandb = setup_wandb(api_key=api_key, project="MNIST", name="experiment-1") | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int): | |
if trainer.is_global_zero: | |
metrics = trainer.callback_metrics | |
train_loss = metrics['train_loss'].detach().cpu().item() | |
self.wandb.log({"train_loss": train_loss}) | |
def on_validation_end(self, trainer, pl_module): | |
if trainer.is_global_zero: | |
metrics = trainer.callback_metrics | |
self.wandb.log({k: v.detach().cpu().item() for k, v in metrics.items()}) | |
class CompileCallback(pl.Callback): | |
def on_train_start(self, trainer, pl_module): | |
print("Compiling model") | |
torch.set_float32_matmul_precision('high') | |
trainer.strategy.model = torch.compile(trainer.strategy.model) | |
if __name__ == "__main__": | |
lightning_config = ( | |
LightningConfigBuilder() | |
.module(MNISTClassifier, feature_dim=128, lr=0.001) | |
.trainer( | |
max_epochs=3, | |
accelerator="gpu", | |
# logger=WandbLogger("canva-wandb-example", project="MNIST-single-run"), | |
# callbacks=[WandbInitCallback()] | |
) | |
.fit_params(datamodule=MNISTDataModule(batch_size=128)) | |
.checkpointing(monitor="ptl/val_accuracy", mode="max", save_last=True) | |
.build() | |
) | |
scaling_config = ScalingConfig( | |
num_workers=2, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1} | |
) | |
trainer = LightningTrainer( | |
lightning_config=lightning_config, | |
scaling_config=scaling_config, | |
) | |
result = trainer.fit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment