Created
July 14, 2020 08:40
-
-
Save aletheia/e69f4dfa0cfbff08296e2f7caa7cb6f1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def training_step(self, batch, batch_idx): | |
'''Called for every training step, uses NLL Loss to compute training loss, then logs and sends back | |
logs parameter to Trainer to perform backpropagation | |
''' | |
# Get input and output from batch | |
x, labels = batch | |
# Compute prediction through the network | |
prediction = self.forward(x) | |
loss = F.nll_loss(prediction, labels) | |
# Logs training loss | |
logs={'train_loss':loss} | |
output = { | |
# This is required in training to be used by backpropagation | |
'loss':loss, | |
# This is optional for logging pourposes | |
'log':logs | |
} | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment