Skip to content

Instantly share code, notes, and snippets.

@aliwaqas333
Created June 19, 2020 10:14
Show Gist options
  • Save aliwaqas333/5d53e4a85a43f32db9e2a778b7d49fb1 to your computer and use it in GitHub Desktop.
Save aliwaqas333/5d53e4a85a43f32db9e2a778b7d49fb1 to your computer and use it in GitHub Desktop.
Image Classification base class
class ImageClassificationBase(nn.Module):
def training_step(self, batch):
images, labels = batch
out = self(images) # Generate predictions
loss = F.cross_entropy(out, labels.long()) # Calculate loss
return loss
def validation_step(self, batch):
images, labels = batch
out = self(images) # Generate predictions
loss = F.cross_entropy(out, labels.long()) # Calculate loss
acc = accuracy(out, labels) # Calculate accuracy
return {'val_loss': loss.detach(), 'val_acc': acc}
def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # Combine losses
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
def epoch_end(self, epoch, result):
print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
epoch, result['train_loss'], result['val_loss'], result['val_acc']))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment