Skip to content

Instantly share code, notes, and snippets.

@kakittwo
Last active August 26, 2018 14:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kakittwo/768e5477a5834ee84624205659e4d415 to your computer and use it in GitHub Desktop.
Save kakittwo/768e5477a5834ee84624205659e4d415 to your computer and use it in GitHub Desktop.
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
### Subroutine to calculate likelihood for hummingbird in an image
def calculateLikelihood(x):
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
# The class index for hummingbird is 94 in ImageNet.
return preds[0][94]
### Load model and define parameters
model = ResNet50(weights='imagenet')dim = 861
height = 1080
width = 1920
expecteddim = 224
data = []
N = 48
### Loop through the images
for n in range(1, N):
# Read the image
name = 'thumb' + '{0:04d}'.format(n) + '.jpg'
img_path = name
img = image.load_img(img_path)
# Crop the image, resize it and make prediction
startx = 0
starty = 0
endx = startx + dim
endy = starty + dim
cropped_img = img.crop((startx, starty, endx, endy))
resize_img = cropped_img.resize((expecteddim, expecteddim), Image.ANTIALIAS)
x = image.img_to_array(resize_img)
p1 = calculateLikelihood(x)
# Repeat above but close to another flower
startx = width - dim
starty = 0
endx = startx + dim
endy = starty + dim
cropped_img = img.crop((startx, starty, endx, endy))
resize_img = cropped_img.resize((expecteddim, expecteddim), Image.ANTIALIAS)
x = image.img_to_array(resize_img)
p2 = calculateLikelihood(x)
# Take the maximum likelihood as the return likelihood
print(n, max(p1, p2))
data.append(max(p1, p2))
### Compare with ground truth labels
trueXList = []
trueYList = []
falseXList = []
falseYList = []
f = open('gt.txt', 'r')
tmp = f.readline().rstrip()
for i in range(1, N):
info = tmp.split()
index = int(info[0])
if info[1] == 't':
trueXList.append(index)
trueYList.append(data[index - 1])
else:
falseXList.append(index)
falseYList.append(data[index - 1])
tmp = f.readline().rstrip()
f.close()
### Generate plots
plt.scatter(falseXList, falseYList, color='red')
plt.scatter(trueXList, trueYList, color='green')
plt.hlines(y=0.05, xmin=0, xmax=N, color='b')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment