Skip to content

Instantly share code, notes, and snippets.

@regonn
Created March 30, 2019 01:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save regonn/d2acf5a20a1b3ec34d8e483af510b4a3 to your computer and use it in GitHub Desktop.
Save regonn/d2acf5a20a1b3ec34d8e483af510b4a3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"using PyCall"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PyObject <module 'pandas' from '/home/regonn/.local/lib/python3.7/site-packages/pandas/__init__.py'>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chainer = pyimport(\"chainer\")\n",
"math = pyimport(\"math\")\n",
"random = pyimport(\"random\")\n",
"np = pyimport(\"numpy\")\n",
"pd = pyimport(\"pandas\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:1\n",
"└ @ Core In[3]:1\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:2\n",
"└ @ Core In[3]:2\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:3\n",
"└ @ Core In[3]:3\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:4\n",
"└ @ Core In[3]:4\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:5\n",
"└ @ Core In[3]:5\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:6\n",
"└ @ Core In[3]:6\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:7\n",
"└ @ Core In[3]:7\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:8\n",
"└ @ Core In[3]:8\n"
]
},
{
"data": {
"text/plain": [
"PyObject <module 'chainer.training.extensions' from '/home/regonn/.local/lib/python3.7/site-packages/chainer/training/extensions/__init__.py'>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"const Chain = chainer[:Chain]\n",
"const Variable = chainer[:Variable]\n",
"const F = chainer[:functions]\n",
"const L = chainer[:links]\n",
"const iterators = chainer[:iterators]\n",
"const optimizers = chainer[:optimizers]\n",
"const training = chainer[:training]\n",
"const extensions = training[:extensions]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(PyObject <chainer.datasets.tuple_dataset.TupleDataset object at 0x7efdc0af2470>, PyObject <chainer.datasets.tuple_dataset.TupleDataset object at 0x7efdaa132b70>)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"const mnist = chainer.datasets.mnist\n",
"train, test = mnist.get_mnist()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"128"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batchsize = 128"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PyObject <chainer.iterators.serial_iterator.SerialIterator object at 0x7efdaa1326a0>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_iter = iterators.SerialIterator(train, batchsize)\n",
"test_iter = iterators.SerialIterator(test, batchsize, repeat=false, shuffle=false)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PyObject <class 'MLP'>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@pydef mutable struct MLP <: Chain\n",
" function __init__(self, n_mid_units=100, n_out=10)\n",
" pybuiltin(:super)(MLP, self)[:__init__]()\n",
" @pywith self.init_scope() begin\n",
" self.l1 = L.Linear(py\"None\", n_mid_units)\n",
" self.l2 = L.Linear(py\"None\", n_mid_units)\n",
" self.l3 = L.Linear(py\"None\", n_out)\n",
" end\n",
" end\n",
"\n",
" function forward(self, x)\n",
" h1 = F.relu(self.l1(x))\n",
" h2 = F.relu(self.l2(h1))\n",
" return self.l3(h2)\n",
" end\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = macro expansion at In[7]:3 [inlined]\n",
"└ @ Core ./In[7]:3\n"
]
},
{
"data": {
"text/plain": [
"PyObject <chainer.links.model.classifier.Classifier object at 0x7efdaa1551d0>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"max_epoch = 10\n",
"model = MLP()\n",
"model = L.Classifier(model)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PyObject <chainer.optimizers.momentum_sgd.MomentumSGD object at 0x7efdaa155470>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimizer = optimizers.MomentumSGD()\n",
"optimizer.setup(model)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PyObject <chainer.training.updaters.standard_updater.StandardUpdater object at 0x7efdaa155518>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"updater = training.updaters.StandardUpdater(train_iter, optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"trainer = training.Trainer(updater, (max_epoch, \"epoch\"), out=\"mnist_result\")\n",
"trainer.extend(extensions.LogReport())\n",
"trainer.extend(extensions.Evaluator(test_iter, model))\n",
"trainer.extend(extensions.PrintReport([\"epoch\", \"main/loss\", \"main/accuracy\", \"validation/main/loss\", \"validation/main/accuracy\", \"elapsed_time\"]))\n",
"trainer.extend(extensions.PlotReport([\"main/loss\", \"validation/main/loss\"], x_key=\"epoch\", file_name=\"loss.png\"))\n",
"trainer.extend(extensions.PlotReport([\"main/accuracy\", \"validation/main/accuracy\"], x_key=\"epoch\", file_name=\"accuracy.png\"))\n",
"trainer.extend(extensions.dump_graph(\"main/loss\"))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch main/loss main/accuracy validation/main/loss validation/main/accuracy elapsed_time\n",
"\u001b[J1 0.53921 0.850197 0.270433 0.919502 151.316 \n",
"\u001b[J2 0.236785 0.93122 0.199803 0.943335 302.314 \n",
"\u001b[J3 0.179879 0.947795 0.157251 0.953224 451.812 \n",
"\u001b[J4 0.143571 0.958417 0.134044 0.959652 602.499 \n",
"\u001b[J5 0.120292 0.965685 0.118129 0.964399 751.069 \n",
"\u001b[J6 0.10287 0.969716 0.101918 0.96875 893.483 \n",
"\u001b[J7 0.0892034 0.974264 0.0912552 0.972013 1041.57 \n",
"\u001b[J8 0.077978 0.977664 0.090307 0.971816 1192.46 \n",
"\u001b[J9 0.0701926 0.979278 0.087615 0.972903 1362.7 \n",
"\u001b[J10 0.0617166 0.98136 0.0808613 0.97498 1504.98 \n"
]
}
],
"source": [
"trainer.run()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.1.0",
"language": "julia",
"name": "julia-1.1"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.1.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment