Skip to content

Instantly share code, notes, and snippets.

@zohaibmohammad
Forked from johnolafenwa/imagenet_inference.py
Last active May 25, 2020 12:51
Show Gist options
  • Save zohaibmohammad/75d55732e0c88dc448e9214bd20630a3 to your computer and use it in GitHub Desktop.
Save zohaibmohammad/75d55732e0c88dc448e9214bd20630a3 to your computer and use it in GitHub Desktop.
{"0": ["n01440764", "airplane"], "1": ["n01443537", "automobile"], "2": ["n01484850", "bird"], "3": ["n01491361", "cat"], "4": ["n01494475", "deer"], "5": ["n01496331", "dog"], "6": ["n01498041", "frog"], "7": ["n01514668", "horse"], "8": ["n01514859", "ship"], "9": ["n01518878", "Truck"]}
{"0": ["n01440764", "airplane"], "1": ["n01443537", "automobile"], "2": ["n01484850", "bird"], "3": ["n01491361", "cat"], "4": ["n01494475", "deer"], "5": ["n01496331", "dog"], "6": ["n01498041", "frog"], "7": ["n01514668", "horse"], "8": ["n01514859", "ship"], "9": ["n01518878", "Truck"]}
# Import needed packages
import torch
import torch.nn as nn
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
from torchvision.models import squeezenet1_1
import torch.functional as F
import requests
import shutil
from io import open
import os
from PIL import Image
import json
""" Instantiate model, this downloads tje 4.7 mb squzzene the first time it is called.
To use with your own model, re-define your trained networks ad load weights as below
checkpoint = torch.load("pathtosavemodel")
model = SimpleNet(num_classes=10)
model.load_state_dict(checkpoint)
model.eval()
example:
checkpoint = torch.load("cifar10model_69.model")
model = SimpleNet(num_classes=10)
model.load_state_dict(checkpoint)
model.eval()
"""
model = squeezenet1_1(pretrained=True)
model.eval()
def predict_image(image_path):
print("Prediction in progress")
image = Image.open(image_path)
# Define transformations for the image, should (note that imagenet models are trained with image size 224)
transformation = transforms.Compose([
# use only one transforms.CenterCrop(224) or transforms.Resize(32)
#transforms.CenterCrop(224), # for training on ImageNet
transforms.Resize(32), # for training on cifar10
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Preprocess the image
image_tensor = transformation(image).float()
# Add an extra batch dimension since pytorch treats all images as batches
image_tensor = image_tensor.unsqueeze_(0)
if torch.cuda.is_available():
image_tensor.cuda()
# Turn the input into a Variable
input = Variable(image_tensor)
# Predict the class of the image
output = model(input)
index = output.data.numpy().argmax()
return index
if __name__ == "__main__":
imagefile = "image.png"
imagepath = os.path.join(os.getcwd(), imagefile)
# Donwload image if it doesn't exist
if not os.path.exists(imagepath):
data = requests.get(
"https://github.com/OlafenwaMoses/ImageAI/raw/master/images/3.jpg", stream=True)
with open(imagepath, "wb") as file:
shutil.copyfileobj(data.raw, file)
del data
index_file = "class_index_map.json"
indexpath = os.path.join(os.getcwd(), index_file)
# Donwload class index if it doesn't exist
if not os.path.exists(indexpath):
data = requests.get('https://github.com/OlafenwaMoses/ImageAI/raw/master/imagenet_class_index.json')
with open(indexpath, "w", encoding="utf-8") as file:
file.write(data.text)
class_map = json.load(open(indexpath))
# run prediction function annd obtain prediccted class index
index = predict_image(imagepath)
prediction = class_map[str(index)][1]
print("Predicted Class ", prediction)
@zohaibmohammad
Copy link
Author

The .JSON file was missing. You can download from here.

{"0": ["n01440764", "airplane"], "1": ["n01443537", "automobile"], "2": ["n01484850", "bird"], "3": ["n01491361", "cat"], "4": ["n01494475", "deer"], "5": ["n01496331", "dog"], "6": ["n01498041", "frog"], "7": ["n01514668", "horse"], "8": ["n01514859", "ship"], "9": ["n01518878", "Truck"]}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment