Skip to content

Instantly share code, notes, and snippets.

@johnolafenwa
Created April 27, 2018 04:34
Show Gist options
  • Save johnolafenwa/f58816947b2cd325b2e49309de39e2ab to your computer and use it in GitHub Desktop.
Save johnolafenwa/f58816947b2cd325b2e49309de39e2ab to your computer and use it in GitHub Desktop.
def predict_image(image_path):
print("Prediction in progress")
image = Image.open(image_path)
# Define transformations for the image, should (note that imagenet models are trained with image size 224)
transformation = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Preprocess the image
image_tensor = transformation(image).float()
# Add an extra batch dimension since pytorch treats all images as batches
image_tensor = image_tensor.unsqueeze_(0)
if torch.cuda.is_available():
image_tensor.cuda()
# Turn the input into a Variable
input = Variable(image_tensor)
# Predict the class of the image
output = model(input)
index = output.data.numpy().argmax()
return index
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment