Skip to content

Instantly share code, notes, and snippets.

@sbarratt
Created May 9, 2019 19:40
  • Star 76 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save sbarratt/37356c46ad1350d4c30aefbd488a4faa to your computer and use it in GitHub Desktop.
Get the jacobian of a vector-valued function that takes batch inputs, in pytorch.
def get_jacobian(net, x, noutputs):
x = x.squeeze()
n = x.size()[0]
x = x.repeat(noutputs, 1)
x.requires_grad_(True)
y = net(x)
y.backward(torch.eye(noutputs))
return x.grad.data
@yiyuezhuo
Copy link

I try to extend the original one to implement multiple dimention support:

def compute_jacobian(f, x, output_dims):
    '''
    Normal:
        f: input_dims -> output_dims
    Jacobian mode:
        f: output_dims x input_dims -> output_dims x output_dims
    '''
    repeat_dims = tuple(output_dims) + (1,) * len(x.shape)
    jac_x = x.detach().repeat(*repeat_dims)
    jac_x.requires_grad_()
    jac_y = f(jac_x)
    
    ml = torch.meshgrid([torch.arange(dim) for dim in output_dims])
    index = [m.flatten() for m in ml]
    gradient = torch.zeros(output_dims + output_dims)
    gradient.__setitem__(tuple(index)*2, 1)
    
    jac_y.backward(gradient)
        
    return jac_x.grad.data

Usage:

w = torch.randn(4,3)
f = lambda x: x @ w
x = torch.randn(2,3,4)
jac = compute_jacobian(f, x, [2,3,3])

'''
>>> w
tensor([[-0.2295,  1.4252,  2.2714],
        [ 0.5877, -2.4398,  0.0136],
        [ 0.3254, -0.3380,  0.1785],
        [ 0.5455,  0.9089, -0.3134]])

>>> jac[1,1,1]
tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 1.4252, -2.4398, -0.3380,  0.9089],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]])
'''

@nimaous
Copy link

nimaous commented Aug 2, 2019

Here is my way to calculate Hessian for a batch.

`Continuing the discussion from Pytorch most efficient Jacobian / Hessian calculation:

This is the fastest way I can tell
suppose you have a loss function of x

        loss = f(x)
        first_drv = torch.zeros(batch_size, x_dim)
        hessian = torch.zeros(batch_size, x_dim, x_dim)
        for n in range(batch_size):
            first_drv[n] = torch.autograd.grad(loss[n], dz,
                                                     create_graph=True, retain_graph=True)[0][n]
            for i in range(x_dim):
                hessian[n][i] = torch.autograd.grad(first_drv[n][i], dz,
                                                        create_graph=True, retain_graph=True)[0][n]

@dohmatob
Copy link

Minibatch version of original get_jacobian code:

def get_jacobian(net, x, num_outputs, batch_size=None, verbose=0):
    """
    Compute jacobian matrix of network outputs w.r.t input x.
    
    Parameters
    ----------
    net: A pytorch callable (e.g a network instance)

    num_outputs: int
        Number of outputs produced by net (per input instance)
        
    batch_size: int, optional
        If None, then do run in full-back mode. Else run in minibatch mode
        with mini-batches of size `batch_size`
    """
    from sklearn.utils import gen_batches
    if batch_size is None:
        batch_size = num_outputs
    num_batches = num_outputs / float(batch_size) + (num_outputs % batch_size != 0)
    x.requires_grad_(False)
    x = x.squeeze(0)
    shape = list(x.shape)
    ones = [1] * len(shape)
    jacs = torch.zeros([num_outputs] + shape)
    
    for b, batch in enumerate(gen_batches(num_outputs, batch_size)):
        this_batch_size = len(jacs[batch])
        x_ = x.repeat(this_batch_size, *ones).requires_grad_(True)
        output = net(x_)
        assert (len(output.shape) == 2 and len(output) == this_batch_size)
        output.backward(torch.eye(num_outputs)[batch, :])
        jacs[batch] = x_.grad
        if verbose and num_batches > 1:
            print("Batch %02i / %02i" % (b + 1, num_batches))
   
    return jacs.data

# Worked example
num_features = 2
num_outputs = 10
x = torch.ones(num_features)
W = torch.randn(num_features, num_outputs)
y = lambda z: z @ W + 2019
batch_size = 3
jacs = get_jacobian(y, x, num_outputs, batch_size, verbose=1)
print("dy/dx:\n%s" % jacs)
print("W.T:\n%s" % W.T)

@tsauri
Copy link

tsauri commented Nov 4, 2019

I need to get jacobian for weight (nn.Parameter of a nn.Module).
Is there way for weight.grad to 'expand' according to number of dim of y

@RylanSchaeffer
Copy link

Can you clarify what the appropriate input shapes are? I'm currently trying to use the gist, but on line x = x.repeat(noutputs, 1), I receive the following error:

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

@sbarratt
Copy link
Author

sbarratt commented Dec 2, 2019 via email

@huiminzeng
Copy link

Genius!

@RylanSchaeffer
Copy link

where is n used?

@harpone
Copy link

harpone commented Dec 14, 2019

Here's a differentiable batch version that doesn't loop over the batch dim (there's still one for loop left over the input dimension though... not sure if it's possible to get rid of that):

def jacobian(f, x):
    """Computes the Jacobian of f w.r.t x.

    This is according to the reverse mode autodiff rule,

    sum_i v^b_i dy^b_i / dx^b_j = sum_i x^b_j R_ji v^b_i,

    where:
    - b is the batch index from 0 to B - 1
    - i, j are the vector indices from 0 to N-1
    - v^b_i is a "test vector", which is set to 1 column-wise to obtain the correct
        column vectors out ot the above expression.

    :param f: function R^N -> R^N
    :param x: torch.tensor of shape [B, N]
    :return: Jacobian matrix (torch.tensor) of shape [B, N, N]
    """

    B, N = x.shape
    y = f(x)
    jacobian = list()
    for i in range(N):
        v = torch.zeros_like(y)
        v[:, i] = 1.
        dy_i_dx = grad(y,
                       x,
                       grad_outputs=v,
                       retain_graph=True,
                       create_graph=True,
                       allow_unused=True)[0]  # shape [B, N]
        jacobian.append(dy_i_dx)

    jacobian = torch.stack(jacobian, dim=2).requires_grad_()

    return jacobian

EDIT: x needs to have requires_grad=True.

@MasanoriYamada
Copy link

MasanoriYamada commented Dec 27, 2019

get_jacobian input x has batch case without loop.

import torch

def get_batch_jacobian(net, x, noutputs):
    x = x.unsqueeze(1) # b, 1 ,in_dim
    n = x.size()[0]
    x = x.repeat(1, noutputs, 1) # b, out_dim, in_dim
    x.requires_grad_(True)
    y = net(x)
    input_val = torch.eye(noutputs).reshape(1,noutputs, noutputs).repeat(n, 1, 1)
    y.backward(input_val)
    return x.grad.data

usage

class Net(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(in_dim, out_dim)
    
    def forward(self, x):
        return torch.nn.functional.relu(self.fc1(x))

batch = 2
num_features = 3
num_outputs = 5
x = torch.randn(batch, num_features)
print(x.shape)
net = Net(num_features, num_outputs)
print(net(x).shape)
result = get_batch_jacobian(net, x, num_outputs)
print(result.shape)

Here my code tensor shape

input x: torch.Size([2, 3])
output y: torch.Size([2, 5])
jacobian dy/dx: torch.Size([2, 5, 3])

I checked consistency https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa#file-torch_jacobian-py

ret0 = get_batch_jacobian(net, x, num_outputs) # my code
ret1 = get_jacobian(net, x[0], num_outputs) # sbarratt code
ret2 = get_jacobian(net, x[1], num_outputs) # sbarratt code

ret0[0] == ret1
ret0[1] == ret2

@RylanSchaeffer
Copy link

@MasanoriYamada , can you explain how this works? What does torch.eye(noutputs).reshape(1,noutputs, noutputs).repeat(n, 1, 1) do?

@RylanSchaeffer
Copy link

@MasanoriYamada , how can I make this work when my inputs are 4d Tensors? When I try calling the model's forward method, I receive the following error:

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 6 3 5 5, but got 5-dimensional input of size [batch_size, output_dimension, num_channels, image_width, image_height] instead

@MasanoriYamada
Copy link

MasanoriYamada commented Jan 16, 2020

@RylanSchaeffer, My code only supports flatten tensor with batch. Please, show me that input-output tensor and network in your case (simple network is the best)

@MasanoriYamada
Copy link

@RylanSchaeffer, how about this https://gist.github.com/MasanoriYamada/d1d8ca884d200e73cca66a4387c7470a
Disclaimer: Value not tested for correctness

@RylanSchaeffer
Copy link

RylanSchaeffer commented Jan 16, 2020 via email

@RylanSchaeffer
Copy link

Question for anyone: why do we need to tile the input before passing it through the graph (net, in sbarratt's original code)? Why can't we tile the input and the output after the forward pass?

@RylanSchaeffer
Copy link

RylanSchaeffer commented Jan 20, 2020

I'm trying to do this currently, but I'm receiving the error: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Here's what I'm doing. Let x be the input to the graph with shape (batch size, input dimension) and let y be the output of the graph with shape (batch size, output dimension). I then select a subset of N random unit vectors. I stack x with itself and y with itself as follows:

x = torch.cat([x for _ in range(N)], dim=0)

and

y = torch.cat([y for _ in range(N)], dim=0)

x then has shape (N * batch size, input dim) and y has shape (N * batch size, output dim). But then, when I try to use autograd, I receive the aforementioned error .

        jacobian = torch.autograd.grad(
            outputs=y,
            inputs=y,
            grad_outputs=subset_unit_vectors,
            retain_graph=True,
            only_inputs=True)[0]

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Does anyone know why this is, and is there a way to make this post-forward pass tiling work?

@Jeff1995
Copy link

@RylanSchaeffer
I was trying the same thing with unsqueeze().expand(), but it leads to the same autograd error. I suppose it's because the newly created x and y nodes are just hanging in the computation graph, and do not really have a dependency, so autograd would no longer work.

@RylanSchaeffer
Copy link

RylanSchaeffer commented Apr 29, 2020 via email

@ChenAo-Phys
Copy link

I met this page about 1 year ago. This is really a nice trick, but it's a pity that it needs to forward pass a large batch and becomes a huge challenge to my GPU room. Recently I found an interesting way to bypass this problem. It's really interesting to solve a problem I encountered 1 year ago. https://github.com/ChenAo-Phys/pytorch-Jacobian

@justinblaber
Copy link

justinblaber commented May 31, 2020

If I'm understanding this correctly, this code will forward pass noutputs times just to compute the jacobian once (but do it in a vectorized way)... The 1.5.0 autograd jacobian computation seems to compute the output once but then forloops over it and call backward one by one (@rjeli first comment) which will for sure be slow... Both tradeoffs seem sub optimal.

Anyone know if there's an update on this? Or is pytorch really not meant to compute jacobians?

@sbarratt
Copy link
Author

sbarratt commented May 31, 2020 via email

@RylanSchaeffer
Copy link

@justinblaber , autodiff either computes matrix-vector products or vector-matrix products (depending on forward mode / reverse mode). The Jacobian is a matrix - there's no easy way to recover this by itself. Either you perform multiple backwards passes, using different elementary basis vector on each pass, or you blow the batch size up and do one massive backwards pass. There's no way around this.

@a-z-e-r-i-l-a
Copy link

how about this experimental api for jacobian: https://pytorch.org/docs/stable/_modules/torch/autograd/functional.html#jacobian
is it good?

@justinblaber
Copy link

how about this experimental api for jacobian: https://pytorch.org/docs/stable/_modules/torch/autograd/functional.html#jacobian
is it good?

I took a look and:

for j in range(out.nelement()):
            vj = _autograd_grad((out.reshape(-1)[j],), inputs, retain_graph=True, create_graph=create_graph)

It's just for-looping over the output and computing the gradient one by one (i.e. each row of the jacobian one by one). This will for sure be slow as hell if you have a lot of outputs. I actually think it's a tad bit deceiving that they advertise this functionality, because really the functionality just isn't there.

And actually, to be honest I wanted the jacobian earlier to do some gauss newton type optimization, but I've actually since discovered that the optim.LBFGS optimizer (now built into pytorch) might work well for my problem. I think it even has some backtracking type stuff built into it. So for now I don't think I even need the jacobian anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment