Last active
March 12, 2023 18:37
-
-
Save kylemcdonald/e8ca989584b3b0e6526c0a737ed412f0 to your computer and use it in GitHub Desktop.
PyTorch ACAI (1807.07543).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In such a pipeline:
optim1 = optim.Adam(G.parameters())
optim2 = optim.Adam(D.parameters())
G = Model1()
D = Model2()
recons, z = G(input)
loss1 = loss_func1(recons)
diff = D(z)
loss2 = loss_func2(diff)
loss3 = loss_func3(diff)
loss_G = loss1 + loss2 # we don’t want to update D parameters here
loss_D = loss3
Solution #1
optim1.zero_grad()
loss_G.backward(retain_graph=True)
optim2.zero_grad()
loss_D.backward()
optim1.step()
optim2.step()
Solution #2
optim1.zero_grad()
loss_G.backward(retain_graph=True, inputs=list(G.parameters()))
optim1.step()
optim2.zero_grad()
loss_D.backward(inputs=list(D.parameters()))
optim2.step()
Both of the solutions come from here.