Skip to content

Instantly share code, notes, and snippets.

@InnovArul
Created August 30, 2022 23:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save InnovArul/845e909c770e9d15723590ed66f4d6ce to your computer and use it in GitHub Desktop.
Save InnovArul/845e909c770e9d15723590ed66f4d6ce to your computer and use it in GitHub Desktop.
using multiple optimizers
import torch, torch.nn as nn
import torch.optim as optim
def print_grads(modules, string):
print(string)
for mod in modules:
for p in mod.parameters():
print(p.grad)
print('**')
print("-----")
def main():
# model creation (encoder - decoder)
enc = nn.Sequential(
nn.Linear(3,5),
nn.ReLU(inplace=True)
)
dec = nn.Linear(5, 3)
# other network
othernet = nn.Sequential(
nn.Linear(5,4),
nn.ReLU(inplace=True),
nn.Linear(4, 2)
)
# define optimizers for autoencoder and other net
autoencoder_optim = optim.SGD(list(enc.parameters()) + list(dec.parameters()), lr=0.0003)
othernet_optim = optim.SGD(othernet.parameters(), lr=0.0002)
# data for the network
data = torch.randn(6, 3)
# zero grad
autoencoder_optim.zero_grad()
othernet_optim.zero_grad()
print_grads([enc, dec, othernet], "initial")
# model forward
bottleneck_out = enc(data)
dec_out = dec(bottleneck_out)
othernet_out = othernet(bottleneck_out)
# calculate autoencoder loss
autoencoder_loss = ((dec_out - data)**2).mean()
autoencoder_loss.backward(retain_graph=True)
print_grads([enc, dec, othernet], "after loss1 backward")
# backward othernet loss
othernet_loss = othernet_out.mean() # a dummy loss
# attach backward hook for bottleneck out
lambda_g = 0.02 # ratio of othernet loss for encoder
bottleneck_out.register_hook(lambda g: g * lambda_g)
othernet_loss.backward()
print_grads([enc, dec, othernet], "after loss2 backward")
# step the optimizers if needed
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment