Last active
October 23, 2022 05:46
-
-
Save terasakisatoshi/73b872a53d72175c8832c212ec8496d3 to your computer and use it in GitHub Desktop.
Optimisers.jl 入門
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "c35bcfc1-fb1c-49b1-b919-a249c5fb442a", | |
"metadata": {}, | |
"source": [ | |
"# Optimisers.jl 入門\n", | |
"\n", | |
"ニューラルネットワークの学習などで用いられる最適化アルゴリズムを提供するパッケージ Optimisers.jl の使い方を紹介する.\n", | |
"機械学習ライブラリ Flux.jl 内部でも提供されているが, 設計思想の変更(implicit parameter から explicit parameter への移行など)により Optimisers.jl の方を用いることが望まれる動きになってるようだ.\n", | |
"\n", | |
"Julia 1.8 で動作確認" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "ab744a83-8049-4e12-8f24-2ab9ac5c2074", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using Random\n", | |
"using Statistics\n", | |
"\n", | |
"using Functors\n", | |
"using Optimisers\n", | |
"using ConcreteStructs\n", | |
"using Zygote" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6b076210-67d8-48f4-aaee-7f2af4f25e0f", | |
"metadata": {}, | |
"source": [ | |
"自作構造体 `Affine` をモデルとしたパラメータの学習を行う. \n", | |
"このモデルは 3 次元ベクトルを入力とし 2 次元ベクトルを出力するアフィン変換であるとする. \n", | |
"\n", | |
"$$\n", | |
"\\mathrm{Affine}(W, b): x \\mapsto W x + b \\quad \\textrm{for}\\ \\ x\\ \\in \\mathbb{R}^3\n", | |
"$$\n", | |
"\n", | |
"ここで, 学習パラメータは $W\\in \\mathbb{R}^{2\\times 3}, b \\in \\mathbb{R}^2$ であるとする" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "53894d58-1527-4f09-a1b1-6841e4ffd849", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@concrete struct Affine\n", | |
" W\n", | |
" b\n", | |
"end\n", | |
"\n", | |
"function Affine(rng::Random.AbstractRNG)\n", | |
" W = rand(rng, Float32, 2, 3)\n", | |
" b = rand(rng, Float32, 2)\n", | |
" Affine(W, b)\n", | |
"end\n", | |
"\n", | |
"function Affine(;seed::Int)\n", | |
" rng = Xoshiro(seed)\n", | |
" Affine(rng)\n", | |
"end\n", | |
"\n", | |
"@functor Affine # これをすることで学習すべきパラメータを把握する.\n", | |
"#= \n", | |
"特定のフィールドを学習すべきパラメータとして指定したい場合は `Optimisers.trainable`\n", | |
"を実装すれば良い\n", | |
"=#\n", | |
"\n", | |
"function (aff::Affine)(x::AbstractVector)\n", | |
" aff.W * x + aff.b\n", | |
"end\n", | |
"\n", | |
"function (aff::Affine)(x::AbstractMatrix) # バッチ処理\n", | |
" ndata = size(x, 2)\n", | |
" [aff(x[:, i]) for i in 1:ndata]\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f76d208f-0b1f-4ea6-8e8f-67be3c5a44d7", | |
"metadata": {}, | |
"source": [ | |
"`@concrete` は `ConcreteStructs.jl` が提供するマクロである.\n", | |
"\n", | |
"```julia\n", | |
"struct Affine{T1, T2}\n", | |
" W::T1\n", | |
" b::T2\n", | |
" function Affine(a::T1, b::T2) where {T1, T2}\n", | |
" return new{T1, T2}(a, b)\n", | |
" end\n", | |
"end\n", | |
"```\n", | |
"\n", | |
"のように構造体のフィールドが具象型となるように設計する過程を `@concrete` によって実現してくれる. `@macroexpand` でマクロがどのように働いているのかを示すことができる. 例えば下記のセルを実行してみよう:\n", | |
"\n", | |
"```julia\n", | |
"@macroexpand @concrete struct Affine\n", | |
" W\n", | |
" b\n", | |
"end\n", | |
"```\n", | |
"\n", | |
"概ね下記のような出力になる. (本質的な部分のみを記載)\n", | |
"\n", | |
"```julia\n", | |
"struct Affine{__T_W, __T_b} <: Any\n", | |
" W::__T_a\n", | |
" b::__T_b\n", | |
" function Affine(a::__T_a, b::__T_b) where {__T_a, __T_b}\n", | |
" return new{__T_a, __T_b}(a, b)\n", | |
" end\n", | |
"end\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "030fd2cb-0386-462c-9a95-5f02656c0c24", | |
"metadata": {}, | |
"source": [ | |
"ここでは簡単のため, 特定の $W=W_0$, $b=b_0$ で得られたモデルの出力を学習データとして $\\mathrm{Affine}(W, b)$ が $\\mathrm{Affine}(W_0, b_0)$ を模倣するように学習をさせてみよう." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "c07b5d8b-4015-4c4a-a1d6-f39263f4cedb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"traindata (generic function with 2 methods)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"function traindata(ndata=100)\n", | |
" rng = Random.Xoshiro(12345)\n", | |
" gt = Affine(rng)\n", | |
" ndata=100\n", | |
" xtrain = rand(rng, Float32, 3, ndata)\n", | |
" ytrain = gt(xtrain)\n", | |
" return xtrain, ytrain\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "31709df6-2e22-4119-a711-9839eed58c1a", | |
"metadata": {}, | |
"source": [ | |
"損失関数はナイーブに点 $y$ と $\\hat{y}$ (モデルの出力) の距離を計算し,バッチデータに関しての平均を返却するように設計してみよう." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "d86101bb-9de9-4623-ac66-5131111f58d7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"loss (generic function with 1 method)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"function loss(ŷbatch::Vector{Vector{T}}, ybatch::Vector{Vector{T}}) where T# (datadim, batchsize)\n", | |
" mean(zip(ŷbatch, ybatch)) do (ŷ, y)\n", | |
" (ŷ - y) .^ 2 |> sum\n", | |
" end\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c5e6edea-7fd7-4246-99ab-06bb1e36bbcf", | |
"metadata": {}, | |
"source": [ | |
"`rule` で最適化アルゴリズムを設定. ここでは単純な勾配法を用いる. `setup` によって最適化アルゴリズムとモデルのパラメータのステータスを紐づける.\n", | |
"\n", | |
"```julia\n", | |
"rule = Optimisers.Descent(0.1f0)\n", | |
"state = Optimisers.setup(rule, model)\n", | |
"```\n", | |
"\n", | |
"勾配計算自体は `Zygote.gradient` に任せておく.エポック数を 100 として学習ループを回し損失値が下がっている様子を観察してみよう." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "4f81773e-6b64-4da8-929e-a54e08f291b3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"compute_gradient (generic function with 1 method)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"function compute_gradient(model, xtrain, ytrain)\n", | |
" # カンマをつける理由: 戻り値がTupleとして返ってるので最初の要素を取り出したいから\n", | |
" ∇, = Zygote.gradient(model, xtrain, ytrain) do m, x, y # calculate the gradients\n", | |
" ŷ = m(x)\n", | |
" loss(ŷ, y)\n", | |
" end\n", | |
" ∇\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "0c00eb2b-6192-40f3-aa87-51c90f26ff15", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"train (generic function with 1 method)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"function train()\n", | |
" model = Affine(seed=54321)\n", | |
" \n", | |
" xtrain, ytrain = traindata()\n", | |
" \n", | |
" rule = Optimisers.Descent(0.1f0)\n", | |
" state = Optimisers.setup(rule, model)\n", | |
" \n", | |
" epochs = 500\n", | |
" for epoch in 1:epochs\n", | |
" ∇model = compute_gradient(model, xtrain, ytrain)\n", | |
" state, model = Optimisers.update(state, model, ∇model)\n", | |
" ŷ = model(xtrain)\n", | |
" mod(epoch, 50) == 0 && @info \"loss\" epoch loss(ŷ, ytrain)\n", | |
" end\n", | |
" return model\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "8cfa3ebb-2739-4244-9836-ef507819df29", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: loss\n", | |
"│ epoch = 50\n", | |
"│ loss(ŷ, ytrain) = 0.03181389\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n", | |
"┌ Info: loss\n", | |
"│ epoch = 100\n", | |
"│ loss(ŷ, ytrain) = 0.009836106\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: loss\n", | |
"│ epoch = 150\n", | |
"│ loss(ŷ, ytrain) = 0.0034513818\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n", | |
"┌ Info: loss\n", | |
"│ epoch = 200\n", | |
"│ loss(ŷ, ytrain) = 0.001299851\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: loss\n", | |
"│ epoch = 250\n", | |
"│ loss(ŷ, ytrain) = 0.00050675025\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n", | |
"┌ Info: loss\n", | |
"│ epoch = 300\n", | |
"│ loss(ŷ, ytrain) = 0.00020071033\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: loss\n", | |
"│ epoch = 350\n", | |
"│ loss(ŷ, ytrain) = 8.006227e-5\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n", | |
"┌ Info: loss\n", | |
"│ epoch = 400\n", | |
"│ loss(ŷ, ytrain) = 3.2037726e-5\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: loss\n", | |
"│ epoch = 450\n", | |
"│ loss(ŷ, ytrain) = 1.283851e-5\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n", | |
"┌ Info: loss\n", | |
"│ epoch = 500\n", | |
"│ loss(ŷ, ytrain) = 5.148223e-6\n", | |
"└ @ Main /Users/terasaki/Downloads/73b872a53d72175c8832c212ec8496d3-b7ef7c234952c3fe2d23011abcbcfce9d969ac61/study_optimisers.ipynb:14\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"Affine{Matrix{Float32}, Vector{Float32}}(Float32[0.99506706 0.32216284 0.077527404; 0.94032687 0.8622485 0.33671075], Float32[0.077655435, 0.14226691])" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"model = train()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c7cb9e99-2ae6-4ba3-b6a0-8be09e955114", | |
"metadata": {}, | |
"source": [ | |
"- `@functor Affine` を省略してコードを動かすとパラメータが更新されず結果として損失関数の変化も見られないことが観測できる.\n", | |
"- 今回の実装は自作構造体をモデルとみなしどのように学習コードを書けばよいかを説明した. Affine 程度であれば Flux.jl または Lux.jl の Dense を使えば実用上問題ない.\n", | |
"- Lux.jl は Flux.jl の次世代版という位置付けになっている. 馴染みのない読者はまず Flux.jl のドキュメントを読んで Flux.jl 自体の使い方を学ぶと良い.\n", | |
"- Zygote.gradient って結構 type-instable になりがちなんだけれどどうすればいいんですかね?(詳しい方いたら教えてクレメンス) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "36cf3010", | |
"metadata": {}, | |
"source": [ | |
"# 結果の確認\n", | |
"\n", | |
"学習がうまくいったかを確認してみよう." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "b4bdfb72", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2×3 Matrix{Float32}:\n", | |
" 0.997473 0.324766 0.0789543\n", | |
" 0.944791 0.866895 0.339611" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"rng = Random.Xoshiro(12345)\n", | |
"gt = Affine(rng)\n", | |
"gt.W" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "b763ba37", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2×3 Matrix{Float32}:\n", | |
" 0.995067 0.322163 0.0775274\n", | |
" 0.940327 0.862248 0.336711" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"model.W" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "550fcf55", | |
"metadata": {}, | |
"source": [ | |
"数値的な誤差はあれどにたような傾向は得られると思う.\n", | |
"\n", | |
"```julia\n", | |
"# 手元だと gt.W\n", | |
"2×3 Matrix{Float32}:\n", | |
" 0.997473 0.324766 0.0789543\n", | |
" 0.944791 0.866895 0.339611\n", | |
"```\n", | |
"\n", | |
"```julia\n", | |
"# modeol.W\n", | |
"2×3 Matrix{Float32}:\n", | |
" 0.995067 0.322163 0.0775274\n", | |
" 0.940327 0.862248 0.336711\n", | |
"```" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 1.8.2", | |
"language": "julia", | |
"name": "julia-1.8" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.8.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment