Skip to content

Instantly share code, notes, and snippets.

@MasanoriYamada
Forked from sbarratt/torch_jacobian.py
Created December 27, 2019 16:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MasanoriYamada/dd7aed59c6c3e1e69ff1e5501b9a836f to your computer and use it in GitHub Desktop.
Save MasanoriYamada/dd7aed59c6c3e1e69ff1e5501b9a836f to your computer and use it in GitHub Desktop.
Get the jacobian of a vector-valued function that takes batch inputs, in pytorch.
def get_jacobian(net, x, noutputs):
x = x.squeeze()
n = x.size()[0]
x = x.repeat(noutputs, 1)
x.requires_grad_(True)
y = net(x)
y.backward(torch.eye(noutputs))
return x.grad.data
@MasanoriYamada
Copy link
Author

MasanoriYamada commented Dec 27, 2019

import torch
from torch.autograd import grad
from torch import autograd

def get_jacobian(net, x, noutputs):
    x = x.unsqueeze(1) # b, 1 ,in
    #print(x.shape)
    n = x.size()[0]
    x = x.repeat(1, noutputs, 1) # out,in => b, out, in
    #print(x.shape)
    x.requires_grad_(True)
    y = net(x)
    #print(y.shape) # b, o, o
    input_val = torch.eye(noutputs).reshape(1,noutputs, noutputs).repeat(n, 1, 1)
    #print(input_val.shape)
    y.backward(input_val)
    #print(y.shape)
    return x.grad.data


class Net(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(in_dim, out_dim)
    
    def forward(self, x):
        return self.fc1(x)    

batch = 2
num_features = 3
num_outputs = 5
x = torch.ones(batch, num_features)
print(x.shape)
print(net(x).shape)
net = Net(num_features, num_outputs)

get_jacobian(net, x, num_outputs).shape

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment