Last active November 17, 2021 15:38
"""InvTorch: Core Invertible Utilities"""
import itertools
import collections
import torch
from torch import nn
import torch.utils.checkpoint
__all__ = ['invertible_checkpoint', 'InvertibleModule']
def invertible_checkpoint(function, inverse, *args, preserve_rng_state=False):
"""Checkpoint a model or part of the model without saving the input
function: invertible differentiable function
inverse: inverse of `function` (doesn't need to be differentiable)
*args: input arguments tuple to be passed to `function`
preserve_rng_state: use same seed when calling `function` in backward
Outputs of `function(*args)` with requires_grad=True for all tensors
def unpack(index):
counts[index] -= 1
if unpack.outputs is not None:
with torch.inference_mode():
inverted = unpack.inverse(*unpack.outputs)
unpack.outputs = unpack.inverse = None
inverted = (inverted, ) if torch.is_tensor(inverted) else inverted
assert len(inverted) == len(args), 'inverse(outputs) != inputs'
list(tensors[i].set_(x) for i, x in zip(args, inverted))
return tensors[index] if counts[index] else tensors.pop(index)
preserve_rng_state = bool(preserve_rng_state)
tensors = {id(x): x for x in args if torch.is_tensor(x)}
assert any(x.requires_grad for x in tensors.values()), 'no input need grad'
with torch.autograd.graph.saved_tensors_hooks(id, unpack):
unpack.outputs = torch.utils.checkpoint.checkpoint(
function, *args, preserve_rng_state=preserve_rng_state)
function, inverse, unpack.inverse = None, None, inverse
args = [id(x) if torch.is_tensor(x) else None for x in args]
counts = dict(collections.Counter(args))
single = torch.is_tensor(unpack.outputs)
unpack.outputs = (unpack.outputs, ) if single else unpack.outputs
assert isinstance(unpack.outputs, tuple), 'function must return a tuple'
list( for x in tensors.values())
return unpack.outputs[0] if single else unpack.outputs
class InvertibleModule(torch.nn.Module):
"""Base invertible `inputs = self.inverse(*self.function(*inputs))`"""
def __init__(self, invertible=True, checkpoint=True):
self.invertible = invertible # use inverse if checkpointing is enabled
self.checkpoint = checkpoint # enables or disables checkpointing
def function(self, *inputs):
"""Compute the outputs of the function given the inputs"""
raise NotImplementedError
def inverse(self, *outputs):
"""Compute the inputs of the function given the outputs"""
raise NotImplementedError
def check_inverse(self, *inputs, atol=1e-5, rtol=1e-3):
"""Check if `self.inverse()` is correct for input tensors"""
outputs = self.pack(self.inverse(*self.pack(self.function(*inputs))))
for inputs, outputs in itertools.zip_longest(inputs, outputs):
is_tensor = torch.is_tensor(inputs)
assert is_tensor == torch.is_tensor(outputs)
assert not is_tensor or torch.allclose(inputs, outputs, rtol, atol)
return True
def checkpoint(self):
"""Whether the module is in checkpoint or pass_through mode"""
return self._checkpoint
def checkpoint(self, value):
if value:
self._checkpoint = True
self._checkpoint = self._invertible = False
def invertible(self):
"""Whether the module is in invertible or simple checkpoint mode"""
return self._checkpoint and self._invertible
def invertible(self, value):
if value:
self._invertible = self._checkpoint = True
self._invertible = False
def forward(self, *inputs):
"""Perform the forward pass"""
if not self.checkpoint or not torch.is_grad_enabled() or not any(
True for x in itertools.chain(self.parameters(), inputs)
if torch.is_tensor(x) and x.requires_grad):
return self.function(*inputs)
if self.invertible:
apply = invertible_checkpoint
apply = torch.utils.checkpoint.checkpoint
zero = torch.zeros((), requires_grad=True) # ensure differentiability
grads, one, *out = apply(self._function, self._inverse, zero, *inputs)
for outputs, requires_grad in zip(out, grads):
if torch.is_tensor(outputs) and not requires_grad:
return out[0] if one.item() else out
def _function(self, _, *inputs):
"""Wraps `self.function` to handle no requires_grad inputs"""
outputs, one = self.pack(self.function(*inputs), True)
grads = [torch.is_tensor(x) and x.requires_grad for x in outputs]
one = torch.tensor(float(one), requires_grad=True)
return (grads, one, *outputs)
def _inverse(self, _, one, *outputs):
"""Wraps `self.inverse` to handle no requires_grad inputs"""
return (one, *self.pack(self.inverse(*outputs)))
def pack(inputs, is_tensor=False):
"""Pack the inputs into tuple if they were a one tensor"""
one = torch.is_tensor(inputs)
outputs = (inputs, ) if one else inputs
return (outputs, one) if is_tensor else outputs
def do_require_grad(*tensors, at_least_one=True):
"""Check whether input tensors have `requires_grad=True`"""
for tensor in tensors:
requires_grad = torch.is_tensor(tensor) and tensor.requires_grad
if at_least_one == requires_grad:
return at_least_one
return not at_least_one
class InvertibleLinear(InvertibleModule):
"""Invertible Linear Module"""
def __init__(self, in_features, out_features):
super().__init__(invertible=True, checkpoint=True)
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))
def function(self, inputs): # pylint: disable=arguments-differ
outputs = inputs @ self.weight.T + self.bias
requires_grad = self.do_require_grad(inputs, self.weight, self.bias)
return outputs.requires_grad_(requires_grad)
def inverse(self, outputs): # pylint: disable=arguments-differ
return (outputs - self.bias) @ self.weight.T.pinverse()
def test():
"""Test invtorch"""
inputs = torch.randn(10, 3)
model = InvertibleLinear(3, 5)
print('Is invertible:', model.check_inverse(inputs))
outputs = model(inputs)
print('Output requires_grad:', outputs.requires_grad)
print('Input was freed:', == 0)
print('Input was restored:', != 0)
if __name__ == '__main__':
