Skip to content

Instantly share code, notes, and snippets.

@david-macleod
Last active April 21, 2020 15:45
Show Gist options
  • Save david-macleod/75f71be386d7df21c89ab73fe5588c26 to your computer and use it in GitHub Desktop.
Save david-macleod/75f71be386d7df21c89ab73fe5588c26 to your computer and use it in GitHub Desktop.
pytorch hook
layer_dict = {}
def inspect_layer(module, input, output):
global layer_dict
layer_dict[module] = (input, output)
handles = [layer.register_forward_hook(inspect_layer) for layer in output_layers]
# handles[0].remove()
# class version (can be a nn.Module)
class Hooks:
def __init__(self, modules):
self.hooks = [m.register_forward_hook(self.hook_fn) for m in modules]
self.hooked_values = {}
def hook_fn(self, module, input, output):
self.hooked_values[module] = (input, output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment