Created
July 21, 2020 18:23
-
-
Save pr2tik1/2856e99561aec917922f88aa17c3f0d9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from torch.autograd import Variable | |
import torchvision.models as models | |
from torch import __version__ | |
resnet18 = models.resnet18(pretrained=True) | |
alexnet = models.alexnet(pretrained=True) | |
vgg16 = models.vgg16(pretrained=True) | |
models = {'resnet': resnet18, 'alexnet': alexnet, 'vgg': vgg16} | |
# obtain ImageNet labels | |
with open('imagenet1000_clsid_to_human.txt') as imagenet_classes_file: | |
imagenet_classes_dict = ast.literal_eval(imagenet_classes_file.read()) | |
def classifier(img_path, model_name): | |
# load the image | |
img_pil = Image.open(img_path) | |
# define transforms | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# preprocess the image | |
img_tensor = preprocess(img_pil) | |
img_tensor.unsqueeze_(0) | |
pytorch_ver = __version__.split('.') | |
if int(pytorch_ver[0]) > 0 or int(pytorch_ver[1]) >= 4: | |
img_tensor.requires_grad_(False) | |
# pytorch versions less than 0.4 - uses Variable because not-depreciated | |
else: | |
# apply model to input | |
# wrap input in variable | |
data = Variable(img_tensor, volatile = True) | |
# apply model to input | |
model = models[model_name] | |
model = model.eval() | |
# apply data to model - adjusted based upon version to account for | |
# operating on a Tensor for version 0.4 & higher. | |
if int(pytorch_ver[0]) > 0 or int(pytorch_ver[1]) >= 4: | |
output = model(img_tensor) | |
# pytorch versions less than 0.4 | |
else: | |
# apply data to model | |
output = model(data) | |
pred_idx = output.data.numpy().argmax() | |
return imagenet_classes_dict[pred_idx] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment