Skip to content

Instantly share code, notes, and snippets.

@amaarora
Last active January 15, 2022 08:00
Show Gist options
  • Save amaarora/139dfa53754e64e66383630e00a0de88 to your computer and use it in GitHub Desktop.
Save amaarora/139dfa53754e64e66383630e00a0de88 to your computer and use it in GitHub Desktop.
class FeatureExtractor(nn.Module):
def __init__(self, model, layer_names):
super().__init__()
self.model = model
self.layer_names = layer_names
self._features = defaultdict(list)
layer_dict = dict([*self.model.named_modules()])
for layer_name in layer_names:
layer = layer_dict[layer_name]
layer.register_forward_hook(self.save_outputs_hook(layer_name))
def save_outputs_hook(self, layer_name):
def fn(_, __, output):
self._features[layer_name] = output
return fn
def forward(self, x):
_ = self.model(x)
return self._features
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment