Created
July 3, 2020 11:12
-
-
Save pranshuj73/bffd8588961ef8d7939965ac3ea6da15 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def accuracy(outputs, labels): | |
_, preds = torch.max(outputs, dim=1) | |
return torch.tensor(torch.sum(preds==labels).item()/len(preds)) | |
class FERBase(nn.Module): | |
# this takes is batch from training dl | |
def training_step(self, batch): | |
images, labels = batch | |
out = self(images) # calls the training model and generates predictions | |
loss = F.cross_entropy(out, labels) # calculates loss compare to real labels using cross entropy | |
return loss | |
# this takes in batch from validation dl | |
def validation_step(self, batch): | |
images, labels = batch | |
out = self(images) | |
loss = F.cross_entropy(out, labels) | |
acc = accuracy(out, labels) # calls the accuracy function to measure the accuracy | |
return {'val_loss': loss.detach(), 'val_acc': acc} | |
def validation_epoch_end(self, outputs): | |
batch_losses = [x['val_loss'] for x in outputs] | |
epoch_loss = torch.stack(batch_losses).mean() # finds out the mean loss of the epoch batch | |
batch_accs = [x['val_acc'] for x in outputs] | |
epoch_acc = torch.stack(batch_accs).mean() # finds out the mean acc of the epoch batch | |
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} | |
def epoch_end(self, epoch, result): | |
print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format( | |
epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc'])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment