Last active
February 11, 2024 01:04
-
-
Save danesherbs/91237e0b6e1534c7248377de549c875a to your computer and use it in GitHub Desktop.
A PyTorch hook that's registered in a `with` statement
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This hook is particularly useful when ablating layers | |
class ContextHook: | |
def __init__(self, layer): | |
self.layer = layer | |
def __enter__(self): | |
self.handle = self.layer.register_forward_hook(self.hook) | |
return self | |
def __exit__(self, type, value, traceback): | |
self.handle.remove() | |
def hook(self, module, input, output): | |
pass # hook logic goes here | |
input = ... | |
net = ... # some neural net | |
out = net(input) | |
with ContextHook(net.my_layer) as modified_net: | |
modified_out = net(input) | |
print(out) | |
print(modified_out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment