Skip to content

Instantly share code, notes, and snippets.

@yukw777
Created August 27, 2020 15:34
Show Gist options
  • Save yukw777/77fe91f560498e1767073ff035f2fb09 to your computer and use it in GitHub Desktop.
Save yukw777/77fe91f560498e1767073ff035f2fb09 to your computer and use it in GitHub Desktop.
PyTorch Lightning Metrics Example
def training_step(self, batch: DataPoint, batch_idx: int) -> pl.TrainResult:
planes, target_move, target_val = batch
pred_move, pred_val = self(planes)
mse_loss, cross_entropy_loss, loss = self.loss(
pred_move, pred_val, target_move, target_val
)
result = pl.TrainResult(minimize=loss)
result.log("train_loss", loss, prog_bar=True)
result.log_dict(
{
"train_mse_loss": mse_loss,
"train_ce_loss": cross_entropy_loss,
"train_acc": accuracy(pred_move, target_move),
}
)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment