Skip to content

Instantly share code, notes, and snippets.

@pierric
Created December 30, 2018 22:10
Show Gist options
  • Save pierric/b4da8783755de763bff13aea5f5e900d to your computer and use it in GitHub Desktop.
Save pierric/b4da8783755de763bff13aea5f5e900d to your computer and use it in GitHub Desktop.
Inference with saved model in mxnet
import mxnet as mx
import itertools
import numpy as np
from collections import namedtuple
ctx = mx.cpu()
def load(prefix):
symbol = mx.sym.load('%s.json' % prefix)
save_dict = mx.nd.load('%s.params' % prefix)
arg_params = {}
aux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_params[name] = v
if tp == 'aux':
aux_params[name] = v
return (symbol, arg_params, aux_params)
sym, arg, aux = load('epoch_0_acc_0.98_loss_0.06')
valid_iter = mx.io.ImageRecordIter(
path_imgrec="../dataiter/test/data/cifar10_val.rec", data_name="data", label_name="softmax_label",
batch_size=10, data_shape=(3,28,28))
print(valid_iter.provide_label)
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None, data_names=['x'])
mod.bind(for_training=False, data_shapes=[('x', (128,3,32,32))], label_shapes=[('y', (128,))])
mod.set_params(arg, aux, allow_missing=True)
firstN = itertools.islice(valid_iter, 1)
for batch in firstN:
mod.forward(batch)
prob = mod.get_outputs()[0].asnumpy()
print(prob.shape)
cate = np.argmax(prob, 1)
print(cate)
label = batch.label[0].asnumpy().astype(int)
print(label)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment