Skip to content

Instantly share code, notes, and snippets.

@antoniopenta
Created January 28, 2019 11:46
Show Gist options
  • Save antoniopenta/afffbd953acbb4e89f934932dceb7c45 to your computer and use it in GitHub Desktop.
Save antoniopenta/afffbd953acbb4e89f934932dceb7c45 to your computer and use it in GitHub Desktop.
Classify images with pre-trained ResNet in Pytorch
from IPython.display import display # to display images
import torch
from torchvision import transforms
from PIL import Image
import torchvision.models as models
import requests
from io import BytesIO
import urllib.request, json
import numpy as np
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
def get_classification(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
#display(img)
resnet_model = models.resnet152(pretrained=True, num_classes=1000)
resnet_model.eval()
img_tensor = preprocess(img).float()
img_tensor = img_tensor.unsqueeze_(0)
fc_out = resnet_model(img_tensor)
output = fc_out.detach().numpy()
return output
# downalod imagenet labels
with urllib.request.urlopen('https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json') as url:
class_idx = json.loads(url.read().decode())
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
#images
img1_url= 'https://i.postimg.cc/SRRHTcgG/Screen-Shot-2019-01-25-at-11-56-52.png'
img2_url = 'https://i.postimg.cc/hG9nT8Yz/Screen-Shot-2019-01-25-at-12-09-15.png'
# predict classes
predition_img1 = get_classification(img1_url)
predition_img2 = get_classification(img2_url)
# print top10 predictions
predition_img1_sorted_top10 = np.argsort(predition_img1,axis=1)[0].tolist()[-10:]
print('Prediction for image %s'%(img1_url))
print('*'*20)
for idx in predition_img1_sorted_top10:
print(idx2label[idx],predition_img1[0][idx])
print('*'*20)
predition_img2_sorted_top10 = np.argsort(predition_img2,axis=1)[0].tolist()[-10:]
print('Prediction for image %s'%(img2_url))
print('*'*20)
for idx in predition_img2_sorted_top10:
print(idx2label[idx],predition_img2[0][idx])
print('*'*20)
#Prediction for image https://i.postimg.cc/SRRHTcgG/Screen-Shot-2019-01-25-at-11-56-52.png
# ********************
# otter 5.5231414
# beaver 5.572069
# bubble 5.6445265
# goose 5.6843433
# fountain 5.958186
# hippopotamus 6.0290885
# platypus 6.246428
# flamingo 6.2740183
# black_swan 7.0539513
# American_alligator 7.1412888
# ********************
#Prediction for image https://i.postimg.cc/hG9nT8Yz/Screen-Shot-2019-01-25-at-12-09-15.png
# ********************
# Weimaraner 8.421744
# American_Staffordshire_terrier 8.620165
# black-and-tan_coonhound 8.720882
# Mexican_hairless 9.13973
# miniature_pinscher 9.583361
# Rhodesian_ridgeback 9.651596
# Chihuahua 9.661165
# redbone 10.895191
# German_short-haired_pointer 11.339299
# bluetick 12.9905615
# ********************
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment