Skip to content

Instantly share code, notes, and snippets.

@ThomasDelteil
Created May 31, 2018 23:19
Show Gist options
  • Save ThomasDelteil/f3258a9c21a1ba330eee7ed8bb1c71fa to your computer and use it in GitHub Desktop.
Save ThomasDelteil/f3258a9c21a1ba330eee7ed8bb1c71fa to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Saving / Loading models in MXNet"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sample Network -> Saving"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"import mxnet as mx\n",
"import numpy as np\n",
"from mxnet import gluon\n",
"ctx = mx.cpu()\n",
"\n",
"save_params = 'save_params.params'\n",
"collect_params = 'collect_params.params'\n",
"export_params = 'export-0000.params'\n",
"sym_symbol = 'sym.json'\n",
"export_symbol = 'export-symbol.json'\n",
"\n",
"# Create network\n",
"def get_net(prefix=\"test_\"):\n",
" net = gluon.nn.HybridSequential(prefix=prefix)\n",
" with net.name_scope():\n",
" net.add(gluon.nn.Conv2D(10, (3, 3)))\n",
" net.add(gluon.nn.Dense(50))\n",
" net.add(gluon.nn.BatchNorm())\n",
" net.initialize()\n",
" return net\n",
"\n",
"net = get_net()\n",
"# Save network \n",
"net.hybridize()\n",
"data = mx.nd.ones((1,1,50,50))\n",
"out = net(data).asnumpy()\n",
"net.export('export', epoch=0)\n",
"net.save_params(save_params)\n",
"net.collect_params().save(collect_params)\n",
"\n",
"sym = net(mx.sym.Variable('data'))\n",
"sym.save(sym_symbol)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.1.0'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mx.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading in Python / Gluon Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `save_params`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# save_params / load_params\n",
"net = get_net()\n",
"net.load_params(save_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"test_ (\n",
" Parameter test_conv0_weight (shape=(10L, 0L, 3L, 3L), dtype=<type 'numpy.float32'>)\n",
" Parameter test_conv0_bias (shape=(10L,), dtype=<type 'numpy.float32'>)\n",
" Parameter test_dense0_weight (shape=(50, 0), dtype=<type 'numpy.float32'>)\n",
" Parameter test_dense0_bias (shape=(50,), dtype=<type 'numpy.float32'>)\n",
" Parameter test_batchnorm0_gamma (shape=(0,), dtype=<type 'numpy.float32'>)\n",
" Parameter test_batchnorm0_beta (shape=(0,), dtype=<type 'numpy.float32'>)\n",
" Parameter test_batchnorm0_running_mean (shape=(0,), dtype=<type 'numpy.float32'>)\n",
" Parameter test_batchnorm0_running_var (shape=(0,), dtype=<type 'numpy.float32'>)\n",
")"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net.collect_params()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# save_params / collect_params\n",
"net = get_net(\"\")\n",
"net.collect_params().load(save_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `collect_params`"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# collect_params / load_params\n",
"net = get_net()\n",
"net._prefix = \"\"\n",
"net.load_params(collect_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# collect_params / collect_params\n",
"net = get_net()\n",
"net.collect_params().load(collect_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `export`"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# export / load_params\n",
"net = get_net()\n",
"net._prefix = \"\"\n",
"net.load_params(export_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# export / collect_params\n",
"net = get_net()\n",
"net.collect_params().load(export_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading in Python / Symbol Block - export json"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"137c137\n",
"< \"inputs\": [[6, 0, 0], [7, 0, 0], [8, 0, 0], [9, 0, 2], [10, 0, 2]]\n",
"---\n",
"> \"inputs\": [[6, 0, 0], [7, 0, 0], [8, 0, 0], [9, 0, 1], [10, 0, 1]]\n"
]
}
],
"source": [
"!diff sym.json export-symbol.json"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"def get_net():\n",
" sym = mx.sym.load_json(open(export_symbol, 'r').read())\n",
" net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))\n",
" return net"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `save_params`"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "Parameter test_conv0_weight is missing in file save_params.params",
"output_type": "error",
"traceback": [
"\u001b[0;31m\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-30-bc3892642c64>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# save_params / load_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/gluon/block.pyc\u001b[0m in \u001b[0;36mload_params\u001b[0;34m(self, filename, ctx, allow_missing, ignore_extra)\u001b[0m\n\u001b[1;32m 315\u001b[0m \"\"\"\n\u001b[1;32m 316\u001b[0m self.collect_params().load(filename, ctx, allow_missing, ignore_extra,\n\u001b[0;32m--> 317\u001b[0;31m self.prefix)\n\u001b[0m\u001b[1;32m 318\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_child\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/gluon/parameter.pyc\u001b[0m in \u001b[0;36mload\u001b[0;34m(self, filename, ctx, allow_missing, ignore_extra, restore_prefix)\u001b[0m\n\u001b[1;32m 667\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 668\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marg_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 669\u001b[0;31m \u001b[0;34m\"Parameter %s is missing in file %s\"\u001b[0m\u001b[0;34m%\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlprefix\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 670\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marg_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 671\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_params\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: Parameter test_conv0_weight is missing in file save_params.params"
]
}
],
"source": [
"# save_params / load_params\n",
"net = get_net()\n",
"net.load_params(save_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "Parameter test_conv0_weight is missing in file save_params.params",
"output_type": "error",
"traceback": [
"\u001b[0;31m\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-31-2378c5c25d8c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# save_params / collect_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/gluon/parameter.pyc\u001b[0m in \u001b[0;36mload\u001b[0;34m(self, filename, ctx, allow_missing, ignore_extra, restore_prefix)\u001b[0m\n\u001b[1;32m 667\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 668\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marg_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 669\u001b[0;31m \u001b[0;34m\"Parameter %s is missing in file %s\"\u001b[0m\u001b[0;34m%\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlprefix\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 670\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marg_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 671\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_params\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: Parameter test_conv0_weight is missing in file save_params.params"
]
}
],
"source": [
"# save_params / collect_params\n",
"net = get_net()\n",
"net.collect_params().load(save_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `collect_params`"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"# collect_params / load_params\n",
"net = get_net()\n",
"net.load_params(collect_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"# collect_params / collect_params\n",
"net = get_net()\n",
"net.collect_params().load(collect_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `export`"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"# export / load_params\n",
"net = get_net()\n",
"net.load_params(export_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"# export / collect_params\n",
"net = get_net()\n",
"net.collect_params().load(export_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading in Python / Symbol Block - mx.sym.var('data') json"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"def get_net():\n",
" sym = mx.sym.load_json(open(sym_symbol, 'r').read())\n",
" net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))\n",
" return net"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `save_params`"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "Parameter test_conv0_weight is missing in file save_params.params",
"output_type": "error",
"traceback": [
"\u001b[0;31m\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-47-bc3892642c64>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# save_params / load_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/gluon/block.pyc\u001b[0m in \u001b[0;36mload_params\u001b[0;34m(self, filename, ctx, allow_missing, ignore_extra)\u001b[0m\n\u001b[1;32m 315\u001b[0m \"\"\"\n\u001b[1;32m 316\u001b[0m self.collect_params().load(filename, ctx, allow_missing, ignore_extra,\n\u001b[0;32m--> 317\u001b[0;31m self.prefix)\n\u001b[0m\u001b[1;32m 318\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_child\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/gluon/parameter.pyc\u001b[0m in \u001b[0;36mload\u001b[0;34m(self, filename, ctx, allow_missing, ignore_extra, restore_prefix)\u001b[0m\n\u001b[1;32m 667\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 668\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marg_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 669\u001b[0;31m \u001b[0;34m\"Parameter %s is missing in file %s\"\u001b[0m\u001b[0;34m%\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlprefix\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 670\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marg_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 671\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_params\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: Parameter test_conv0_weight is missing in file save_params.params"
]
}
],
"source": [
"# save_params / load_params\n",
"net = get_net()\n",
"net.load_params(save_params, ctx=mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "Parameter test_conv0_weight is missing in file save_params.params",
"output_type": "error",
"traceback": [
"\u001b[0;31m\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-48-dc38e98a2978>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# save_params / collect_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/gluon/parameter.pyc\u001b[0m in \u001b[0;36mload\u001b[0;34m(self, filename, ctx, allow_missing, ignore_extra, restore_prefix)\u001b[0m\n\u001b[1;32m 667\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 668\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marg_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 669\u001b[0;31m \u001b[0;34m\"Parameter %s is missing in file %s\"\u001b[0m\u001b[0;34m%\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlprefix\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 670\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0marg_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 671\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_params\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: Parameter test_conv0_weight is missing in file save_params.params"
]
}
],
"source": [
"# save_params / collect_params\n",
"net = get_net()\n",
"net.collect_params().load(save_params, mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `collect_params`"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"# collect_params / load_params\n",
"net = get_net()\n",
"net.load_params(collect_params, mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"# collect_params / collect_params\n",
"net = get_net()\n",
"net.collect_params().load(collect_params, mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `export`"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"# export / load_params\n",
"net = get_net()\n",
"net.load_params(export_params, mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"# export / collect_params\n",
"net = get_net()\n",
"net.collect_params().load(export_params, mx.cpu())\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading with module API"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"def get_module():\n",
" sym = mx.sym.load_json(open(sym_symbol, 'r').read())\n",
" mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)\n",
" mod.bind(for_training=False, data_shapes=[('data', (1,1,50,50))], \n",
" label_shapes=mod._label_shapes)\n",
" return mod"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `save_params`"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "need more than 1 value to unpack",
"output_type": "error",
"traceback": [
"\u001b[0;31m\u001b[0m",
"\u001b[0;31mValueError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-57-f179d8f58eab>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msystem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mu'cp $sym_symbol test-symbol.json'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msystem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mu'cp $save_params test-0000.params'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0msym\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux_params\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_checkpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'test'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/model.pyc\u001b[0m in \u001b[0;36mload_checkpoint\u001b[0;34m(prefix, epoch)\u001b[0m\n\u001b[1;32m 423\u001b[0m \u001b[0maux_params\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 424\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msave_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 425\u001b[0;31m \u001b[0mtp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m':'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 426\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtp\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'arg'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0marg_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: need more than 1 value to unpack"
]
}
],
"source": [
"!cp $sym_symbol test-symbol.json\n",
"!cp $save_params test-0000.params\n",
"sym, arg_params, aux_params = mx.model.load_checkpoint('test', 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `collect_params`"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "need more than 1 value to unpack",
"output_type": "error",
"traceback": [
"\u001b[0;31m\u001b[0m",
"\u001b[0;31mValueError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-58-ae8bdf923005>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msystem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mu'cp $sym_symbol test-symbol.json'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msystem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mu'cp $collect_params test-0000.params'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0msym\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux_params\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_checkpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'test'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/home/ubuntu/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/model.pyc\u001b[0m in \u001b[0;36mload_checkpoint\u001b[0;34m(prefix, epoch)\u001b[0m\n\u001b[1;32m 423\u001b[0m \u001b[0maux_params\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 424\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msave_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 425\u001b[0;31m \u001b[0mtp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m':'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 426\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtp\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'arg'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0marg_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: need more than 1 value to unpack"
]
}
],
"source": [
"!cp $sym_symbol test-symbol.json\n",
"!cp $collect_params test-0000.params\n",
"sym, arg_params, aux_params = mx.model.load_checkpoint('test', 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Parameters saved with `export`"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"sym, arg_params, aux_params = mx.model.load_checkpoint('export', 0)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"mod = get_module()\n",
"mod.set_params(arg_params, aux_params, allow_missing=True)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"mod.forward(mx.io.DataBatch([data]))"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"out = mod.get_outputs()[0].asnumpy()\n",
"assert np.array_equal(out, net(data).asnumpy())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Environment (conda_mxnet_p27)",
"language": "python",
"name": "conda_mxnet_p27"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment