Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Created June 27, 2024 16:19
Show Gist options
  • Save soulitzer/ec1049a947be046de7fbc2af61a4ee8c to your computer and use it in GitHub Desktop.
Save soulitzer/ec1049a947be046de7fbc2af61a4ee8c to your computer and use it in GitHub Desktop.
A new way to do AC
import torch
import functools
from torch.utils._python_dispatch import TorchDispatchMode
import torch.utils._pytree as pytree
from torch.utils.weak import WeakTensorKeyDictionary
class RecomputableTensor(torch.Tensor):
@staticmethod
def __new__(cls, t, func, args):
return torch.Tensor._make_wrapper_subclass(
cls, t.shape, t.stride(), t.storage_offset(),
torch.contiguous_format, t.dtype, torch.strided,
t.device, False, t.requires_grad, "sizes", False, False, None)
def __init__(self, t, func=None, args=None):
super().__init__()
if func is not None:
self.cached_value = None
self.args = args
self.func = func
else:
self.cached_value = t
self.func = None
self.counter = 0
def increment(self):
self.counter += 1
def decrement(self):
self.counter -= 1
if self.counter == 0:
print("cleared cache", id(self.cached_value))
self.cached_value = None
def recompute(self):
print("recompute", self.func)
if self.cached_value is not None:
print("already cached", id(self.cached_value))
return self.cached_value
new_args = pytree.tree_map_only(RecomputableTensor, lambda s: s.recompute(), self.args)
out = self.func(*new_args)
self.cached_value = out
print("done recompute and cached", id(out))
return out
def __repr__(self):
return f"RecomputableTensor({self._t})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
assert False
def is_plain_tensor(t):
return isinstance(t, torch.Tensor)
def maybe_create(recomputable_tensors, t):
if t not in recomputable_tensors:
print("input: ", t)
# creates a reference cycle :(
recomputable_tensors[t] = RecomputableTensor(t, None, None)
return recomputable_tensors[t]
class RecomputableMode(TorchDispatchMode):
def __init__(self, recomputable_tensors):
self.recomputable_tensors = recomputable_tensors
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
out = func(*args, **kwargs)
wrapped_args = pytree.tree_map_only(torch.Tensor, functools.partial(maybe_create, self.recomputable_tensors), args)
if func is torch.ops.aten.exp.default:
# selective AC, policy is don't recompute exp
self.recomputable_tensors[out] = RecomputableTensor(out, None, None)
else:
self.recomputable_tensors[out] = RecomputableTensor(out, func, wrapped_args)
return out
class CheckpointHook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, recomputable_tensors):
def pack_hook(raw_tensor):
recomputable_tensor = maybe_create(recomputable_tensors, raw_tensor)
recomputable_tensor.increment()
return recomputable_tensors[raw_tensor]
def unpack_hook(recomputable_tensor):
out = recomputable_tensor.recompute()
recomputable_tensor.decrement()
return out
super().__init__(pack_hook, unpack_hook)
recomputable_tensors = WeakTensorKeyDictionary()
# Two independent chains
# ----------------------
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
with CheckpointHook(recomputable_tensors), RecomputableMode(recomputable_tensors):
out = a.sin().sin().sin()
out2 = b.cos().cos().cos()
out3 = out + out2
out3.backward()
print(a.grad, b.grad)
# input: tensor(1., requires_grad=True)
# input: tensor(2., requires_grad=True)
# V0312 22:21:14.781496 8114936896 torch/autograd/graph.py:726] Executing: <AddBackward0 object at 0x1322318e0> with grad_outputs: [f32[]]
# V0312 22:21:14.781547 8114936896 torch/autograd/graph.py:726] Executing: <CosBackward0 object at 0x1322318e0> with grad_outputs: [f32[]]
# recompute
# recompute
# recompute
# already cached
# done recompute
# done recompute
# cleared cache
# V0312 22:21:14.781641 8114936896 torch/autograd/graph.py:726] Executing: <CosBackward0 object at 0x1322318e0> with grad_outputs: [f32[]]
# recompute
# already cached
# cleared cache
# V0312 22:21:14.781690 8114936896 torch/autograd/graph.py:726] Executing: <CosBackward0 object at 0x1322318e0> with grad_outputs: [f32[]]
# recompute
# already cached
# cleared cache
# V0312 22:21:14.781728 8114936896 torch/autograd/graph.py:726] Executing: <AccumulateGrad object at 0x1322318e0> with grad_outputs: [f32[]]
# V0312 22:21:14.781758 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x1322318e0> with grad_outputs: [f32[]]
# recompute
# recompute
# recompute
# already cached
# done recompute
# done recompute
# cleared cache
# V0312 22:21:14.781829 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x1322318e0> with grad_outputs: [f32[]]
# recompute
# already cached
# cleared cache
# V0312 22:21:14.781869 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x1322318e0> with grad_outputs: [f32[]]
# recompute
# already cached
# cleared cache
# V0312 22:21:14.781971 8114936896 torch/autograd/graph.py:726] Executing: <AccumulateGrad object at 0x1322318e0> with grad_outputs: [f32[]]
# Selective AC
# -----------
a = torch.tensor(1., requires_grad=True)
with CheckpointHook(recomputable_tensors), RecomputableMode(recomputable_tensors):
out = a.sin().sin().exp().sin().sin().exp().sin().sin()
out.backward()
print(a.grad)
# input: tensor(1., requires_grad=True)
# V0312 22:34:10.851179 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fb5670> with grad_outputs: [f32[]]
# recompute aten.sin.default
# recompute None
# already cached 4881238672
# done recompute and cached 4881844720
# cleared cache 4881844720
# V0312 22:34:10.851357 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fa80d0> with grad_outputs: [f32[]]
# recompute None
# already cached 4881238672
# V0312 22:34:10.851450 8114936896 torch/autograd/graph.py:726] Executing: <ExpBackward0 object at 0x122fa80d0> with grad_outputs: [f32[]]
# recompute None
# already cached 4881238672
# cleared cache 4881238672
# V0312 22:34:10.851538 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fa80d0> with grad_outputs: [f32[]]
# recompute aten.sin.default
# recompute None
# already cached 4881721120
# done recompute and cached 4882103248
# cleared cache 4882103248
# V0312 22:34:10.851649 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fb5670> with grad_outputs: [f32[]]
# recompute None
# already cached 4881721120
# V0312 22:34:10.851719 8114936896 torch/autograd/graph.py:726] Executing: <ExpBackward0 object at 0x122fb5670> with grad_outputs: [f32[]]
# recompute None
# already cached 4881721120
# cleared cache 4881721120
# V0312 22:34:10.851792 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fb5670> with grad_outputs: [f32[]]
# recompute aten.sin.default
# recompute None
# already cached 4881149360
# done recompute and cached 4881721120
# cleared cache 4881721120
# V0312 22:34:10.851898 8114936896 torch/autograd/graph.py:726] Executing: <SinBackward0 object at 0x122fa80d0> with grad_outputs: [f32[]]
# recompute None
# already cached 4881149360
# cleared cache 4881149360
# V0312 22:34:10.851972 8114936896 torch/autograd/graph.py:726] Executing: <AccumulateGrad object at 0x122fa8a60> with grad_outputs: [f32[]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment