This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""InvTorch: Core Invertible Utilities https://github.com/xmodar/invtorch""" | |
import itertools | |
import collections | |
import torch | |
from torch import nn | |
import torch.utils.checkpoint | |
__all__ = ['invertible_checkpoint', 'InvertibleModule'] | |
def invertible_checkpoint(function, inverse, *args, preserve_rng_state=False): | |
"""Checkpoint a model or part of the model without saving the input | |
Source: https://gist.github.com/xmodar/4deb8905ed8c294862972466f69e5d17 | |
Args: | |
function: invertible differentiable function | |
inverse: inverse of `function` (doesn't need to be differentiable) | |
*args: input arguments tuple to be passed to `function` | |
preserve_rng_state: use same seed when calling `function` in backward | |
Returns: | |
Outputs of `function(*args)` with requires_grad=True for all tensors | |
""" | |
def unpack(index): | |
counts[index] -= 1 | |
if unpack.outputs is not None: | |
with torch.inference_mode(): | |
inverted = unpack.inverse(*unpack.outputs) | |
unpack.outputs = unpack.inverse = None | |
inverted = (inverted, ) if torch.is_tensor(inverted) else inverted | |
assert len(inverted) == len(args), 'inverse(outputs) != inputs' | |
list(tensors[i].set_(x) for i, x in zip(args, inverted)) | |
return tensors[index] if counts[index] else tensors.pop(index) | |
preserve_rng_state = bool(preserve_rng_state) | |
tensors = {id(x): x for x in args if torch.is_tensor(x)} | |
assert any(x.requires_grad for x in tensors.values()), 'no input need grad' | |
with torch.autograd.graph.saved_tensors_hooks(id, unpack): | |
unpack.outputs = torch.utils.checkpoint.checkpoint( | |
function, *args, preserve_rng_state=preserve_rng_state) | |
function, inverse, unpack.inverse = None, None, inverse | |
args = [id(x) if torch.is_tensor(x) else None for x in args] | |
counts = dict(collections.Counter(args)) | |
single = torch.is_tensor(unpack.outputs) | |
unpack.outputs = (unpack.outputs, ) if single else unpack.outputs | |
assert isinstance(unpack.outputs, tuple), 'function must return a tuple' | |
list(x.storage().resize_(0) for x in tensors.values()) | |
return unpack.outputs[0] if single else unpack.outputs | |
class InvertibleModule(torch.nn.Module): | |
"""Base invertible `inputs = self.inverse(*self.function(*inputs))`""" | |
def __init__(self, invertible=True, checkpoint=True): | |
super().__init__() | |
self.invertible = invertible # use inverse if checkpointing is enabled | |
self.checkpoint = checkpoint # enables or disables checkpointing | |
def function(self, *inputs): | |
"""Compute the outputs of the function given the inputs""" | |
raise NotImplementedError | |
def inverse(self, *outputs): | |
"""Compute the inputs of the function given the outputs""" | |
raise NotImplementedError | |
@torch.inference_mode() | |
def check_inverse(self, *inputs, atol=1e-5, rtol=1e-3): | |
"""Check if `self.inverse()` is correct for input tensors""" | |
outputs = self.pack(self.inverse(*self.pack(self.function(*inputs)))) | |
for inputs, outputs in itertools.zip_longest(inputs, outputs): | |
is_tensor = torch.is_tensor(inputs) | |
assert is_tensor == torch.is_tensor(outputs) | |
assert not is_tensor or torch.allclose(inputs, outputs, rtol, atol) | |
return True | |
@property | |
def checkpoint(self): | |
"""Whether the module is in checkpoint or pass_through mode""" | |
return self._checkpoint | |
@checkpoint.setter | |
def checkpoint(self, value): | |
if value: | |
self._checkpoint = True | |
else: | |
self._checkpoint = self._invertible = False | |
@property | |
def invertible(self): | |
"""Whether the module is in invertible or simple checkpoint mode""" | |
return self._checkpoint and self._invertible | |
@invertible.setter | |
def invertible(self, value): | |
if value: | |
self._invertible = self._checkpoint = True | |
else: | |
self._invertible = False | |
def forward(self, *inputs): | |
"""Perform the forward pass""" | |
if not self.checkpoint or not torch.is_grad_enabled() or not any( | |
True for x in itertools.chain(self.parameters(), inputs) | |
if torch.is_tensor(x) and x.requires_grad): | |
return self.function(*inputs) | |
if self.invertible: | |
apply = invertible_checkpoint | |
else: | |
apply = torch.utils.checkpoint.checkpoint | |
zero = torch.zeros((), requires_grad=True) # ensure differentiability | |
grads, one, *out = apply(self._function, self._inverse, zero, *inputs) | |
for outputs, requires_grad in zip(out, grads): | |
if torch.is_tensor(outputs) and not requires_grad: | |
outputs.detach_() | |
return out[0] if one.item() else out | |
def _function(self, _, *inputs): | |
"""Wraps `self.function` to handle no requires_grad inputs""" | |
outputs, one = self.pack(self.function(*inputs), True) | |
grads = [torch.is_tensor(x) and x.requires_grad for x in outputs] | |
one = torch.tensor(float(one), requires_grad=True) | |
return (grads, one, *outputs) | |
def _inverse(self, _, one, *outputs): | |
"""Wraps `self.inverse` to handle no requires_grad inputs""" | |
return (one, *self.pack(self.inverse(*outputs))) | |
@staticmethod | |
def pack(inputs, is_tensor=False): | |
"""Pack the inputs into tuple if they were a one tensor""" | |
one = torch.is_tensor(inputs) | |
outputs = (inputs, ) if one else inputs | |
return (outputs, one) if is_tensor else outputs | |
@staticmethod | |
def do_require_grad(*tensors, at_least_one=True): | |
"""Check whether input tensors have `requires_grad=True`""" | |
for tensor in tensors: | |
requires_grad = torch.is_tensor(tensor) and tensor.requires_grad | |
if at_least_one == requires_grad: | |
return at_least_one | |
return not at_least_one | |
class InvertibleLinear(InvertibleModule): | |
"""Invertible Linear Module""" | |
def __init__(self, in_features, out_features): | |
super().__init__(invertible=True, checkpoint=True) | |
self.weight = nn.Parameter(torch.randn(out_features, in_features)) | |
self.bias = nn.Parameter(torch.randn(out_features)) | |
def function(self, inputs): # pylint: disable=arguments-differ | |
outputs = inputs @ self.weight.T + self.bias | |
requires_grad = self.do_require_grad(inputs, self.weight, self.bias) | |
return outputs.requires_grad_(requires_grad) | |
def inverse(self, outputs): # pylint: disable=arguments-differ | |
return (outputs - self.bias) @ self.weight.T.pinverse() | |
def test(): | |
"""Test invtorch""" | |
inputs = torch.randn(10, 3) | |
model = InvertibleLinear(3, 5) | |
print('Is invertible:', model.check_inverse(inputs)) | |
outputs = model(inputs) | |
print('Output requires_grad:', outputs.requires_grad) | |
print('Input was freed:', inputs.storage().size() == 0) | |
outputs.backward(torch.randn_like(outputs)) | |
print('Input was restored:', inputs.storage().size() != 0) | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment