Skip to content

Instantly share code, notes, and snippets.

@pbloem
Last active October 6, 2020 10:06
Show Gist options
  • Save pbloem/99e1c73f891b6175ca1a1c000b7ecd58 to your computer and use it in GitHub Desktop.
Save pbloem/99e1c73f891b6175ca1a1c000b7ecd58 to your computer and use it in GitHub Desktop.
Gradient estimators
import torch
from torch import nn
import torch.distributions as dist
## REINFORCE
adjacencies, num_edges, targets = load_data(...)
opt = ...
# parametrize edge weights by normal dist
ew_mean = nn.Parameter(torch.randn(num_edges))
ew_std = nn.Parameter(torch.ones(num_edges))
for _ in range(num_epochs):
opt.zero_grad()
# reparametrized sample from the normal dist
ew_dist = dist.Normal(ew_mean, ew_std)
edge_weights = ew_dist.rsample()
# -- we take a reparametrized sample from the normal distribution (see the VAE lecture)
edge_weights = softmax(edge_weights)
# -- or spherical normalization or whatever
# compute two loss terms
gcn_loss = loss(gcn(adjacencies, edge_weights), targets) # regular GCN computation (simple backpropagation is fine here)
with torch.no_grad():
kemeny_loss = - alpha * kemeny(adjacencies, edge_weights)
# -- computation of the kemeny constant is not easily differentiable, so we do it under torch.no_grad
# and use REINFORCE to estimate the gradient.
# actual loss
# loss = gcn_loss + kemeny_loss
# -- we won't get a gradient over this, because kemeny_loss is detached from the comp graph
# estimated loss
loss = gcn_loss + ew_dist.log_prob(edge_weights) * kemeny_loss
# -- To see what happens here, write down the expected gradient of the actual loss under the normal distribution above.
# The expectation over the first term can be estimated simply by letting pytorch work out the gradient. The
# reparametrization results in a gradient on the ew_mean
# -- The expectation over the second term, we rewrite using the score function (so this becomes a score function with a
# single sample). The variable `kemeny_loss` is just a constant, but the log probability over the edge_weight will
# get a gradient for the REINFORCE loss.
# -- Note that the _derivative_ of this second term is the score function (if we see kemeny_loss as a constant). This is
# what we're looking for. By adding this loss, we're tricking pytorch into computing the gradient estimate and
# backpropagating it.
loss.backward()
opt.step
## SPSA
# -- We proceed in the same way: compute the gradient estimate in a detached way (under torch.no_grad()) and add its
# integrand to the loss so that pytorch sets the gradient estimate as the gradient of the relevant nodes and
# backpropagates from there.
edge_weights = nn.Parameter(torch.randn(num_edges))
STD = 1e-7 # size of the perturbation
for _ in range(num_epochs):
opt.zero_grad()
# -- sparse softmax or spherical normalization or whatever
edge_weights = softmax(edge_weights)
# compute the GCN loss once
gcn_loss = loss(gcn(adjacencies, edge_weights), targets) # regular GCN computation (simple backpropagation is fine here)
# and the Kemeny loss twice
with torch.no_grad():
perturbation = torch.randn(num_edges) * STD
# -- This should be Bernoulli for a proper SPSA implementation.
edge_weights0, edge_weights1 = edge_weights + perturbation, edge_weights - perturbation
normalized0, normalized1 = softmax(edge_weights0), softmax(edge_weights1)
kemeny_loss0 = - alpha * kemeny(adjacencies, edge_weights0)
kemeny_loss1 = - alpha * kemeny(adjacencies, edge_weights0)
# -- computation of the kemeny constant is not easily differentiable, so we do it under torch.no_grad
# and use REINFORCE to estimate the gradient.
# actual loss
# loss = gcn_loss + kemeny_loss
# -- we won't get a gradient over this, because kemeny_loss is detached from the comp graph
# estimated loss
loss = gcn_loss + edge_weights * ((kemeny_loss0 - kemeny_loss1)/ (2.0 * perturbation))
# -- Note that when we take the derivative over the second term, the gradient becomes the SPSA estimate of the gradient.
loss.backward()
opt.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment