Skip to content

Instantly share code, notes, and snippets.

@danesherbs
Last active February 11, 2024 01:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danesherbs/91237e0b6e1534c7248377de549c875a to your computer and use it in GitHub Desktop.
Save danesherbs/91237e0b6e1534c7248377de549c875a to your computer and use it in GitHub Desktop.
A PyTorch hook that's registered in a `with` statement
# This hook is particularly useful when ablating layers
class ContextHook:
def __init__(self, layer):
self.layer = layer
def __enter__(self):
self.handle = self.layer.register_forward_hook(self.hook)
return self
def __exit__(self, type, value, traceback):
self.handle.remove()
def hook(self, module, input, output):
pass # hook logic goes here
input = ...
net = ... # some neural net
out = net(input)
with ContextHook(net.my_layer) as modified_net:
modified_out = net(input)
print(out)
print(modified_out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment