Skip to content

Instantly share code, notes, and snippets.

@raman-r-4978
Created September 24, 2021 10:22
Show Gist options
  • Save raman-r-4978/7ae193b8dfa1ebbd504572d2bb9ecd5b to your computer and use it in GitHub Desktop.
Save raman-r-4978/7ae193b8dfa1ebbd504572d2bb9ecd5b to your computer and use it in GitHub Desktop.
PL issue
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.trainer.trainer import Trainer
from torch.nn import functional as F
from mnist_datamodule import MNISTDataModule
pl.seed_everything(42)
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log(
f"train_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
rank_zero_only=True,
)
return loss
def validation_step(self, batch, batch_idx, dataset_idx=None):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log(
f"val_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
rank_zero_only=True,
)
return loss
def test_step(self, batch, batch_idx, dataset_idx=None):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log(
f"test_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
rank_zero_only=True,
)
return loss
def aggregate_validation_metrics(self, val_outputs, loss_name):
tot_loss: torch.FloatTensor = torch.tensor(0.0, device=self.device)
# multi data loader
if isinstance(val_outputs[0], list):
for loss in val_outputs:
tot_loss += sum(loss) / len(loss)
tot_loss = tot_loss / len(val_outputs)
# single data loader
else:
tot_loss += sum(val_outputs) / len(val_outputs)
self.log(
f"tot_{loss_name}",
tot_loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
rank_zero_only=True,
)
def validation_epoch_end(self, val_outputs):
self.aggregate_validation_metrics(val_outputs, loss_name="val_loss")
def test_epoch_end(self, val_outputs):
self.aggregate_validation_metrics(val_outputs, loss_name="test_loss")
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
def main():
model = LitClassifier()
data_module = MNISTDataModule()
trainer = Trainer(
gpus=2,
max_epochs=5,
num_sanity_val_steps=5,
logger=TensorBoardLogger("mnist_logs", name="mnist"),
accelerator="ddp",
)
trainer.fit(model, data_module)
trainer.test(ckpt_path="best")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment