Skip to content

Instantly share code, notes, and snippets.

@wohlert
Last active May 24, 2019 15:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wohlert/e7242e708b076a3108c7602d99f31e09 to your computer and use it in GitHub Desktop.
Save wohlert/e7242e708b076a3108c7602d99f31e09 to your computer and use it in GitHub Desktop.
A metaprogramming approach with minimal overhead to perform Jacobian computation for arbitrary networks
"""
Registers jacobian for basic PyTorch layers
"""
import torch
from torch import nn
from .metric import register_jacobian
@register_jacobian(nn.Linear)
def _linear_jacobian(module):
# Jacobian is simply the weight matrix transposed
w = module.weight
def jacobian(_x, dx):
eq = "bi, ij -> bij" if dx.dim() == w.dim() else "bij, jk -> bik"
dx = torch.einsum(eq, dx, w)
return dx
return jacobian
@register_jacobian(nn.Tanh)
def _tanh_jacobian(_module):
# dtanh(x)/dx
dtanh = lambda x: (1 - torch.tanh(x)**2)
def jacobian(x, dx):
# Send x back through derivative of activation
dy = dtanh(x)
dy = dy.unsqueeze(1) if dx.dim() != dy.dim() else dy
# Accumulate gradient
return dx * dy
return jacobian
import torch
from torch import nn
# When importing the Jacobians are automatically registered
from .metric import model_jacobian
model = nn.Sequential(
nn.Linear(3, 32),
nn.Tanh(),
nn.Linear(32, 32),
nn.Tanh(),
nn.Linear(32, 2)
)
# Easily compute
x = torch.randn(64, 3)
jac = model_jacobian(model, x)
# Define another model
model2 = nn.Sequential(
nn.Linear(3, 32),
nn.Linear(32, 2)
)
# No need to change anything or write model specific code
jac = model_jacobian(model2, x)
from torch import nn
import torch
_JACOBIAN_REGISTRY = {}
_JACOBIAN_MEMOIZE = {}
def register_jacobian(module_type):
"""
Adds a Jacobian method for the specific module type
to the lookup.
:param module_type: type of module (must inherit nn.Module)
:return: a function that decorates the Jacobian
"""
if not isinstance(module_type, type) and issubclass(module_type, nn.Module):
raise TypeError('Expected type_p to be a Module subclass but got {}'.format(module_type))
def decorator(fun):
_JACOBIAN_REGISTRY[module_type] = fun
_JACOBIAN_MEMOIZE.clear() # reset since lookup order may have changed
return fun
return decorator
def jacobian(module, input, grad=None):
"""
Computes the Jacobian of a module (layer) given
an input and optionally a gradient of previous layer.
The Jacobian of the module must be preregistered through
the method `register_jacobian`.
:param module: layer to compute for
:param input: input to compute Jacobian wrt.
:param grad:
:return: J_input
"""
module_type = type(module)
# Try to find it in cache first
try:
fun = _JACOBIAN_MEMOIZE[module_type]
except KeyError:
# Find a function in registry that matches module
matches = [key for key in _JACOBIAN_REGISTRY if module_type is key]
if not matches:
return NotImplemented
# Jacobian found, save in cache and return
fun = _JACOBIAN_REGISTRY[module_type]
_JACOBIAN_MEMOIZE[module_type] = fun
if fun is NotImplemented:
raise NotImplementedError
return fun(module)(input, grad)
def model_jacobian(model, x):
"""
Computes the Jacobian for a PyTorch module for which
forward computation is straight forward, e.g. nn.Sequential.
All layers have registered Jacobians.
:param model: nn.Module
:param x: input
:return: J_x
"""
# Compute forward pass and save intermediate states
forwards = []
for module in model:
forwards.append(x)
x = module(x)
# Perform backward by accumulating Jacobian recursively
# initial condition J_0 = 1
dx = torch.ones_like(x)
for x, module in zip(reversed(forwards), reversed(model)):
dx = jacobian(module, x, dx)
return dx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment