Skip to content

Instantly share code, notes, and snippets.

@dschaehi
Last active March 15, 2022 21:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dschaehi/22038b2c5b50941bf605dbbeb1cf44d3 to your computer and use it in GitHub Desktop.
Save dschaehi/22038b2c5b50941bf605dbbeb1cf44d3 to your computer and use it in GitHub Desktop.
Extracting ResNet Features Using PyTorch
from collections import OrderedDict
from torchvision import models
def gen_feature_extractor(model, output_layer):
layers = OrderedDict()
for (k, v) in model._modules.items():
layers[k] = v
if k == output_layer:
break
return nn.Sequential(layers)
model = gen_feature_extractor(models.resnet18(pretrained=True), "layer3")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment