Skip to content

Instantly share code, notes, and snippets.

@kylemcdonald
Last active March 12, 2023 18:37
Show Gist options
  • Save kylemcdonald/e8ca989584b3b0e6526c0a737ed412f0 to your computer and use it in GitHub Desktop.
Save kylemcdonald/e8ca989584b3b0e6526c0a737ed412f0 to your computer and use it in GitHub Desktop.
PyTorch ACAI (1807.07543).
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@stamate
Copy link

stamate commented Apr 16, 2022

In such a pipeline:

optim1 = optim.Adam(G.parameters())
optim2 = optim.Adam(D.parameters())
G = Model1()
D = Model2()
recons, z = G(input)
loss1 = loss_func1(recons)
diff = D(z)
loss2 = loss_func2(diff)
loss3 = loss_func3(diff)
loss_G = loss1 + loss2 # we don’t want to update D parameters here
loss_D = loss3

Solution #1
optim1.zero_grad()
loss_G.backward(retain_graph=True)
optim2.zero_grad()
loss_D.backward()
optim1.step()
optim2.step()

Solution #2
optim1.zero_grad()
loss_G.backward(retain_graph=True, inputs=list(G.parameters()))
optim1.step()
optim2.zero_grad()
loss_D.backward(inputs=list(D.parameters()))
optim2.step()

Both of the solutions come from here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment