Skip to content

Instantly share code, notes, and snippets.

@pranshuj73
Created July 3, 2020 11:12
Show Gist options
  • Save pranshuj73/bffd8588961ef8d7939965ac3ea6da15 to your computer and use it in GitHub Desktop.
Save pranshuj73/bffd8588961ef8d7939965ac3ea6da15 to your computer and use it in GitHub Desktop.
def accuracy(outputs, labels):
_, preds = torch.max(outputs, dim=1)
return torch.tensor(torch.sum(preds==labels).item()/len(preds))
class FERBase(nn.Module):
# this takes is batch from training dl
def training_step(self, batch):
images, labels = batch
out = self(images) # calls the training model and generates predictions
loss = F.cross_entropy(out, labels) # calculates loss compare to real labels using cross entropy
return loss
# this takes in batch from validation dl
def validation_step(self, batch):
images, labels = batch
out = self(images)
loss = F.cross_entropy(out, labels)
acc = accuracy(out, labels) # calls the accuracy function to measure the accuracy
return {'val_loss': loss.detach(), 'val_acc': acc}
def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # finds out the mean loss of the epoch batch
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean() # finds out the mean acc of the epoch batch
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
def epoch_end(self, epoch, result):
print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment