Created
March 30, 2019 01:30
-
-
Save regonn/9eb7ff68ac3a5a3201523bd6267953dd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using PyCall" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <module 'optuna' from '/home/regonn/.local/lib/python3.7/site-packages/optuna/__init__.py'>" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"chainer = pyimport(\"chainer\")\n", | |
"math = pyimport(\"math\")\n", | |
"random = pyimport(\"random\")\n", | |
"np = pyimport(\"numpy\")\n", | |
"pd = pyimport(\"pandas\")\n", | |
"optuna = pyimport(\"optuna\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:1\n", | |
"└ @ Core In[3]:1\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:2\n", | |
"└ @ Core In[3]:2\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:3\n", | |
"└ @ Core In[3]:3\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:4\n", | |
"└ @ Core In[3]:4\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:5\n", | |
"└ @ Core In[3]:5\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:6\n", | |
"└ @ Core In[3]:6\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:7\n", | |
"└ @ Core In[3]:7\n", | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = top-level scope at In[3]:8\n", | |
"└ @ Core In[3]:8\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <module 'chainer.training.extensions' from '/home/regonn/.local/lib/python3.7/site-packages/chainer/training/extensions/__init__.py'>" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"const Chain = chainer[:Chain]\n", | |
"const Variable = chainer[:Variable]\n", | |
"const F = chainer[:functions]\n", | |
"const L = chainer[:links]\n", | |
"const iterators = chainer[:iterators]\n", | |
"const optimizers = chainer[:optimizers]\n", | |
"const training = chainer[:training]\n", | |
"const extensions = training[:extensions]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"10" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"N_TRAIN_EXAMPLES = 3000\n", | |
"N_TEST_EXAMPLES = 1000\n", | |
"BATCHSIZE = 128\n", | |
"EPOCH = 10" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PyObject <class 'MLP'>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"@pydef mutable struct MLP <: Chain\n", | |
" function __init__(self, n_mid_units=100, n_out=10)\n", | |
" pybuiltin(:super)(MLP, self)[:__init__]()\n", | |
" @pywith self.init_scope() begin\n", | |
" self.l1 = L.Linear(py\"None\", n_mid_units)\n", | |
" self.l2 = L.Linear(py\"None\", n_mid_units)\n", | |
" self.l3 = L.Linear(py\"None\", n_out)\n", | |
" end\n", | |
" end\n", | |
"\n", | |
" function forward(self, x)\n", | |
" h1 = F.relu(self.l1(x))\n", | |
" h2 = F.relu(self.l2(h1))\n", | |
" return self.l3(h2)\n", | |
" end\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"create_optimizer (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function create_optimizer(trial, model)\n", | |
" optimizer_name = trial.suggest_categorical(\"optimizer\", [\"Adam\", \"MomentumSGD\"])\n", | |
" if optimizer_name == \"Adam\"\n", | |
" adam_alpha = trial.suggest_loguniform(\"adam_alpha\", 1e-5, 1e-1)\n", | |
" optimizer = chainer.optimizers.Adam(alpha=adam_alpha)\n", | |
" else\n", | |
" momentum_sgd_lr = trial.suggest_loguniform(\"momentum_sgd_lr\", 1e-5, 1e-1)\n", | |
" optimizer = chainer.optimizers.MomentumSGD(lr=momentum_sgd_lr)\n", | |
" end\n", | |
" \n", | |
" weight_decay = trial.suggest_loguniform(\"weight_decay\", 1e-10, 1e-3)\n", | |
" optimizer.setup(model)\n", | |
" optimizer.add_hook(chainer.optimizer.WeightDecay(weight_decay))\n", | |
" return optimizer\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"objective (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function objective(trial)\n", | |
" model = MLP()\n", | |
" model = L.Classifier(model)\n", | |
" optimizer = create_optimizer(trial, model)\n", | |
" \n", | |
" rng = np.random.RandomState(0)\n", | |
" train, test = chainer.datasets.get_mnist()\n", | |
" train = chainer.datasets.SubDataset(train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(length(train)))\n", | |
" test = chainer.datasets.SubDataset(test, 0, N_TEST_EXAMPLES, order=rng.permutation(length(test)))\n", | |
" train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE)\n", | |
" test_iter = chainer.iterators.SerialIterator(test, BATCHSIZE, repeat=false, shuffle=false)\n", | |
" \n", | |
" updater = chainer.training.StandardUpdater(train_iter, optimizer)\n", | |
" trainer = chainer.training.Trainer(updater, (EPOCH, \"epoch\"))\n", | |
" trainer.extend(chainer.training.extensions.Evaluator(test_iter, model))\n", | |
" log_report_extension = chainer.training.extensions.LogReport(log_name=py\"None\")\n", | |
" trainer.extend(chainer.training.extensions.PrintReport([\"epoch\", \"main/loss\", \"validation/main/loss\", \"main/accuracy\", \"validation/main/accuracy\"]))\n", | |
" trainer.extend(log_report_extension)\n", | |
" \n", | |
" trainer.run()\n", | |
" \n", | |
" log_last = log_report_extension.log[end]\n", | |
" for (key, value) in log_last\n", | |
" trial.set_user_attr(key, value)\n", | |
" end\n", | |
" \n", | |
" val_err = 1.0 - log_report_extension.log[end][\"validation/main/accuracy\"]\n", | |
" return val_err\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n", | |
"│ caller = macro expansion at In[5]:3 [inlined]\n", | |
"└ @ Core ./In[5]:3\n", | |
"[I 2019-03-30 09:11:05,273] Finished trial#0 resulted in value: 0.10854867845773697. Current best value is 0.10854867845773697 with parameters: {'optimizer': 'Adam', 'adam_alpha': 0.04735444104799611, 'weight_decay': 3.78234622639618e-05}.\n", | |
"[I 2019-03-30 09:12:37,268] Finished trial#1 resulted in value: 0.11771333962678909. Current best value is 0.10854867845773697 with parameters: {'optimizer': 'Adam', 'adam_alpha': 0.04735444104799611, 'weight_decay': 3.78234622639618e-05}.\n", | |
"[I 2019-03-30 09:13:53,406] Finished trial#2 resulted in value: 0.5378605797886848. Current best value is 0.10854867845773697 with parameters: {'optimizer': 'Adam', 'adam_alpha': 0.04735444104799611, 'weight_decay': 3.78234622639618e-05}.\n", | |
"[I 2019-03-30 09:15:10,697] Finished trial#3 resulted in value: 0.2515024021267891. Current best value is 0.10854867845773697 with parameters: {'optimizer': 'Adam', 'adam_alpha': 0.04735444104799611, 'weight_decay': 3.78234622639618e-05}.\n", | |
"[I 2019-03-30 09:16:42,596] Finished trial#4 resulted in value: 0.10043569654226303. Current best value is 0.10043569654226303 with parameters: {'optimizer': 'MomentumSGD', 'momentum_sgd_lr': 0.012680859786712828, 'weight_decay': 0.00036674170021996}.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n", | |
"\u001b[J1 2.98819 0.803917 0.427409 0.747296 \n", | |
"\u001b[J2 0.691177 0.571618 0.778533 0.834736 \n", | |
"\u001b[J3 0.451954 0.534379 0.847005 0.842022 \n", | |
"\u001b[J4 0.390083 0.535907 0.873641 0.836914 \n", | |
"\u001b[J5 0.344564 0.522144 0.888997 0.856671 \n", | |
"\u001b[J6 0.312317 0.43479 0.897758 0.886118 \n", | |
"\u001b[J7 0.241218 0.524394 0.919271 0.855919 \n", | |
"\u001b[J8 0.225194 0.461898 0.918818 0.884465 \n", | |
"\u001b[J9 0.197385 0.48858 0.93784 0.873648 \n", | |
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n", | |
"\u001b[J1 2.52886 0.816131 0.450195 0.723332 \n", | |
"\u001b[J2 0.715267 0.646964 0.778872 0.807016 \n", | |
"\u001b[J3 0.502833 0.513494 0.852865 0.856671 \n", | |
"\u001b[J4 0.420803 0.414364 0.872622 0.889799 \n", | |
"\u001b[J5 0.323061 0.50816 0.899089 0.865986 \n", | |
"\u001b[J6 0.364097 0.531125 0.88519 0.867488 \n", | |
"\u001b[J7 0.351918 0.510676 0.898438 0.872371 \n", | |
"\u001b[J8 0.304244 0.516988 0.911005 0.865234 \n", | |
"\u001b[J9 0.308954 0.46725 0.904552 0.865309 \n", | |
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n", | |
"\u001b[J1 2.32498 2.30721 0.0797526 0.089393 \n", | |
"\u001b[J2 2.30125 2.28621 0.115829 0.124249 \n", | |
"\u001b[J3 2.27929 2.26394 0.151693 0.185322 \n", | |
"\u001b[J4 2.25497 2.24187 0.19837 0.232647 \n", | |
"\u001b[J5 2.23301 2.21747 0.241211 0.257061 \n", | |
"\u001b[J6 2.20734 2.19275 0.286685 0.29387 \n", | |
"\u001b[J7 2.17969 2.16544 0.319987 0.333083 \n", | |
"\u001b[J8 2.153 2.13685 0.356658 0.388146 \n", | |
"\u001b[J9 2.12269 2.10593 0.393003 0.427584 \n", | |
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n", | |
"\u001b[J1 2.28918 2.25415 0.147135 0.183894 \n", | |
"\u001b[J2 2.2281 2.1976 0.213995 0.250451 \n", | |
"\u001b[J3 2.16766 2.13286 0.27832 0.328651 \n", | |
"\u001b[J4 2.10197 2.05995 0.328804 0.373197 \n", | |
"\u001b[J5 2.02155 1.97231 0.388997 0.425105 \n", | |
"\u001b[J6 1.93113 1.88031 0.445652 0.478666 \n", | |
"\u001b[J7 1.83265 1.77844 0.499349 0.551457 \n", | |
"\u001b[J8 1.73413 1.67593 0.574049 0.624624 \n", | |
"\u001b[J9 1.62991 1.57281 0.654552 0.697716 \n", | |
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n", | |
"\u001b[J1 2.14668 1.80712 0.282552 0.613732 \n", | |
"\u001b[J2 1.32171 0.791081 0.72894 0.839618 \n", | |
"\u001b[J3 0.616387 0.467579 0.84375 0.876953 \n", | |
"\u001b[J4 0.438287 0.400863 0.882133 0.887094 \n", | |
"\u001b[J5 0.370513 0.358306 0.896159 0.893254 \n", | |
"\u001b[J6 0.341263 0.344949 0.901834 0.896409 \n", | |
"\u001b[J7 0.299163 0.345071 0.921224 0.896409 \n", | |
"\u001b[J8 0.284804 0.369142 0.920856 0.887169 \n", | |
"\u001b[J9 0.266375 0.33605 0.929008 0.90009 \n" | |
] | |
} | |
], | |
"source": [ | |
"study = optuna.create_study()\n", | |
"study.optimize(objective, n_trials=5)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Number of finished trials: 5\n", | |
"Best trial:\n", | |
" Value: 0.10043569654226303\n", | |
" Params: \n", | |
" optimizer: MomentumSGD\n", | |
" momentum_sgd_lr: 0.012680859786712828\n", | |
" weight_decay: 0.00036674170021996\n" | |
] | |
} | |
], | |
"source": [ | |
"println(string(\"Number of finished trials: \", length(study.trials)))\n", | |
"println(\"Best trial:\")\n", | |
"trial = study.best_trial\n", | |
"println(string(\" Value: \", trial[3]))\n", | |
"println(\" Params: \")\n", | |
"for (key, value) in trial[6]\n", | |
" println(string(\" \", key, \": \", value))\n", | |
"end" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 1.1.0", | |
"language": "julia", | |
"name": "julia-1.1" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.1.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment