Created
June 27, 2024 16:19
-
-
Save soulitzer/ec1049a947be046de7fbc2af61a4ee8c to your computer and use it in GitHub Desktop.
A new way to do AC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import functools | |
from torch.utils._python_dispatch import TorchDispatchMode | |
import torch.utils._pytree as pytree | |
from torch.utils.weak import WeakTensorKeyDictionary | |
class RecomputableTensor(torch.Tensor): | |
@staticmethod | |
def __new__(cls, t, func, args): | |
return torch.Tensor._make_wrapper_subclass( | |
cls, t.shape, t.stride(), t.storage_offset(), | |
torch.contiguous_format, t.dtype, torch.strided, | |
t.device, False, t.requires_grad, "sizes", False, False, None) | |
def __init__(self, t, func=None, args=None): | |
super().__init__() | |
if func is not None: | |
self.cached_value = None | |
self.args = args | |
self.func = func | |
else: | |
self.cached_value = t | |
self.func = None | |
self.counter = 0 | |
def increment(self): | |
self.counter += 1 | |
def decrement(self): | |
self.counter -= 1 | |
if self.counter == 0: | |
print("cleared cache", id(self.cached_value)) | |
self.cached_value = None | |
def recompute(self): | |
print("recompute", self.func) | |
if self.cached_value is not None: | |
print("already cached", id(self.cached_value)) | |
return self.cached_value | |
new_args = pytree.tree_map_only(RecomputableTensor, lambda s: s.recompute(), self.args) | |
out = self.func(*new_args) | |
self.cached_value = out | |
print("done recompute and cached", id(out)) | |
return out | |
def __repr__(self): | |
return f"RecomputableTensor({self._t})" | |
@classmethod | |
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
assert False | |
def is_plain_tensor(t): | |
return isinstance(t, torch.Tensor) | |
def maybe_create(recomputable_tensors, t): | |
if t not in recomputable_tensors: | |
print("input: ", t) | |
# creates a reference cycle :( | |
recomputable_tensors[t] = RecomputableTensor(t, None, None) | |
return recomputable_tensors[t] | |
class RecomputableMode(TorchDispatchMode): | |
def __init__(self, recomputable_tensors): | |
self.recomputable_tensors = recomputable_tensors | |
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
kwargs = {} if kwargs is None else kwargs | |
out = func(*args, **kwargs) | |
wrapped_args = pytree.tree_map_only(torch.Tensor, functools.partial(maybe_create, self.recomputable_tensors), args) | |
if func is torch.ops.aten.exp.default: | |
# selective AC, policy is don't recompute exp | |
self.recomputable_tensors[out] = RecomputableTensor(out, None, None) | |
else: | |
self.recomputable_tensors[out] = RecomputableTensor(out, func, wrapped_args) | |
return out | |
class CheckpointHook(torch.autograd.graph.saved_tensors_hooks): | |
def __init__(self, recomputable_tensors): | |
def pack_hook(raw_tensor): | |
recomputable_tensor = maybe_create(recomputable_tensors, raw_tensor) | |
recomputable_tensor.increment() | |
return recomputable_tensors[raw_tensor] | |
def unpack_hook(recomputable_tensor): | |
out = recomputable_tensor.recompute() | |
recomputable_tensor.decrement() | |
return out | |
super().__init__(pack_hook, unpack_hook) | |
recomputable_tensors = WeakTensorKeyDictionary() | |
# Two independent chains | |
# ---------------------- | |
a = torch.tensor(1., requires_grad=True) | |
b = torch.tensor(2., requires_grad=True) | |
with CheckpointHook(recomputable_tensors), RecomputableMode(recomputable_tensors): | |
out = a.sin().sin().sin() | |
out2 = b.cos().cos().cos() | |
out3 = out + out2 | |
out3.backward() | |
print(a.grad, b.grad) | |
# input: tensor(1., requires_grad=True) | |
# input: tensor(2., requires_grad=True) | |
# V0312 22:21:14.781496 8114936896 torch/autograd/graph.py:726] Executing: <AddBackward0 object at 0x1322318e0> with grad_outputs: [f32[]] | |
# V0312 22:21:14.781547 8114936896 torch/autograd/graph.py:726] Executing: <CosBackward0 object at 0x1322318e0> with grad_outputs: [f32[]] | |
# recompute | |
# recompute | |
# recompute | |
# already cached | |
# done recompute | |
# done recompute | |
# cleared cache | |
# V0312 22:21:14.781641 8114936896 torch/autograd/graph.py:726] Executing: <CosBackward0 object at 0x1322318e0> with grad_outputs: [f32[]] | |
# recompute | |
# already cached | |
# cleared cache | |
# V0312 22:21:14.781690 8114936896 torch/autograd/graph.py:726] Executing: <CosBackward0 object at 0x1322318e0> with grad_outputs: [f32[]] | |
# recompute | |
# already cached | |
# cleared cache | |
# V0312 22:21:14.781728 8114936896 torch/autograd/graph.py:726] Executing: <AccumulateGrad object at 0x1322318e0> with grad_outputs: [f32[]] | |
# V0312 22:21:14.781758 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x1322318e0> with grad_outputs: [f32[]] | |
# recompute | |
# recompute | |
# recompute | |
# already cached | |
# done recompute | |
# done recompute | |
# cleared cache | |
# V0312 22:21:14.781829 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x1322318e0> with grad_outputs: [f32[]] | |
# recompute | |
# already cached | |
# cleared cache | |
# V0312 22:21:14.781869 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x1322318e0> with grad_outputs: [f32[]] | |
# recompute | |
# already cached | |
# cleared cache | |
# V0312 22:21:14.781971 8114936896 torch/autograd/graph.py:726] Executing: <AccumulateGrad object at 0x1322318e0> with grad_outputs: [f32[]] | |
# Selective AC | |
# ----------- | |
a = torch.tensor(1., requires_grad=True) | |
with CheckpointHook(recomputable_tensors), RecomputableMode(recomputable_tensors): | |
out = a.sin().sin().exp().sin().sin().exp().sin().sin() | |
out.backward() | |
print(a.grad) | |
# input: tensor(1., requires_grad=True) | |
# V0312 22:34:10.851179 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fb5670> with grad_outputs: [f32[]] | |
# recompute aten.sin.default | |
# recompute None | |
# already cached 4881238672 | |
# done recompute and cached 4881844720 | |
# cleared cache 4881844720 | |
# V0312 22:34:10.851357 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fa80d0> with grad_outputs: [f32[]] | |
# recompute None | |
# already cached 4881238672 | |
# V0312 22:34:10.851450 8114936896 torch/autograd/graph.py:726] Executing: <ExpBackward0 object at 0x122fa80d0> with grad_outputs: [f32[]] | |
# recompute None | |
# already cached 4881238672 | |
# cleared cache 4881238672 | |
# V0312 22:34:10.851538 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fa80d0> with grad_outputs: [f32[]] | |
# recompute aten.sin.default | |
# recompute None | |
# already cached 4881721120 | |
# done recompute and cached 4882103248 | |
# cleared cache 4882103248 | |
# V0312 22:34:10.851649 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fb5670> with grad_outputs: [f32[]] | |
# recompute None | |
# already cached 4881721120 | |
# V0312 22:34:10.851719 8114936896 torch/autograd/graph.py:726] Executing: <ExpBackward0 object at 0x122fb5670> with grad_outputs: [f32[]] | |
# recompute None | |
# already cached 4881721120 | |
# cleared cache 4881721120 | |
# V0312 22:34:10.851792 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fb5670> with grad_outputs: [f32[]] | |
# recompute aten.sin.default | |
# recompute None | |
# already cached 4881149360 | |
# done recompute and cached 4881721120 | |
# cleared cache 4881721120 | |
# V0312 22:34:10.851898 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fa80d0> with grad_outputs: [f32[]] | |
# recompute None | |
# already cached 4881149360 | |
# cleared cache 4881149360 | |
# V0312 22:34:10.851972 8114936896 torch/autograd/graph.py:726] Executing: <AccumulateGrad object at 0x122fa8a60> with grad_outputs: [f32[]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment