Last active
November 21, 2018 22:08
-
-
Save juliensimon/ed9ef71ae35ae1d5048dd14bbefc552a to your computer and use it in GitHub Desktop.
MXNet + 3 CNNs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
import numpy as np | |
import cv2,sys,time | |
from collections import namedtuple | |
def loadModel(modelname): | |
t1 = time.time() | |
sym, arg_params, aux_params = mx.model.load_checkpoint(modelname, 0) | |
t2 = time.time() | |
t = 1000*(t2-t1) | |
print("Loaded in %2.2f milliseconds" % t) | |
arg_params['prob_label'] = mx.nd.array([0]) | |
mod = mx.mod.Module(symbol=sym) | |
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))]) | |
mod.set_params(arg_params, aux_params) | |
return mod | |
def loadCategories(): | |
synsetfile = open('synset.txt', 'r') | |
synsets = [] | |
for l in synsetfile: | |
synsets.append(l.rstrip()) | |
return synsets | |
def prepareNDArray(filename): | |
img = cv2.imread(filename) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
img = cv2.resize(img, (224, 224,)) | |
img = np.swapaxes(img, 0, 2) | |
img = np.swapaxes(img, 1, 2) | |
img = img[np.newaxis, :] | |
return mx.nd.array(img) | |
def predict(filename, model, categories, n): | |
array = prepareNDArray(filename) | |
Batch = namedtuple('Batch', ['data']) | |
t1 = time.time() | |
model.forward(Batch([array])) | |
t2 = time.time() | |
t = 1000*(t2-t1) | |
print("Predicted in %2.2f millsecond" % t) | |
prob = model.get_outputs()[0].asnumpy() | |
prob = np.squeeze(prob) | |
sortedprobindex = np.argsort(prob)[::-1] | |
topn = [] | |
for i in sortedprobindex[0:n]: | |
topn.append((prob[i], categories[i])) | |
return topn | |
def init(modelname): | |
model = loadModel(modelname) | |
cats = loadCategories() | |
return model, cats | |
vgg16,c = init("vgg16") | |
resnet152,c = init("resnet-152") | |
inceptionv3,c = init("Inception-BN") | |
filename = sys.argv[1] | |
print ("*** VGG16") | |
print predict(filename,vgg16,c,5) | |
print ("*** ResNet-152") | |
print predict(filename,resnet152,c,5) | |
print ("*** Inception v3") | |
print predict(filename,inceptionv3,c,5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I seem to have trouble with the vgg model. Searching google points to this: apache/mxnet#5779
Where a user suggested: "A probable remedy could be to just supply a dummy variable with the name 'softmax_label'"
But i'm not familiar enough with MXNet yet to make that change. Can you help?
Here's the output:
[02:32:46] src/nnvm/legacy_json_util.cc:190: Loading symbol saved by previous version v0.8.0. Attempting to upgrade...
[02:32:46] src/nnvm/legacy_json_util.cc:198: Symbol successfully upgraded!
Loaded in 384.57 microseconds
/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/base_module.py:52: UserWarning: You created Module with Module(..., label_names=['softmax_label']) but input with name 'softmax_label' is not found in symbol.list_arguments(). Did you mean one of:
data
prob_label
warnings.warn(msg)
/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/base_module.py:64: UserWarning: Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])
warnings.warn(msg)
Traceback (most recent call last):
File "anymodel.py", line 55, in
vgg16,c = init("vgg16")
File "anymodel.py", line 50, in init
model = loadModel(modelname)
File "anymodel.py", line 14, in loadModel
mod.set_params(arg_params, aux_params)
File "/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/module.py", line 309, in set_params
allow_missing=allow_missing, force_init=force_init)
File "/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/module.py", line 274, in init_params
_impl(desc, arr, arg_params)
File "/Users/itsme/.virtualenvs/mxnet/lib/python2.7/site-packages/mxnet/module/module.py", line 265, in _impl
raise RuntimeError("%s is not presented" % name)
RuntimeError: prob_label is not presented
(mxnet) galaga:pretrained-models itsme$