Skip to content

Instantly share code, notes, and snippets.

@ndronen
Created April 25, 2020 09:06
Show Gist options
  • Save ndronen/bf44e29e98e3d774c621e24242a0ab5c to your computer and use it in GitHub Desktop.
Save ndronen/bf44e29e98e3d774c621e24242a0ab5c to your computer and use it in GitHub Desktop.
Straw man proposal for PyTorch issue
"""
This is a straw man proposal to begin discussion of how to change the
PyTorch hooks API to support capture/inspection/modification of
keyword arguments.
https://github.com/pytorch/pytorch/issues/35643
"""
import unittest
import torch
from collections import OrderedDict
class Module:
def __init__(self):
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
def forward(self, *args, **kwargs):
raise NotImplementedError()
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
# The try/except block is an inelegant hack.
try:
result = hook(self, input, kwargs)
except TypeError as e:
if 'takes 2 positional' in str(e):
result = hook(self, input)
else:
raise e
if result is not None:
# Client possibly modified the input.
if isinstance(result, tuple):
if len(result) == 2 and isinstance(result[0], tuple) \
and isinstance(result[1], dict):
# Client possibly modified positional and keyword args.
input = result[0]
kwargs.update(result[1])
else:
# Client possibly modified positional args.
input = result
else:
# Client possibly modified positional args, returned
# non-tuple.
input = (result,)
result = self.forward(*input, **kwargs)
# TODO
# Forward hooks
# Backward hooks
return result
def register_forward_pre_hook(self, hook):
handle = torch.utils.hooks.RemovableHandle(self._forward_pre_hooks)
self._forward_pre_hooks[handle.id] = hook
return handle
def register_forward_hook(self, hook):
handle = torch.utils.hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
return handle
def register_backward_hook(self, hook):
handle = torch.utils.hooks.RemovableHandle(self._forward_hooks)
self._backward_hooks[handle.id] = hook
return handle
class BinaryOrTernarySum(Module):
def __init__(self):
super().__init__()
def forward(self, x, y, z=None):
output = x + y
if z is not None:
output += z
return output
def forward_pre_hook_with_kwargs(module, input, kwargs=None):
"""Increment positional and keyword arguments by one.
"""
assert isinstance(input, tuple)
if kwargs is not None:
assert isinstance(kwargs, dict)
for i in input:
i += 1
for k, v in kwargs.items():
v += 1
return input, kwargs
def forward_pre_hook_backward_compatibility(module, input):
"""Increment positional arguments by one.
"""
assert isinstance(input, tuple)
for i in input:
i += 1
return input
class TestModuleHooks(unittest.TestCase):
def setUp(self):
self.n = 2
self.module = BinaryOrTernarySum()
self.x = torch.zeros(self.n)
self.y = torch.ones(self.n)
self.z = torch.ones(self.n) * 2
def test_baseline(self):
# Without forward pre-hook, should be 0 + 1 + 2 = 3.
expected = torch.ones(self.n) * 3
actual = self.module(self.x, self.y, z=self.z)
self.assertTrue(torch.all(actual == expected))
def test_forward_pre_hook_backward_compatibility(self):
handle = self.module.register_forward_pre_hook(
forward_pre_hook_backward_compatibility
)
# The result should be x=1 + y=2 + z=2 = 5.
expected = torch.ones(self.n) * 5
actual = self.module(self.x, self.y, z=self.z)
self.assertTrue(torch.all(actual == expected))
handle.remove()
def test_forward_pre_hook_using_kwargs(self):
# The result should be x=1 + y=2 + z=3 = 6.
handle = self.module.register_forward_pre_hook(
forward_pre_hook_with_kwargs
)
expected = torch.ones(self.n) * 6
actual = self.module(self.x, self.y, z=self.z)
self.assertTrue(torch.all(actual == expected))
handle.remove()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment