Skip to content

Instantly share code, notes, and snippets.

@juliensimon
Last active November 21, 2018 22:08
Show Gist options
  • Save juliensimon/ed9ef71ae35ae1d5048dd14bbefc552a to your computer and use it in GitHub Desktop.
Save juliensimon/ed9ef71ae35ae1d5048dd14bbefc552a to your computer and use it in GitHub Desktop.
MXNet + 3 CNNs
import mxnet as mx
import numpy as np
import cv2,sys,time
from collections import namedtuple
def loadModel(modelname):
t1 = time.time()
sym, arg_params, aux_params = mx.model.load_checkpoint(modelname, 0)
t2 = time.time()
t = 1000*(t2-t1)
print("Loaded in %2.2f milliseconds" % t)
arg_params['prob_label'] = mx.nd.array([0])
mod = mx.mod.Module(symbol=sym)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
mod.set_params(arg_params, aux_params)
return mod
def loadCategories():
synsetfile = open('synset.txt', 'r')
synsets = []
for l in synsetfile:
synsets.append(l.rstrip())
return synsets
def prepareNDArray(filename):
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224,))
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]
return mx.nd.array(img)
def predict(filename, model, categories, n):
array = prepareNDArray(filename)
Batch = namedtuple('Batch', ['data'])
t1 = time.time()
model.forward(Batch([array]))
t2 = time.time()
t = 1000*(t2-t1)
print("Predicted in %2.2f millsecond" % t)
prob = model.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
sortedprobindex = np.argsort(prob)[::-1]
topn = []
for i in sortedprobindex[0:n]:
topn.append((prob[i], categories[i]))
return topn
def init(modelname):
model = loadModel(modelname)
cats = loadCategories()
return model, cats
vgg16,c = init("vgg16")
resnet152,c = init("resnet-152")
inceptionv3,c = init("Inception-BN")
filename = sys.argv[1]
print ("*** VGG16")
print predict(filename,vgg16,c,5)
print ("*** ResNet-152")
print predict(filename,resnet152,c,5)
print ("*** Inception v3")
print predict(filename,inceptionv3,c,5)
@temojin
Copy link

temojin commented Jun 3, 2017

I seem to have trouble with the vgg model. Searching google points to this: apache/mxnet#5779
Where a user suggested: "A probable remedy could be to just supply a dummy variable with the name 'softmax_label'"
But i'm not familiar enough with MXNet yet to make that change. Can you help?

Here's the output:
[02:32:46] src/nnvm/legacy_json_util.cc:190: Loading symbol saved by previous version v0.8.0. Attempting to upgrade...
[02:32:46] src/nnvm/legacy_json_util.cc:198: Symbol successfully upgraded!
Loaded in 384.57 microseconds
/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/base_module.py:52: UserWarning: You created Module with Module(..., label_names=['softmax_label']) but input with name 'softmax_label' is not found in symbol.list_arguments(). Did you mean one of:
data
prob_label
warnings.warn(msg)
/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/base_module.py:64: UserWarning: Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])
warnings.warn(msg)
Traceback (most recent call last):
File "anymodel.py", line 55, in
vgg16,c = init("vgg16")
File "anymodel.py", line 50, in init
model = loadModel(modelname)
File "anymodel.py", line 14, in loadModel
mod.set_params(arg_params, aux_params)
File "/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/module.py", line 309, in set_params
allow_missing=allow_missing, force_init=force_init)
File "/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/module.py", line 274, in init_params
_impl(desc, arr, arg_params)
File "/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/module.py", line 265, in _impl
raise RuntimeError("%s is not presented" % name)
RuntimeError: prob_label is not presented
(mxnet) galaga:pretrained-models itsme$

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment