Skip to content

Instantly share code, notes, and snippets.

@marta-sd
Created November 19, 2022 12:57
Show Gist options
  • Save marta-sd/a19bbbc8b3322ae86425622edd2a7773 to your computer and use it in GitHub Desktop.
Save marta-sd/a19bbbc8b3322ae86425622edd2a7773 to your computer and use it in GitHub Desktop.
Example of using backward hooks
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Sequential(nn.Linear(8, 64), nn.ReLU(), nn.Linear(64, 2))
self.l2 = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 2))
self.relu = nn.ReLU()
def forward(self, x, y):
x = self.l1(x)
y = self.l2(y)
return self.relu(x+y)
net = Net()
def hook(module, grad_inputs, grad_outputs):
print('inputs')
for inp in grad_inputs:
print(inp if inp is None else inp.mean())
print('outputs')
for out in grad_outputs:
print(out if out is None else out.mean())
handle = net.register_full_backward_hook(hook)
out = net(torch.rand((2,8)), torch.rand((2, 64)))
(1 - out.mean()).backward()
handle.remove()
def hook_factory(name):
def hook(module, grad_inputs, grad_outputs):
print(name)
print('inputs')
for inp in grad_inputs:
print(inp if inp is None else inp.mean())
print('outputs')
for out in grad_outputs:
print(out if out is None else out.mean())
return hook
handles = []
for name, module in net.named_children():
hook = hook_factory(name)
handle = module.register_full_backward_hook(hook)
handles.append(handle)
print('registered hook for', name)
out = net(torch.rand((2,8)), torch.rand((2, 64)))
(1 - out.mean()).backward()
for h in handles:
h.remove()
handles = []
for name, module in net.named_modules():
hook = hook_factory(name)
handle = module.register_full_backward_hook(hook)
handles.append(handle)
print('registered hook for', name)
out = net(torch.rand((2,8)), torch.rand((2, 64)))
(1 - out.mean()).backward()
for h in handles:
h.remove()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment