Skip to content

Instantly share code, notes, and snippets.

@tuananhle7
Last active February 18, 2020 04:55
Show Gist options
  • Save tuananhle7/5cc17ba09909498bd9c50ddbb054246b to your computer and use it in GitHub Desktop.
Save tuananhle7/5cc17ba09909498bd9c50ddbb054246b to your computer and use it in GitHub Desktop.
import torch
from torch import nn
theta = nn.Parameter(torch.tensor(1.))
phi = nn.Parameter(torch.tensor(2.))
log_joint = theta**2
log_q = phi**2
log_weight = log_joint + log_q.detach()
theta_loss = 2 * log_weight
phi_loss = 3 * log_weight.detach() * log_q
loss = theta_loss + phi_loss
loss.backward()
print('{} should be (2 * 2 * theta = 4 * 1 = ) 4'.format(theta.grad))
print('{} should be (3 * (theta^2 + phi^2) * 2 * phi = 3 * 5 * 2 * 2 = ) 60'.format(phi.grad))
# prints:
# 4.0 should be (2 * 2 * theta = 4 * 1 = ) 4
# 60.0 should be (3 * (theta^2 + phi^2) * 2 * phi = 3 * 5 * 2 * 2 = ) 60
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment