Last active
March 15, 2022 21:27
-
-
Save dschaehi/22038b2c5b50941bf605dbbeb1cf44d3 to your computer and use it in GitHub Desktop.
Extracting ResNet Features Using PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
from torchvision import models | |
def gen_feature_extractor(model, output_layer): | |
layers = OrderedDict() | |
for (k, v) in model._modules.items(): | |
layers[k] = v | |
if k == output_layer: | |
break | |
return nn.Sequential(layers) | |
model = gen_feature_extractor(models.resnet18(pretrained=True), "layer3") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment