Last active
August 26, 2018 14:49
-
-
Save kakittwo/768e5477a5834ee84624205659e4d415 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.applications.resnet50 import ResNet50 | |
from keras.preprocessing import image | |
from keras.applications.resnet50 import preprocess_input, decode_predictions | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
### Subroutine to calculate likelihood for hummingbird in an image | |
def calculateLikelihood(x): | |
x = np.expand_dims(x, axis=0) | |
x = preprocess_input(x) | |
preds = model.predict(x) | |
# The class index for hummingbird is 94 in ImageNet. | |
return preds[0][94] | |
### Load model and define parameters | |
model = ResNet50(weights='imagenet')dim = 861 | |
height = 1080 | |
width = 1920 | |
expecteddim = 224 | |
data = [] | |
N = 48 | |
### Loop through the images | |
for n in range(1, N): | |
# Read the image | |
name = 'thumb' + '{0:04d}'.format(n) + '.jpg' | |
img_path = name | |
img = image.load_img(img_path) | |
# Crop the image, resize it and make prediction | |
startx = 0 | |
starty = 0 | |
endx = startx + dim | |
endy = starty + dim | |
cropped_img = img.crop((startx, starty, endx, endy)) | |
resize_img = cropped_img.resize((expecteddim, expecteddim), Image.ANTIALIAS) | |
x = image.img_to_array(resize_img) | |
p1 = calculateLikelihood(x) | |
# Repeat above but close to another flower | |
startx = width - dim | |
starty = 0 | |
endx = startx + dim | |
endy = starty + dim | |
cropped_img = img.crop((startx, starty, endx, endy)) | |
resize_img = cropped_img.resize((expecteddim, expecteddim), Image.ANTIALIAS) | |
x = image.img_to_array(resize_img) | |
p2 = calculateLikelihood(x) | |
# Take the maximum likelihood as the return likelihood | |
print(n, max(p1, p2)) | |
data.append(max(p1, p2)) | |
### Compare with ground truth labels | |
trueXList = [] | |
trueYList = [] | |
falseXList = [] | |
falseYList = [] | |
f = open('gt.txt', 'r') | |
tmp = f.readline().rstrip() | |
for i in range(1, N): | |
info = tmp.split() | |
index = int(info[0]) | |
if info[1] == 't': | |
trueXList.append(index) | |
trueYList.append(data[index - 1]) | |
else: | |
falseXList.append(index) | |
falseYList.append(data[index - 1]) | |
tmp = f.readline().rstrip() | |
f.close() | |
### Generate plots | |
plt.scatter(falseXList, falseYList, color='red') | |
plt.scatter(trueXList, trueYList, color='green') | |
plt.hlines(y=0.05, xmin=0, xmax=N, color='b') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment