Skip to content

Instantly share code, notes, and snippets.

@zeryx
Created September 19, 2019 18:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeryx/88ee0cb5d158a43db5e6abc20368b68d to your computer and use it in GitHub Desktop.
Save zeryx/88ee0cb5d158a43db5e6abc20368b68d to your computer and use it in GitHub Desktop.
import torch as th
from torchvision import models, transforms
from PIL import Image
import numpy as np
model = models.resnet18(pretrained=True)
class Identity(th.nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
model.fc = Identity()
model = th.jit.script(model)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
crop = transforms.CenterCrop(224)
# th.save(model, "/tmp/resnet18-headless.th")
image = Image.open("/tmp/dog.jpg")
# image_th = th.Tensor(np.asarray(image))
image_th = th.Tensor(np.asarray(crop(image)))
image_th = normalize(image_th)
image_th = image_th.reshape(1,3, 224, 224)
result = model.forward(image_th)
print(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment