Skip to content

Instantly share code, notes, and snippets.

@donalod
Last active December 13, 2018 10:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save donalod/15fd99a43669a6cb98a17a40fc3b8679 to your computer and use it in GitHub Desktop.
Save donalod/15fd99a43669a6cb98a17a40fc3b8679 to your computer and use it in GitHub Desktop.
ai_ml_mxnet_image_identification.py
#!/usr/bin/env python
## 'pycodestyle' and 'autopep8 --in-place'
# Labs from https://github.com/drandrewkane/AI_ML_Workshops
# All slides from all sessions https://www.slideshare.net/AmazonWebServices/tag/aws-builders-days
# AWS Builders Day 2018 in Dublin at the Aviva Stadium
# wget http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-symbol.json
# wget -O Inception-BN-0000.params http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-0126.params
# wget http://data.dmlc.ml/models/imagenet/synset.txt
# wget -O image0.jpeg https://cdn-images-1.medium.com/max/1600/1*sPdrfGtDd_6RQfYvD5qcyg.jpeg
# wget -O image1.jpeg http://kidszoo.org/wp-content/uploads/2015/02/clownfish3-1500x630.jpg
# wget http://data.dmlc.ml/models/imagenet/vgg/vgg16-symbol.json
# wget http://data.dmlc.ml/models/imagenet/vgg/vgg16-0000.params
# wget http://data.dmlc.ml/models/imagenet/resnet/152-layers/resnet-152-symbol.json
# wget http://data.dmlc.ml/models/imagenet/resnet/152-layers/resnet-152-0000.params
import mxnet as mx
import numpy as np
import cv2
from collections import namedtuple
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)
mod = mx.mod.Module(symbol=sym)
mod.bind(for_training=False, data_shapes=[('data', (1, 3, 224, 224))])
mod.set_params(arg_params, aux_params)
img = cv2.imread('image1.jpeg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224,))
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]
array = mx.nd.array(img)
print "Array shape: "+str(array.shape)
Batch = namedtuple('Batch', ['data'])
mod.forward(Batch([array]))
prob = mod.get_outputs()[0].asnumpy()
print "Probability shape: "+str(prob.shape)
prob = np.squeeze(prob)
print "Array shape: "+str(prob.shape)
print "Probabilities: "+str(prob)
sortedprob = np.argsort(prob)[::-1]
print "Sorted probability: "+str(sortedprob.shape)
print "It's probably a: "+str(prob[int(sortedprob[0])])
synsetfile = open('synset.txt', 'r')
categorylist = []
for line in synsetfile:
categorylist.append(line.rstrip())
print "We think it's: "+str(categorylist[sortedprob[0]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment