Skip to content

Instantly share code, notes, and snippets.

@ruslangrimov
Created May 29, 2020 19:40
Show Gist options
  • Save ruslangrimov/cf01f5db03e185e8dcec157d8965fa90 to your computer and use it in GitHub Desktop.
Save ruslangrimov/cf01f5db03e185e8dcec157d8965fa90 to your computer and use it in GitHub Desktop.
Get outputs of each layer using hooks
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
save_output = SaveOutput()
hook_handles = []
for layer in model.modules():
if isinstance(layer, torch.nn.modules.conv.Conv2d):
handle = layer.register_forward_hook(save_output)
hook_handles.append(handle)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment