Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Created April 18, 2023 00:20
Show Gist options
  • Save woshiyyya/15973bcf15da629c986e083171fa3b2e to your computer and use it in GitHub Desktop.
Save woshiyyya/15973bcf15da629c986e083171fa3b2e to your computer and use it in GitHub Desktop.
[Canva] Integrating Wandb Logger with LightningTrainer
import os
import ray
from ray.air.config import ScalingConfig
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.air.integrations.wandb import setup_wandb
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning.loggers.wandb import WandbLogger
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import wandb
api_key = os.getenv("WANDB_API_KEY")
class MNISTClassifier(pl.LightningModule):
def __init__(self, lr, feature_dim):
super(MNISTClassifier, self).__init__()
self.fc1 = torch.nn.Linear(28 * 28, feature_dim)
self.fc2 = torch.nn.Linear(feature_dim, 10)
self.lr = lr
self.accuracy = Accuracy(task="multiclass", num_classes=10)
def setup(self, stage):
super().setup(stage)
self.wandb = setup_wandb(api_key=api_key, project="MNIST-2", name="experiment-1")
def forward(self, x):
x = x.view(-1, 28 * 28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("train_loss", loss, on_step=True)
self.wandb.log({"train_loss": loss})
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": acc}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss, sync_dist=True)
self.log("ptl/val_accuracy", avg_acc, sync_dist=True)
self.wandb.log({"ptl/val_loss": avg_loss})
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=100):
super().__init__()
self.data_dir = os.getcwd()
self.batch_size = batch_size
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
def setup(self, stage=None):
# split data into train and val sets
if stage == "fit" or stage is None:
mnist = MNIST(
self.data_dir, train=True, download=True, transform=self.transform
)
self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])
# assign test set for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(
self.data_dir, train=False, download=True, transform=self.transform
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)
class WandbInitCallback(pl.Callback):
def setup(self, trainer, pl_module, stage):
self.wandb = setup_wandb(api_key=api_key, project="MNIST", name="experiment-1")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int):
if trainer.is_global_zero:
metrics = trainer.callback_metrics
train_loss = metrics['train_loss'].detach().cpu().item()
self.wandb.log({"train_loss": train_loss})
def on_validation_end(self, trainer, pl_module):
if trainer.is_global_zero:
metrics = trainer.callback_metrics
self.wandb.log({k: v.detach().cpu().item() for k, v in metrics.items()})
class CompileCallback(pl.Callback):
def on_train_start(self, trainer, pl_module):
print("Compiling model")
torch.set_float32_matmul_precision('high')
trainer.strategy.model = torch.compile(trainer.strategy.model)
if __name__ == "__main__":
lightning_config = (
LightningConfigBuilder()
.module(MNISTClassifier, feature_dim=128, lr=0.001)
.trainer(
max_epochs=3,
accelerator="gpu",
# logger=WandbLogger("canva-wandb-example", project="MNIST-single-run"),
# callbacks=[WandbInitCallback()]
)
.fit_params(datamodule=MNISTDataModule(batch_size=128))
.checkpointing(monitor="ptl/val_accuracy", mode="max", save_last=True)
.build()
)
scaling_config = ScalingConfig(
num_workers=2, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)
trainer = LightningTrainer(
lightning_config=lightning_config,
scaling_config=scaling_config,
)
result = trainer.fit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment