Skip to content

Instantly share code, notes, and snippets.

@yukw777
Created February 9, 2021 14:23
Show Gist options
  • Save yukw777/37ddfe8226e421a4948805b98db39358 to your computer and use it in GitHub Desktop.
Save yukw777/37ddfe8226e421a4948805b98db39358 to your computer and use it in GitHub Desktop.
PyTorch Lightning Metrics Example
class NetworkLightningModule(..., pl.LightningModule):
def __init__(self, ...):
super().__init__(...)
self.save_hyperparameters()
# metrics
self.train_accuracy = pl.metrics.Accuracy()
self.val_accuracy = pl.metrics.Accuracy()
self.test_accuracy = pl.metrics.Accuracy()
def training_step(self, batch: DataPoint, batch_idx: int) -> torch.Tensor:
...
self.log_dict(
{
"train_mse_loss": mse_loss,
"train_ce_loss": cross_entropy_loss,
"train_acc": self.train_accuracy(pred_move, target_move),
}
)
...
def validation_step(self, batch: DataPoint, batch_idx: int) -> None:
...
self.val_accuracy(pred_move, target_move),
self.log_dict(
{
"val_loss": loss,
"val_mse_loss": mse_loss,
"val_ce_loss": cross_entropy_loss,
"val_acc": self.val_accuracy,
}
)
def test_step(self, batch: DataPoint, batch_idx: int) -> None:
...
self.test_accuracy(pred_move, target_move)
self.log_dict(
{
"test_loss": loss,
"test_mse_loss": mse_loss,
"test_ce_loss": cross_entropy_loss,
"test_acc": self.test_accuracy,
}
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment