Skip to content

Instantly share code, notes, and snippets.

@koshian2
Last active May 15, 2022 13:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/c57c9bad91f1f4da89622a98af0aafa1 to your computer and use it in GitHub Desktop.
Save koshian2/c57c9bad91f1f4da89622a98af0aafa1 to your computer and use it in GitHub Desktop.
PyTorch lightning CSVLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchmetrics
import pytorch_lightning as pl
class TenLayersModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList()
for i in range(3):
for j in range(3):
if i == 0 and j == 0:
in_ch = 3
elif j == 0:
in_ch = 64 * (2**(i-1))
else:
in_ch = 64*(2**i)
out_ch = 64*(2**i)
self.layers.append(nn.Conv2d(in_ch, out_ch, 3, padding=1))
self.layers.append(nn.BatchNorm2d(out_ch))
self.layers.append(nn.ReLU())
self.layers.append(nn.AdaptiveAvgPool2d((1, 1)))
self.fc = nn.Linear(256, 10)
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
def forward(self, inputs):
x = inputs
for l in self.layers:
x = l(x)
x = x.view(x.shape[0], 256)
x = self.fc(x)
return x
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), 0.1, 0.9)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [70, 90], gamma=0.1)
return [optimizer], [scheduler]
def training_step(self, train_batch, batch_idx):
x, y_true = train_batch
y_pred = self.forward(x)
loss = F.cross_entropy(y_pred, y_true)
y_pred_label = torch.argmax(y_pred, dim=-1)
self.train_acc.update(y_pred_label, y_true)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss
def training_epoch_end(self, outputs):
self.log("train_acc", self.train_acc.compute(), prog_bar=True, logger=True)
self.train_acc.reset()
def validation_step(self, val_batch, batch_idx):
x, y_true = val_batch
y_pred = self.forward(x)
loss = F.cross_entropy(y_pred, y_true)
y_pred_label = torch.argmax(y_pred, dim=-1)
self.val_acc.update(y_pred_label, y_true)
self.log("val_loss", loss, prog_bar=False, logger=True)
return loss
def validation_epoch_end(self, outputs):
self.log("val_acc", self.val_acc.compute(), prog_bar=True, logger=True)
self.val_acc.reset()
class MyDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
self.train_dataset = torchvision.datasets.CIFAR10(
"./data", train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomCrop(size=(32, 32), padding=2),
torchvision.transforms.ToTensor()
]))
self.val_dataset = torchvision.datasets.CIFAR10(
"./data", train=False, download=True,
transform=torchvision.transforms.ToTensor())
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=256, num_workers=4, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=256, num_workers=4, shuffle=False)
def main():
model = TenLayersModel()
cifar = MyDataModule()
logger_csv = pl.loggers.CSVLogger("outputs", name="lightning_logs_csv")
logger_tb = pl.loggers.TensorBoardLogger("outputs", name="lightning_logs_tb")
checkpoint_cb = pl.callbacks.ModelCheckpoint(dirpath="outputs/checkpoints", save_top_k=1, monitor="val_acc",
mode="max", filename="{epoch:03}-{val_acc:.3f}")
trainer = pl.Trainer(gpus=[1], max_epochs=100, logger=[logger_csv, logger_tb], callbacks=[checkpoint_cb])
trainer.fit(model, cifar)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment