Skip to content

Instantly share code, notes, and snippets.

@imflash217
Created January 14, 2021 22:02
Show Gist options
  • Save imflash217/4b91225c886c2201fc563a355dff73c8 to your computer and use it in GitHub Desktop.
Save imflash217/4b91225c886c2201fc563a355dff73c8 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.Functional as F
import pytorch_lightning as pl
###########################################################################################
## Pytorch_Lightning version
##
class FlashModel(pl.LightningModule):
"""DOCSTRING"""
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss
###########################################################################################
## Under the hood PL does the following =>
##
## Step-1: Put the model in train mode
model.train()
torch.set_grad_enabled = True
losses = []
for batch in train_dataloader:
## Step-2: Forward
loss = training_step(batch)
losses.append(loss.detach())
## Step-3: Backward
loss.backward()
## Step-4: apply optimizer step and clear grads
optimizer.step()
optimizer.zero_grad()
###########################################################################################
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment