Skip to content

Instantly share code, notes, and snippets.

@ybj14
Last active November 20, 2022 06:51
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ybj14/7738b119768af2fe765a2d63688f5496 to your computer and use it in GitHub Desktop.
Save ybj14/7738b119768af2fe765a2d63688f5496 to your computer and use it in GitHub Desktop.
JVP
# Refer to `https://j-towns.github.io/2017/06/12/A-new-trick.html` for math details.
import torch
from torch import nn
from torch import autograd
def get_jvp(net, x, v):
'''
Generate jacobian vector product. Requires x.requires_grad()
and v.requires_grad().
Args:
net: pytorch model, [batch_size, n_input] -> [batch_size, n_output]
x: [batch_size, n_input]
v: [n_input, batch_size]
Returns:
jvp: [n_output, batch_size]
'''
y = mlp(x)
u = torch.zeros_like(y, requires_grad=True)
ujp = torch.autograd.grad(y, x, grad_outputs=u, create_graph=True)
ujpT = ujp[0].transpose(1, 0)
jvpT = torch.autograd.grad(ujpT, u, grad_outputs=v, retain_graph=True)
jvp = jvpT[0].transpose(1, 0)
return jvp
def get_jacobian(net, x):
'''
Generate jacobian vector product. Requires x.requires_grad().
Args:
net: pytorch model, [batch_size, n_input] -> [batch_size, n_output]
x: [batch_size, n_input]
v: [n_input, batch_size]
Returns:
j: [n_batch, n_output, n_in]
'''
n_batch = x.shape[0]
n_in = x.shape[1]
n_output = net(x).shape[1]
jacobians = []
# x: [n_batch, n_in]
# xs: [n_batch * n_output, n_in]
xs = x.repeat(1, n_output).view(-1, n_in)
xs_grad = None
def hook(grad):
nonlocal xs_grad
xs_grad = grad
xs.register_hook(hook)
y = net(xs)
y.backward(torch.eye(n_output).repeat(n_batch, 1), retain_graph=True)
# xs.grad.data: [n_batch * n_output, n_in]
return xs_grad.data.view(n_batch, n_output, n_in)
#############################Test Here#################################
class MLP(nn.Module):
def __init__(self, i, o):
super().__init__()
self.fc1 = nn.Linear(i, 10)
self.fc2 = nn.Linear(10, o)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
def n_params(self):
return sum(p.numel() for p in self.parameters())
i = 7
o = 3
batch_size = 20
mlp = MLP(i, o)
x = torch.rand(batch_size, i).requires_grad_()
v = torch.randn(i, batch_size).requires_grad_()
# check get_jvp
jvp = get_jvp(mlp, x, v)
j = get_jacobian(mlp, x)
j.shape
jvp_ = torch.einsum('ijk,ki->ji', [j, v])
assert (abs(jvp - jvp_) < 1e-7).all(), 'Batched JVP is wrong if assertion failed!'
# check get_jacobian
j = get_jacobian(mlp, x)
a1 = get_jacobian(mlp, x[0].unsqueeze(0))
a2 = get_jacobian(mlp, x)[0].unsqueeze(0)
assert (abs(a1 - a2) < 1e-7).all(), 'Batched Jacobian is wrong if assertion failed!'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment