Skip to content

Instantly share code, notes, and snippets.

@aletheia
Created July 14, 2020 08:45
Show Gist options
  • Save aletheia/1e2445bb3b466a928bb804f323edc32c to your computer and use it in GitHub Desktop.
Save aletheia/1e2445bb3b466a928bb804f323edc32c to your computer and use it in GitHub Desktop.
def validation_step(self, batch, batch_idx):
''' Prforms model validation computing cross entropy for predictions and labels
'''
x, labels = batch
prediction = self.forward(x)
return {
'val_loss': F.cross_entropy(prediction, labels)
}
def validation_epoch_end(self, outputs):
'''Called after every epoch, stacks validation loss
'''
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': val_loss_mean}
def validation_end(self, outputs):
'''Called after validation completes. Stacks all testing loss and computes average.
'''
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
print('Average training loss: '+str(avg_loss.item()))
logs = {'val_loss':avg_loss}
return {
'avg_val_loss':avg_loss,
'log':logs
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment