Skip to content

Instantly share code, notes, and snippets.

@pr2tik1
Created July 21, 2020 18:23
Show Gist options
  • Save pr2tik1/2856e99561aec917922f88aa17c3f0d9 to your computer and use it in GitHub Desktop.
Save pr2tik1/2856e99561aec917922f88aa17c3f0d9 to your computer and use it in GitHub Desktop.
import ast
from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.models as models
from torch import __version__
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
models = {'resnet': resnet18, 'alexnet': alexnet, 'vgg': vgg16}
# obtain ImageNet labels
with open('imagenet1000_clsid_to_human.txt') as imagenet_classes_file:
imagenet_classes_dict = ast.literal_eval(imagenet_classes_file.read())
def classifier(img_path, model_name):
# load the image
img_pil = Image.open(img_path)
# define transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# preprocess the image
img_tensor = preprocess(img_pil)
img_tensor.unsqueeze_(0)
pytorch_ver = __version__.split('.')
if int(pytorch_ver[0]) > 0 or int(pytorch_ver[1]) >= 4:
img_tensor.requires_grad_(False)
# pytorch versions less than 0.4 - uses Variable because not-depreciated
else:
# apply model to input
# wrap input in variable
data = Variable(img_tensor, volatile = True)
# apply model to input
model = models[model_name]
model = model.eval()
# apply data to model - adjusted based upon version to account for
# operating on a Tensor for version 0.4 & higher.
if int(pytorch_ver[0]) > 0 or int(pytorch_ver[1]) >= 4:
output = model(img_tensor)
# pytorch versions less than 0.4
else:
# apply data to model
output = model(data)
pred_idx = output.data.numpy().argmax()
return imagenet_classes_dict[pred_idx]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment