Created
January 28, 2019 11:46
-
-
Save antoniopenta/afffbd953acbb4e89f934932dceb7c45 to your computer and use it in GitHub Desktop.
Classify images with pre-trained ResNet in Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from IPython.display import display # to display images | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import torchvision.models as models | |
import requests | |
from io import BytesIO | |
import urllib.request, json | |
import numpy as np | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalize | |
]) | |
def get_classification(url): | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)).convert("RGB") | |
#display(img) | |
resnet_model = models.resnet152(pretrained=True, num_classes=1000) | |
resnet_model.eval() | |
img_tensor = preprocess(img).float() | |
img_tensor = img_tensor.unsqueeze_(0) | |
fc_out = resnet_model(img_tensor) | |
output = fc_out.detach().numpy() | |
return output | |
# downalod imagenet labels | |
with urllib.request.urlopen('https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json') as url: | |
class_idx = json.loads(url.read().decode()) | |
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))] | |
#images | |
img1_url= 'https://i.postimg.cc/SRRHTcgG/Screen-Shot-2019-01-25-at-11-56-52.png' | |
img2_url = 'https://i.postimg.cc/hG9nT8Yz/Screen-Shot-2019-01-25-at-12-09-15.png' | |
# predict classes | |
predition_img1 = get_classification(img1_url) | |
predition_img2 = get_classification(img2_url) | |
# print top10 predictions | |
predition_img1_sorted_top10 = np.argsort(predition_img1,axis=1)[0].tolist()[-10:] | |
print('Prediction for image %s'%(img1_url)) | |
print('*'*20) | |
for idx in predition_img1_sorted_top10: | |
print(idx2label[idx],predition_img1[0][idx]) | |
print('*'*20) | |
predition_img2_sorted_top10 = np.argsort(predition_img2,axis=1)[0].tolist()[-10:] | |
print('Prediction for image %s'%(img2_url)) | |
print('*'*20) | |
for idx in predition_img2_sorted_top10: | |
print(idx2label[idx],predition_img2[0][idx]) | |
print('*'*20) | |
#Prediction for image https://i.postimg.cc/SRRHTcgG/Screen-Shot-2019-01-25-at-11-56-52.png | |
# ******************** | |
# otter 5.5231414 | |
# beaver 5.572069 | |
# bubble 5.6445265 | |
# goose 5.6843433 | |
# fountain 5.958186 | |
# hippopotamus 6.0290885 | |
# platypus 6.246428 | |
# flamingo 6.2740183 | |
# black_swan 7.0539513 | |
# American_alligator 7.1412888 | |
# ******************** | |
#Prediction for image https://i.postimg.cc/hG9nT8Yz/Screen-Shot-2019-01-25-at-12-09-15.png | |
# ******************** | |
# Weimaraner 8.421744 | |
# American_Staffordshire_terrier 8.620165 | |
# black-and-tan_coonhound 8.720882 | |
# Mexican_hairless 9.13973 | |
# miniature_pinscher 9.583361 | |
# Rhodesian_ridgeback 9.651596 | |
# Chihuahua 9.661165 | |
# redbone 10.895191 | |
# German_short-haired_pointer 11.339299 | |
# bluetick 12.9905615 | |
# ******************** |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment