Skip to content

Instantly share code, notes, and snippets.

@zeryx
Created September 24, 2019 18:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeryx/eb91c80f01975f446e717d97f551e0c5 to your computer and use it in GitHub Desktop.
Save zeryx/eb91c80f01975f446e717d97f551e0c5 to your computer and use it in GitHub Desktop.
takes a pytorch resnet 18 model, strips off the last module and checks to make sure serialization still works properly.
import torch as th
from torchvision import models, transforms
from PIL import Image
import numpy as np
resnet18 = models.resnet18(pretrained=True)
modules=list(resnet18.children())[:-1]
model = th.nn.Sequential(*modules)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
crop = transforms.CenterCrop(224)
image = Image.open("/tmp/dog.jpg")
image_th = th.Tensor(np.asarray(crop(image)))
image_th = normalize(image_th)
image_th = image_th.reshape(1,3, 224, 224)
result = model.forward(image_th)
print(result)
th.save(model, "/tmp/resnet18-headless.th")
reloaded = th.load("/tmp/resnet18-headless.th")
newresult = reloaded.forward(image_th)
print(newresult)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment