Skip to content

Instantly share code, notes, and snippets.

@imflash217
Created January 14, 2021 22:15
Show Gist options
  • Save imflash217/0769203a6b0c79e593320202233d1d92 to your computer and use it in GitHub Desktop.
Save imflash217/0769203a6b0c79e593320202233d1d92 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.Functional as F
import pytorch_lightning as pl
###########################################################################################
class FlashModel(pl.LightningModule):
"""DOCSTRING"""
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
## logs metrics for each training_step
## and the average across each epoch, to the logger and progress-bar
self.log("train_loss", loss,
on_step=True,
on_epoch=True,
logger=True,
prog_bar=True
)
return loss
###########################################################################################
## Under the hood
outs = []
for batch in train_dataloader:
## Step-1: FORWARD
out = training_step(val_batch)
## Step-2: BACKWARD
loss.backward()
## optim step and cread grads
optimizer.step()
optimizer.zero_grad()
epoch_metric = torch.mean(torch.stack([x["train_loss"] for x in outs]))
###########################################################################################
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment