Skip to content

Instantly share code, notes, and snippets.

@regonn
Created December 2, 2017 14:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save regonn/145b1f441a34f31246b9ad0d650d3738 to your computer and use it in GitHub Desktop.
Save regonn/145b1f441a34f31246b9ad0d650d3738 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(MXNet.mx.MXDataProvider(MXNet.mx.MX_DataIterHandle(Ptr{Void} @0x00000000052fba70), Tuple{Symbol,Tuple}[(:data, (784, 100))], Tuple{Symbol,Tuple}[(:softmax_label, (100,))], 100, true, true), MXNet.mx.MXDataProvider(MXNet.mx.MX_DataIterHandle(Ptr{Void} @0x0000000004ceae60), Tuple{Symbol,Tuple}[(:data, (784, 100))], Tuple{Symbol,Tuple}[(:softmax_label, (100,))], 100, true, true))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using MXNet\n",
"\n",
"mlp = @mx.chain mx.Variable(:data) =>\n",
" mx.FullyConnected(name=:fc1, num_hidden=128) =>\n",
" mx.Activation(name=:relu1, act_type=:relu) =>\n",
" mx.FullyConnected(name=:fc2, num_hidden=64) =>\n",
" mx.Activation(name=:relu2, act_type=:relu) =>\n",
" mx.FullyConnected(name=:fc3, num_hidden=10) =>\n",
" mx.SoftmaxOutput(name=:softmax)\n",
"\n",
"# data provider\n",
"batch_size = 100\n",
"include(Pkg.dir(\"MXNet\", \"examples\", \"mnist\", \"mnist-data.jl\"))\n",
"train_provider, eval_provider = get_mnist_providers(batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mStart training on MXNet.mx.Context[CPU0]\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mInitializing parameters...\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mCreating KVStore...\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mTempSpace: Total 0 MB allocated on CPU0\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mStart training...\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 001/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.7579\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 115.3491 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9537\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 002/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9575\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 114.0865 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9650\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 003/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9695\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 114.2930 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9687\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 004/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9770\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 116.2855 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9735\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 005/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9810\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 115.9846 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9694\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mFinish training on MXNet.mx.Context[CPU0]\n",
"\u001b[39m"
]
}
],
"source": [
"# setup model\n",
"model = mx.FeedForward(mlp, context=mx.cpu())\n",
"\n",
"# optimization algorithm\n",
"optimizer = mx.SGD(lr=0.1, momentum=0.9)\n",
"\n",
"# fit parameters\n",
"mx.fit(model, optimizer, train_provider, n_epoch=5, eval_data=eval_provider)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mStart training on MXNet.mx.Context[GPU0]\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mInitializing parameters...\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mCreating KVStore...\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mTempSpace: Total 0 MB allocated on GPU0\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mStart training...\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 001/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.7575\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 1.8311 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9502\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 002/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9575\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 1.3912 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9611\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 003/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9696\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 1.3981 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9678\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 004/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9770\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 1.3920 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9703\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m== Epoch 005/005 ==========\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Training summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9803\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m time = 1.3830 seconds\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m## Validation summary\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36m accuracy = 0.9718\n",
"\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mFinish training on MXNet.mx.Context[GPU0]\n",
"\u001b[39m"
]
}
],
"source": [
"# setup model\n",
"model = mx.FeedForward(mlp, context=mx.gpu(0))\n",
"\n",
"# optimization algorithm\n",
"optimizer = mx.SGD(lr=0.1, momentum=0.9)\n",
"\n",
"# fit parameters\n",
"mx.fit(model, optimizer, train_provider, n_epoch=5, eval_data=eval_provider)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 0.6.1",
"language": "julia",
"name": "julia-0.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "0.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment