Skip to content

Instantly share code, notes, and snippets.

@yusuke0519
Last active February 2, 2021 10:01
Show Gist options
  • Save yusuke0519/4945c213a49332d683c77203c62a4247 to your computer and use it in GitHub Desktop.
Save yusuke0519/4945c213a49332d683c77203c62a4247 to your computer and use it in GitHub Desktop.
Store activations during the forward path using hooks.
class MLPEncoder(nn.Module):
def __init__(self):
super(MLPEncoder, self).__init__()
# TODO: Fix hard coding
self.model = nn.Sequential(OrderedDict([
('layer1', nn.Linear(784, 400)),
('relu1', nn.ReLU()),
('layer2', nn.Linear(400, 400)),
('relu2', nn.ReLU()),
('layer3', nn.Linear(400, 200)),
('relu3', nn.ReLU()),
('layer4', nn.Linear(200, 200)),
('relu4', nn.ReLU())
]))
print(self.model)
# add hook to store activatiosns
self.activations = {}
def store_activations(model, input, output):
self.activations[model.__name__] = output
for name, layer in self.model.named_children():
layer.__name__ = name
layer.register_forward_hook(store_activations)
def forward(self, x):
return self.model(x.view(-1, 784))
def get_activations(name):
return self.activations[name]
@yusuke0519
Copy link
Author

yusuke0519 commented Feb 2, 2021

Hookを使ってpytorchの各中間層の出力を獲得する.
Lossを計算する場合を想定してdetachはしてない.使わずに可視化するだけとかならself.activations[model.__name__] = output.detach()したほうが良い.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment