Skip to content

Instantly share code, notes, and snippets.

@albanD
Last active August 8, 2023 07:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save albanD/18c240bd2e09f9d93f5c4a0c9ccda39e to your computer and use it in GitHub Desktop.
Save albanD/18c240bd2e09f9d93f5c4a0c9ccda39e to your computer and use it in GitHub Desktop.
PyTorch optimizer as hook
import torch
from torch import nn
from torch.optim.sgd import sgd
import gc
import objgraph
import weakref
def all():
# Only a subset of the args you could have
def set_sgd_hook(mod, p, lr, weight_decay, momentum):
buff_list = [None]
acc_grad = p.view_as(p).grad_fn.next_functions[0][0]
# The grad accumulator is a weak ref, so we need to keep it
# alive until the Tensor is alive.
# Store it on the module to avoid uncollectable ref-cycle
if not hasattr(mod, "_acc_grads"):
mod._acc_grads = []
mod._acc_grads.append(acc_grad)
def sgd_hook(*_unused):
# Update the params
sgd([p], [p.grad], buff_list, has_sparse_grad=False, foreach=False,
weight_decay=weight_decay, momentum=momentum, lr=lr, dampening=0,
nesterov=False, maximize=False)
# Free up grad memory
p.grad = None
# We should have an API for post hooks... But we don't have one right now
acc_grad.register_hook(sgd_hook)
print("Startup", torch.cuda.memory_allocated())
mod = torch.nn.Linear(4, 1).cuda()
crit = nn.MSELoss()
for p in mod.parameters():
set_sgd_hook(mod, p, lr=.01, weight_decay=0., momentum=0.9)
# Make sure the keepalive works well
gc.collect()
inp = torch.rand(10, 4, device="cuda")
target = torch.rand(10, 1, device="cuda")
for i in range(11):
def eval_one():
print(f"It {i}, {torch.cuda.memory_allocated()}")
pred = mod(inp)
loss = crit(pred, target)
print("Before backward", torch.cuda.memory_allocated())
loss.backward()
print(f"Loss: {loss.item()}")
eval_one()
if i == 0:
print("No memory decrease due to optimizer state lazy initialization")
print("End of iteration", torch.cuda.memory_allocated())
return weakref.ref(mod.weight)
w = all()
print("Done, final memory", torch.cuda.memory_allocated())
@albanD
Copy link
Author

albanD commented Aug 8, 2022

Output from the script:

$ python opt_as_hook.py
Startup 0
It 0, 2048
Before backward 3072
Loss: 0.795136034488678
No memory decrease due to optimizer state lazy initialization
End of iteration 3072
It 1, 3072
Before backward 4096
Loss: 0.7405098080635071
End of iteration 3072
It 2, 3072
Before backward 4096
Loss: 0.6442667841911316
End of iteration 3072
It 3, 3072
Before backward 4096
Loss: 0.5232660174369812
End of iteration 3072
It 4, 3072
Before backward 4096
Loss: 0.3950752913951874
End of iteration 3072
It 5, 3072
Before backward 4096
Loss: 0.2753334045410156
End of iteration 3072
It 6, 3072
Before backward 4096
Loss: 0.17588858306407928
End of iteration 3072
It 7, 3072
Before backward 4096
Loss: 0.1038241907954216
End of iteration 3072
It 8, 3072
Before backward 4096
Loss: 0.06134021282196045
End of iteration 3072
It 9, 3072
Before backward 4096
Loss: 0.046344172209501266
End of iteration 3072
It 10, 3072
Before backward 4096
Loss: 0.05353296920657158
End of iteration 3072
Done, final memory 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment