Skip to content

Instantly share code, notes, and snippets.

@zeryx
Last active March 31, 2020 17:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeryx/039bcfe6d4f60b9e0a4ccf309fd5c029 to your computer and use it in GitHub Desktop.
Save zeryx/039bcfe6d4f60b9e0a4ccf309fd5c029 to your computer and use it in GitHub Desktop.
import torch as th
from torchvision import models, transforms
from PIL import Image
import numpy as np
resnet18 = models.resnet18(pretrained=True)
# This stage below removes the last module (aka the softmax layer) to provide raw data for fine-tuning/transfer learning
modules=list(resnet18.children())[:-1]
model = th.nn.Sequential(*modules)
th.save(model, "/tmp/resnet18-headless.th")
# Lets check to see if the model was serialized properly
reloaded = th.load("/tmp/resnet18-headless.th")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment