Skip to content

Instantly share code, notes, and snippets.

@imflash217
Created January 14, 2021 22:56
Show Gist options
  • Save imflash217/ef4b53d5601c9686799838e2fd23766e to your computer and use it in GitHub Desktop.
Save imflash217/ef4b53d5601c9686799838e2fd23766e to your computer and use it in GitHub Desktop.
import torch
import torch.nn.Functional as F
import pytorch_lightning as pl
##########################################################################################
class FlashModel(pl.LightningModule):
"""DOCSTRING"""
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
return {"loss": loss, "preds": preds}
def training_epoch_end(self, training_step_outputs):
for pred in training_step_outputs:
## do something
pass
pass
##########################################################################################
## Under the hood pseudocode
outs = []
for batch in train_dataloader:
## Step-1: FORWARD
out = training_step(val_batch)
## Step-2: BACKWARD
loss.backward()
## Step-3: Optim step and zero-grad
optimizer.step()
optimizer.zero_grad()
training_epoch_end(outs)
##########################################################################################
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment