Skip to content

Instantly share code, notes, and snippets.

@imflash217
Last active January 13, 2021 20:51
Show Gist options
  • Save imflash217/6e412c3bf57f694009aadd3071bbe3c4 to your computer and use it in GitHub Desktop.
Save imflash217/6e412c3bf57f694009aadd3071bbe3c4 to your computer and use it in GitHub Desktop.
PyTorch Lightning Model
import torch as pt
import pytorch_lightning as pl
#######################################################################
class FlashModel(pl.LightningModule):
"""This defines a MODEL"""
def __init__(self, num_layers: int = 3):
super().__init__()
self.layer1 = pt.nn.Linear()
self.layer2 = pt.nn.Linear()
self.layer3 = pt.nn.Linear()
class FlashModel(pl.LightningModule):
"""This defines a SYSTEM"""
def __init__(self,
encoder: pt.nn.Module = None,
decoder: pt.nn.Module = None):
super().__init__()
self.encoder = encoder
self.decoder = decoder
##### INIT ##################################################################
class FlashModel(pl.LightningModule):
""" DON'T DO THIS"""
def __init__(self, params):
self.lr = params.lr
self.coeff_x = params.coeff_x
class FlashModel(pl.LightningModule):
"""Instead DO THIS"""
def __init__(self,
encoder: pt.nn.Module = None,
coeff_x : float = 0.2,
lr : float = 1e-3):
pass
#######################################################################
## A typical PyTorch Lightning Model looks like this =>
class FlashModel(pl.LightningModule):
"""DOCSTRING"""
def __init__(): pass
def forward(): pass
def training_step(): pass
def training_step_end(): pass
def training_epoch_end(): pass
def validation_step(): pass
def validation_step_end(): pass
def validation_epoch_end(): pass
def test_step(): pass
def test_step_end(): pass
def test_epoch_end(): pass
def configure_optimizers(): pass
def any_other_custom_hooks(): pass
#######################################################################
#### FORWARD & TRAINIG STEP ########################################################################
class FlashModel(pl.LightningModule):
"""DOCTSRING"""
def __init__(self): pass
def forward(self, x, ...):
""" use this for inference/predictions"""
embeddings = self.encoder(x)
def training_step(self, batch, ...):
"""use this for training only"""
x, y = batch
z = self.encoder(x)
z = self(x) ## <-- when using data-parallel DP/DDP call this instead of self.encoder()
pred = self.decoder(z)
...
####################################################################################################
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment