Skip to content

Instantly share code, notes, and snippets.

@regonn
Created March 30, 2019 01:30
Show Gist options
  • Save regonn/9eb7ff68ac3a5a3201523bd6267953dd to your computer and use it in GitHub Desktop.
Save regonn/9eb7ff68ac3a5a3201523bd6267953dd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"using PyCall"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PyObject <module 'optuna' from '/home/regonn/.local/lib/python3.7/site-packages/optuna/__init__.py'>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chainer = pyimport(\"chainer\")\n",
"math = pyimport(\"math\")\n",
"random = pyimport(\"random\")\n",
"np = pyimport(\"numpy\")\n",
"pd = pyimport(\"pandas\")\n",
"optuna = pyimport(\"optuna\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:1\n",
"└ @ Core In[3]:1\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:2\n",
"└ @ Core In[3]:2\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:3\n",
"└ @ Core In[3]:3\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:4\n",
"└ @ Core In[3]:4\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:5\n",
"└ @ Core In[3]:5\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:6\n",
"└ @ Core In[3]:6\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:7\n",
"└ @ Core In[3]:7\n",
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = top-level scope at In[3]:8\n",
"└ @ Core In[3]:8\n"
]
},
{
"data": {
"text/plain": [
"PyObject <module 'chainer.training.extensions' from '/home/regonn/.local/lib/python3.7/site-packages/chainer/training/extensions/__init__.py'>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"const Chain = chainer[:Chain]\n",
"const Variable = chainer[:Variable]\n",
"const F = chainer[:functions]\n",
"const L = chainer[:links]\n",
"const iterators = chainer[:iterators]\n",
"const optimizers = chainer[:optimizers]\n",
"const training = chainer[:training]\n",
"const extensions = training[:extensions]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N_TRAIN_EXAMPLES = 3000\n",
"N_TEST_EXAMPLES = 1000\n",
"BATCHSIZE = 128\n",
"EPOCH = 10"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PyObject <class 'MLP'>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@pydef mutable struct MLP <: Chain\n",
" function __init__(self, n_mid_units=100, n_out=10)\n",
" pybuiltin(:super)(MLP, self)[:__init__]()\n",
" @pywith self.init_scope() begin\n",
" self.l1 = L.Linear(py\"None\", n_mid_units)\n",
" self.l2 = L.Linear(py\"None\", n_mid_units)\n",
" self.l3 = L.Linear(py\"None\", n_out)\n",
" end\n",
" end\n",
"\n",
" function forward(self, x)\n",
" h1 = F.relu(self.l1(x))\n",
" h2 = F.relu(self.l2(h1))\n",
" return self.l3(h2)\n",
" end\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"create_optimizer (generic function with 1 method)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function create_optimizer(trial, model)\n",
" optimizer_name = trial.suggest_categorical(\"optimizer\", [\"Adam\", \"MomentumSGD\"])\n",
" if optimizer_name == \"Adam\"\n",
" adam_alpha = trial.suggest_loguniform(\"adam_alpha\", 1e-5, 1e-1)\n",
" optimizer = chainer.optimizers.Adam(alpha=adam_alpha)\n",
" else\n",
" momentum_sgd_lr = trial.suggest_loguniform(\"momentum_sgd_lr\", 1e-5, 1e-1)\n",
" optimizer = chainer.optimizers.MomentumSGD(lr=momentum_sgd_lr)\n",
" end\n",
" \n",
" weight_decay = trial.suggest_loguniform(\"weight_decay\", 1e-10, 1e-3)\n",
" optimizer.setup(model)\n",
" optimizer.add_hook(chainer.optimizer.WeightDecay(weight_decay))\n",
" return optimizer\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"objective (generic function with 1 method)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function objective(trial)\n",
" model = MLP()\n",
" model = L.Classifier(model)\n",
" optimizer = create_optimizer(trial, model)\n",
" \n",
" rng = np.random.RandomState(0)\n",
" train, test = chainer.datasets.get_mnist()\n",
" train = chainer.datasets.SubDataset(train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(length(train)))\n",
" test = chainer.datasets.SubDataset(test, 0, N_TEST_EXAMPLES, order=rng.permutation(length(test)))\n",
" train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE)\n",
" test_iter = chainer.iterators.SerialIterator(test, BATCHSIZE, repeat=false, shuffle=false)\n",
" \n",
" updater = chainer.training.StandardUpdater(train_iter, optimizer)\n",
" trainer = chainer.training.Trainer(updater, (EPOCH, \"epoch\"))\n",
" trainer.extend(chainer.training.extensions.Evaluator(test_iter, model))\n",
" log_report_extension = chainer.training.extensions.LogReport(log_name=py\"None\")\n",
" trainer.extend(chainer.training.extensions.PrintReport([\"epoch\", \"main/loss\", \"validation/main/loss\", \"main/accuracy\", \"validation/main/accuracy\"]))\n",
" trainer.extend(log_report_extension)\n",
" \n",
" trainer.run()\n",
" \n",
" log_last = log_report_extension.log[end]\n",
" for (key, value) in log_last\n",
" trial.set_user_attr(key, value)\n",
" end\n",
" \n",
" val_err = 1.0 - log_report_extension.log[end][\"validation/main/accuracy\"]\n",
" return val_err\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Warning: `getindex(o::PyObject, s::Symbol)` is deprecated in favor of dot overloading (`getproperty`) so elements should now be accessed as e.g. `o.s` instead of `o[:s]`.\n",
"│ caller = macro expansion at In[5]:3 [inlined]\n",
"└ @ Core ./In[5]:3\n",
"[I 2019-03-30 09:11:05,273] Finished trial#0 resulted in value: 0.10854867845773697. Current best value is 0.10854867845773697 with parameters: {'optimizer': 'Adam', 'adam_alpha': 0.04735444104799611, 'weight_decay': 3.78234622639618e-05}.\n",
"[I 2019-03-30 09:12:37,268] Finished trial#1 resulted in value: 0.11771333962678909. Current best value is 0.10854867845773697 with parameters: {'optimizer': 'Adam', 'adam_alpha': 0.04735444104799611, 'weight_decay': 3.78234622639618e-05}.\n",
"[I 2019-03-30 09:13:53,406] Finished trial#2 resulted in value: 0.5378605797886848. Current best value is 0.10854867845773697 with parameters: {'optimizer': 'Adam', 'adam_alpha': 0.04735444104799611, 'weight_decay': 3.78234622639618e-05}.\n",
"[I 2019-03-30 09:15:10,697] Finished trial#3 resulted in value: 0.2515024021267891. Current best value is 0.10854867845773697 with parameters: {'optimizer': 'Adam', 'adam_alpha': 0.04735444104799611, 'weight_decay': 3.78234622639618e-05}.\n",
"[I 2019-03-30 09:16:42,596] Finished trial#4 resulted in value: 0.10043569654226303. Current best value is 0.10043569654226303 with parameters: {'optimizer': 'MomentumSGD', 'momentum_sgd_lr': 0.012680859786712828, 'weight_decay': 0.00036674170021996}.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n",
"\u001b[J1 2.98819 0.803917 0.427409 0.747296 \n",
"\u001b[J2 0.691177 0.571618 0.778533 0.834736 \n",
"\u001b[J3 0.451954 0.534379 0.847005 0.842022 \n",
"\u001b[J4 0.390083 0.535907 0.873641 0.836914 \n",
"\u001b[J5 0.344564 0.522144 0.888997 0.856671 \n",
"\u001b[J6 0.312317 0.43479 0.897758 0.886118 \n",
"\u001b[J7 0.241218 0.524394 0.919271 0.855919 \n",
"\u001b[J8 0.225194 0.461898 0.918818 0.884465 \n",
"\u001b[J9 0.197385 0.48858 0.93784 0.873648 \n",
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n",
"\u001b[J1 2.52886 0.816131 0.450195 0.723332 \n",
"\u001b[J2 0.715267 0.646964 0.778872 0.807016 \n",
"\u001b[J3 0.502833 0.513494 0.852865 0.856671 \n",
"\u001b[J4 0.420803 0.414364 0.872622 0.889799 \n",
"\u001b[J5 0.323061 0.50816 0.899089 0.865986 \n",
"\u001b[J6 0.364097 0.531125 0.88519 0.867488 \n",
"\u001b[J7 0.351918 0.510676 0.898438 0.872371 \n",
"\u001b[J8 0.304244 0.516988 0.911005 0.865234 \n",
"\u001b[J9 0.308954 0.46725 0.904552 0.865309 \n",
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n",
"\u001b[J1 2.32498 2.30721 0.0797526 0.089393 \n",
"\u001b[J2 2.30125 2.28621 0.115829 0.124249 \n",
"\u001b[J3 2.27929 2.26394 0.151693 0.185322 \n",
"\u001b[J4 2.25497 2.24187 0.19837 0.232647 \n",
"\u001b[J5 2.23301 2.21747 0.241211 0.257061 \n",
"\u001b[J6 2.20734 2.19275 0.286685 0.29387 \n",
"\u001b[J7 2.17969 2.16544 0.319987 0.333083 \n",
"\u001b[J8 2.153 2.13685 0.356658 0.388146 \n",
"\u001b[J9 2.12269 2.10593 0.393003 0.427584 \n",
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n",
"\u001b[J1 2.28918 2.25415 0.147135 0.183894 \n",
"\u001b[J2 2.2281 2.1976 0.213995 0.250451 \n",
"\u001b[J3 2.16766 2.13286 0.27832 0.328651 \n",
"\u001b[J4 2.10197 2.05995 0.328804 0.373197 \n",
"\u001b[J5 2.02155 1.97231 0.388997 0.425105 \n",
"\u001b[J6 1.93113 1.88031 0.445652 0.478666 \n",
"\u001b[J7 1.83265 1.77844 0.499349 0.551457 \n",
"\u001b[J8 1.73413 1.67593 0.574049 0.624624 \n",
"\u001b[J9 1.62991 1.57281 0.654552 0.697716 \n",
"epoch main/loss validation/main/loss main/accuracy validation/main/accuracy\n",
"\u001b[J1 2.14668 1.80712 0.282552 0.613732 \n",
"\u001b[J2 1.32171 0.791081 0.72894 0.839618 \n",
"\u001b[J3 0.616387 0.467579 0.84375 0.876953 \n",
"\u001b[J4 0.438287 0.400863 0.882133 0.887094 \n",
"\u001b[J5 0.370513 0.358306 0.896159 0.893254 \n",
"\u001b[J6 0.341263 0.344949 0.901834 0.896409 \n",
"\u001b[J7 0.299163 0.345071 0.921224 0.896409 \n",
"\u001b[J8 0.284804 0.369142 0.920856 0.887169 \n",
"\u001b[J9 0.266375 0.33605 0.929008 0.90009 \n"
]
}
],
"source": [
"study = optuna.create_study()\n",
"study.optimize(objective, n_trials=5)\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of finished trials: 5\n",
"Best trial:\n",
" Value: 0.10043569654226303\n",
" Params: \n",
" optimizer: MomentumSGD\n",
" momentum_sgd_lr: 0.012680859786712828\n",
" weight_decay: 0.00036674170021996\n"
]
}
],
"source": [
"println(string(\"Number of finished trials: \", length(study.trials)))\n",
"println(\"Best trial:\")\n",
"trial = study.best_trial\n",
"println(string(\" Value: \", trial[3]))\n",
"println(\" Params: \")\n",
"for (key, value) in trial[6]\n",
" println(string(\" \", key, \": \", value))\n",
"end"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.1.0",
"language": "julia",
"name": "julia-1.1"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.1.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment