Created
July 31, 2021 21:16
-
-
Save discort/ae6c578c15adcd01eeb4b88791991141 to your computer and use it in GitHub Desktop.
PyTorch hooks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# based on https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904 | |
class VerboseExecution(nn.Module): | |
def __init__(self, model: nn.Module): | |
super().__init__() | |
self.model = model | |
# Register a hook for each layer | |
for name, layer in self.model.named_children(): | |
layer.__name__ = name | |
layer.register_forward_hook( | |
lambda layer, _, output: print(f"{layer.__name__}: {output.shape}") | |
) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.model(x) | |
import torch | |
from torchvision.models import resnet50 | |
verbose_resnet = VerboseExecution(resnet50()) | |
dummy_input = torch.ones(10, 3, 224, 224) | |
_ = verbose_resnet(dummy_input) | |
# conv1: torch.Size([10, 64, 112, 112]) | |
# bn1: torch.Size([10, 64, 112, 112]) | |
# relu: torch.Size([10, 64, 112, 112]) | |
# maxpool: torch.Size([10, 64, 56, 56]) | |
# layer1: torch.Size([10, 256, 56, 56]) | |
# layer2: torch.Size([10, 512, 28, 28]) | |
# layer3: torch.Size([10, 1024, 14, 14]) | |
# layer4: torch.Size([10, 2048, 7, 7]) | |
# avgpool: torch.Size([10, 2048, 1, 1]) | |
# fc: torch.Size([10, 1000]) | |
from typing import Dict, Iterable, Callable | |
class FeatureExtractor(nn.Module): | |
def __init__(self, model: nn.Module, layers: Iterable[str]): | |
super().__init__() | |
self.model = model | |
self.layers = layers | |
self._features = {layer: torch.empty(0) for layer in layers} | |
for layer_id in layers: | |
layer = dict([*self.model.named_modules()])[layer_id] | |
layer.register_forward_hook(self.save_outputs_hook(layer_id)) | |
def save_outputs_hook(self, layer_id: str) -> Callable: | |
def fn(_, __, output): | |
self._features[layer_id] = output | |
return fn | |
def forward(self, x: Tensor) -> Dict[str, Tensor]: | |
_ = self.model(x) | |
return self._features | |
resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"]) | |
features = resnet_features(dummy_input) | |
print({name: output.shape for name, output in features.items()}) | |
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment