Skip to content

Instantly share code, notes, and snippets.

@mfkasim1
Created December 27, 2019 18:08
Show Gist options
  • Save mfkasim1/70af1f43bc6d723d2c3aa197ded930a2 to your computer and use it in GitHub Desktop.
Save mfkasim1/70af1f43bc6d723d2c3aa197ded930a2 to your computer and use it in GitHub Desktop.
Calculate Jacobian matrix in PyTorch
def bjac(gy, y):
"""
Get the batched Jacobian matrix for vector-valued output, gy, with respect to the input, y.
The Jacobian output will take shape of (nbatch, *gy.shape[1:], *y.shape[1:])
"""
# reshape gy to be batched 1D vector
ggy = gy.view(gy.shape[0], -1)
outshape = (gy.shape[0], *gy.shape[1:], *y.shape[1:])
nbatch, nfout = ggy.shape
jacy = torch.zeros((nfout, nbatch, *y.shape[1:])) # will be transposed later
# obtain the gradient of the j-th element of the gy for every batch
for j in range(nfout):
mask = torch.zeros((nbatch, nfout))
mask[:,j] = 1.0
mask = mask.view(gy.shape)
if y.grad is not None:
y.grad.zero_()
gy.backward(mask, retain_graph=True)
# store the grad in the jacobian matrix
jacy[j] = y.grad.data
return jacy.transpose(0, 1).view(outshape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment