Skip to content

Instantly share code, notes, and snippets.

@gautierdag
Created April 14, 2021 12:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gautierdag/522f9bf265db4b23d3ff64ff19ab6ca3 to your computer and use it in GitHub Desktop.
Save gautierdag/522f9bf265db4b23d3ff64ff19ab6ca3 to your computer and use it in GitHub Desktop.
Example of how to use different optims for different layers or modules using pytorch lightning
import torch
import torch.nn as nn
import pytorch_lightning as pl
class BoringModel(pl.LightningModule):
def __init__(
self
):
super(BoringModel, self).__init__()
self.automatic_optimization = False
self.model_type = "BoringModel"
self.linear1 = nn.Linear(
10, 1
)
self.linear2 = nn.Linear(
10, 1
)
self.loss_fn = nn.BCEWithLogitsLoss()
def forward(
self,
X,
**kwargs,
):
return self.linear2(self.linear1(X))
def training_step(self, batch, batch_nb, optimizer_idx):
opt_1, opt_2 = self.optimizers()
y = self(**batch)
# get loss for train batch
loss = self.loss_fn(y, batch["y"])
# zero_grad needs to be called before backward
opt_1.zero_grad()
opt_2.zero_grad()
# step backwards
self.manual_backward(loss)
# step through optimizers
opt_1.step()
opt_2.step()
def configure_optimizers(self):
opt_1 = torch.optim.Adam(
self.linear1.parameters(),
lr=0.1,
)
opt_2 = torch.optim.Adam(
self.linear2.parameters(),
lr=0.2,
)
return [opt_1, opt_2]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment