Last active
November 20, 2022 06:51
-
-
Save ybj14/7738b119768af2fe765a2d63688f5496 to your computer and use it in GitHub Desktop.
JVP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Refer to `https://j-towns.github.io/2017/06/12/A-new-trick.html` for math details. | |
import torch | |
from torch import nn | |
from torch import autograd | |
def get_jvp(net, x, v): | |
''' | |
Generate jacobian vector product. Requires x.requires_grad() | |
and v.requires_grad(). | |
Args: | |
net: pytorch model, [batch_size, n_input] -> [batch_size, n_output] | |
x: [batch_size, n_input] | |
v: [n_input, batch_size] | |
Returns: | |
jvp: [n_output, batch_size] | |
''' | |
y = mlp(x) | |
u = torch.zeros_like(y, requires_grad=True) | |
ujp = torch.autograd.grad(y, x, grad_outputs=u, create_graph=True) | |
ujpT = ujp[0].transpose(1, 0) | |
jvpT = torch.autograd.grad(ujpT, u, grad_outputs=v, retain_graph=True) | |
jvp = jvpT[0].transpose(1, 0) | |
return jvp | |
def get_jacobian(net, x): | |
''' | |
Generate jacobian vector product. Requires x.requires_grad(). | |
Args: | |
net: pytorch model, [batch_size, n_input] -> [batch_size, n_output] | |
x: [batch_size, n_input] | |
v: [n_input, batch_size] | |
Returns: | |
j: [n_batch, n_output, n_in] | |
''' | |
n_batch = x.shape[0] | |
n_in = x.shape[1] | |
n_output = net(x).shape[1] | |
jacobians = [] | |
# x: [n_batch, n_in] | |
# xs: [n_batch * n_output, n_in] | |
xs = x.repeat(1, n_output).view(-1, n_in) | |
xs_grad = None | |
def hook(grad): | |
nonlocal xs_grad | |
xs_grad = grad | |
xs.register_hook(hook) | |
y = net(xs) | |
y.backward(torch.eye(n_output).repeat(n_batch, 1), retain_graph=True) | |
# xs.grad.data: [n_batch * n_output, n_in] | |
return xs_grad.data.view(n_batch, n_output, n_in) | |
#############################Test Here################################# | |
class MLP(nn.Module): | |
def __init__(self, i, o): | |
super().__init__() | |
self.fc1 = nn.Linear(i, 10) | |
self.fc2 = nn.Linear(10, o) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, x): | |
return self.fc2(self.relu(self.fc1(x))) | |
def n_params(self): | |
return sum(p.numel() for p in self.parameters()) | |
i = 7 | |
o = 3 | |
batch_size = 20 | |
mlp = MLP(i, o) | |
x = torch.rand(batch_size, i).requires_grad_() | |
v = torch.randn(i, batch_size).requires_grad_() | |
# check get_jvp | |
jvp = get_jvp(mlp, x, v) | |
j = get_jacobian(mlp, x) | |
j.shape | |
jvp_ = torch.einsum('ijk,ki->ji', [j, v]) | |
assert (abs(jvp - jvp_) < 1e-7).all(), 'Batched JVP is wrong if assertion failed!' | |
# check get_jacobian | |
j = get_jacobian(mlp, x) | |
a1 = get_jacobian(mlp, x[0].unsqueeze(0)) | |
a2 = get_jacobian(mlp, x)[0].unsqueeze(0) | |
assert (abs(a1 - a2) < 1e-7).all(), 'Batched Jacobian is wrong if assertion failed!' | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment