Created
December 27, 2019 18:08
-
-
Save mfkasim1/70af1f43bc6d723d2c3aa197ded930a2 to your computer and use it in GitHub Desktop.
Calculate Jacobian matrix in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def bjac(gy, y): | |
""" | |
Get the batched Jacobian matrix for vector-valued output, gy, with respect to the input, y. | |
The Jacobian output will take shape of (nbatch, *gy.shape[1:], *y.shape[1:]) | |
""" | |
# reshape gy to be batched 1D vector | |
ggy = gy.view(gy.shape[0], -1) | |
outshape = (gy.shape[0], *gy.shape[1:], *y.shape[1:]) | |
nbatch, nfout = ggy.shape | |
jacy = torch.zeros((nfout, nbatch, *y.shape[1:])) # will be transposed later | |
# obtain the gradient of the j-th element of the gy for every batch | |
for j in range(nfout): | |
mask = torch.zeros((nbatch, nfout)) | |
mask[:,j] = 1.0 | |
mask = mask.view(gy.shape) | |
if y.grad is not None: | |
y.grad.zero_() | |
gy.backward(mask, retain_graph=True) | |
# store the grad in the jacobian matrix | |
jacy[j] = y.grad.data | |
return jacy.transpose(0, 1).view(outshape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment