Skip to content

Instantly share code, notes, and snippets.

@yukw777
Created August 27, 2020 15:30
Show Gist options
  • Save yukw777/d184198f2748b1b237100e9a580ac425 to your computer and use it in GitHub Desktop.
Save yukw777/d184198f2748b1b237100e9a580ac425 to your computer and use it in GitHub Desktop.
PyTorch Lightning Result Example
def training_step(self, batch: DataPoint, batch_idx: int) -> pl.TrainResult:
planes, target_move, target_val = batch
pred_move, pred_val = self(planes)
mse_loss, cross_entropy_loss, loss = self.loss(
pred_move, pred_val, target_move, target_val
)
result = pl.TrainResult(minimize=loss)
result.log("train_loss", loss, prog_bar=True)
result.log_dict(
{
"train_mse_loss": mse_loss,
"train_ce_loss": cross_entropy_loss,
"train_acc": accuracy(pred_move, target_move),
}
)
return result
def validation_step(self, batch: DataPoint, batch_idx: int) -> pl.EvalResult:
planes, target_move, target_val = batch
pred_move, pred_val = self(planes)
mse_loss, cross_entropy_loss, loss = self.loss(
pred_move, pred_val, target_move, target_val
)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict(
{
"val_loss": loss,
"val_mse_loss": mse_loss,
"val_ce_loss": cross_entropy_loss,
"val_acc": accuracy(pred_move, target_move),
}
)
return result
def test_step(self, batch: DataPoint, batch_idx: int) -> pl.EvalResult:
result = self.validation_step(batch, batch_idx)
result.rename_keys(
{
"val_loss": "test_loss",
"val_mse_loss": "test_mse_loss",
"val_ce_loss": "test_ce_loss",
"val_acc": "test_acc",
}
)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment