Skip to content

Instantly share code, notes, and snippets.

@gregchu
Created March 10, 2017 06:51
Show Gist options
  • Save gregchu/e0839445113296c29c7b7dc2722a1700 to your computer and use it in GitHub Desktop.
Save gregchu/e0839445113296c29c7b7dc2722a1700 to your computer and use it in GitHub Desktop.
def predict(model, img, target_size, top_n=3):
"""Run model prediction on image
Args:
model: keras model
img: PIL format image
target_size: (w,h) tuple
top_n: # of top predictions to return
Returns:
np array of predictions
"""
if img.size != target_size:
img = img.resize(target_size)
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
return decode_predictions(preds, top=top_n)[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment