Created
March 30, 2019 01:29
-
-
Save regonn/d2acf5a20a1b3ec34d8e483af510b4a3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using PyCall" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <module 'pandas' from '/home/regonn/.local/lib/python3.7/site-packages/pandas/__init__.py'>" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"chainer = pyimport(\"chainer\")\n", | |
"math = pyimport(\"math\")\n", | |
"random = pyimport(\"random\")\n", | |
"np = pyimport(\"numpy\")\n", | |
"pd = pyimport(\"pandas\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:1\n", | |
"└ @ Core In[3]:1\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:2\n", | |
"└ @ Core In[3]:2\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:3\n", | |
"└ @ Core In[3]:3\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:4\n", | |
"└ @ Core In[3]:4\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:5\n", | |
"└ @ Core In[3]:5\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:6\n", | |
"└ @ Core In[3]:6\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:7\n", | |
"└ @ Core In[3]:7\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:8\n", | |
"└ @ Core In[3]:8\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <module 'chainer.training.extensions' from '/home/regonn/.local/lib/python3.7/site-packages/chainer/training/extensions/__init__.py'>" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"const Chain = chainer[:Chain]\n", | |
"const Variable = chainer[:Variable]\n", | |
"const F = chainer[:functions]\n", | |
"const L = chainer[:links]\n", | |
"const iterators = chainer[:iterators]\n", | |
"const optimizers = chainer[:optimizers]\n", | |
"const training = chainer[:training]\n", | |
"const extensions = training[:extensions]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(PyObject <chainer.datasets.tuple_dataset.TupleDataset object at 0x7efdc0af2470>, PyObject <chainer.datasets.tuple_dataset.TupleDataset object at 0x7efdaa132b70>)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"const mnist = chainer.datasets.mnist\n", | |
"train, test = mnist.get_mnist()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"128" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"batchsize = 128" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <chainer.iterators.serial_iterator.SerialIterator object at 0x7efdaa1326a0>" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_iter = iterators.SerialIterator(train, batchsize)\n", | |
"test_iter = iterators.SerialIterator(test, batchsize, repeat=false, shuffle=false)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <class 'MLP'>" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"@pydef mutable struct MLP <: Chain\n", | |
" function __init__(self, n_mid_units=100, n_out=10)\n", | |
" pybuiltin(:super)(MLP, self)[:__init__]()\n", | |
" @pywith self.init_scope() begin\n", | |
" self.l1 = L.Linear(py\"None\", n_mid_units)\n", | |
" self.l2 = L.Linear(py\"None\", n_mid_units)\n", | |
" self.l3 = L.Linear(py\"None\", n_out)\n", | |
" end\n", | |
" end\n", | |
"\n", | |
" function forward(self, x)\n", | |
" h1 = F.relu(self.l1(x))\n", | |
" h2 = F.relu(self.l2(h1))\n", | |
" return self.l3(h2)\n", | |
" end\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = macro expansion at In[7]:3 [inlined]\n", | |
"└ @ Core ./In[7]:3\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <chainer.links.model.classifier.Classifier object at 0x7efdaa1551d0>" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"max_epoch = 10\n", | |
"model = MLP()\n", | |
"model = L.Classifier(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <chainer.optimizers.momentum_sgd.MomentumSGD object at 0x7efdaa155470>" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"optimizer = optimizers.MomentumSGD()\n", | |
"optimizer.setup(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <chainer.training.updaters.standard_updater.StandardUpdater object at 0x7efdaa155518>" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"updater = training.updaters.StandardUpdater(train_iter, optimizer)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trainer = training.Trainer(updater, (max_epoch, \"epoch\"), out=\"mnist_result\")\n", | |
"trainer.extend(extensions.LogReport())\n", | |
"trainer.extend(extensions.Evaluator(test_iter, model))\n", | |
"trainer.extend(extensions.PrintReport([\"epoch\", \"main/loss\", \"main/accuracy\", \"validation/main/loss\", \"validation/main/accuracy\", \"elapsed_time\"]))\n", | |
"trainer.extend(extensions.PlotReport([\"main/loss\", \"validation/main/loss\"], x_key=\"epoch\", file_name=\"loss.png\"))\n", | |
"trainer.extend(extensions.PlotReport([\"main/accuracy\", \"validation/main/accuracy\"], x_key=\"epoch\", file_name=\"accuracy.png\"))\n", | |
"trainer.extend(extensions.dump_graph(\"main/loss\"))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch main/loss main/accuracy validation/main/loss validation/main/accuracy elapsed_time\n", | |
"\u001b[J1 0.53921 0.850197 0.270433 0.919502 151.316 \n", | |
"\u001b[J2 0.236785 0.93122 0.199803 0.943335 302.314 \n", | |
"\u001b[J3 0.179879 0.947795 0.157251 0.953224 451.812 \n", | |
"\u001b[J4 0.143571 0.958417 0.134044 0.959652 602.499 \n", | |
"\u001b[J5 0.120292 0.965685 0.118129 0.964399 751.069 \n", | |
"\u001b[J6 0.10287 0.969716 0.101918 0.96875 893.483 \n", | |
"\u001b[J7 0.0892034 0.974264 0.0912552 0.972013 1041.57 \n", | |
"\u001b[J8 0.077978 0.977664 0.090307 0.971816 1192.46 \n", | |
"\u001b[J9 0.0701926 0.979278 0.087615 0.972903 1362.7 \n", | |
"\u001b[J10 0.0617166 0.98136 0.0808613 0.97498 1504.98 \n" | |
] | |
} | |
], | |
"source": [ | |
"trainer.run()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 1.1.0", | |
"language": "julia", | |
"name": "julia-1.1" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.1.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment