-
-
Save xiangze/3eb3b438593daef23527216e1644ba6c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"toc": true | |
}, | |
"source": [ | |
"<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n", | |
"<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"><li><span><a href=\"#Stein-discrepancyと最適化(SVGD)\" data-toc-modified-id=\"Stein-discrepancyと最適化(SVGD)-1\"><span class=\"toc-item-num\">1 </span>Stein discrepancyと最適化(SVGD)</a></span><ul class=\"toc-item\"><li><span><a href=\"#手法\" data-toc-modified-id=\"手法-1.1\"><span class=\"toc-item-num\">1.1 </span>手法</a></span><ul class=\"toc-item\"><li><span><a href=\"#目的\" data-toc-modified-id=\"目的-1.1.1\"><span class=\"toc-item-num\">1.1.1 </span>目的</a></span></li></ul></li><li><span><a href=\"#Stein-Operator\" data-toc-modified-id=\"Stein-Operator-1.2\"><span class=\"toc-item-num\">1.2 </span>Stein Operator</a></span></li><li><span><a href=\"#Stein-disrepancy\" data-toc-modified-id=\"Stein-disrepancy-1.3\"><span class=\"toc-item-num\">1.3 </span>Stein disrepancy</a></span></li><li><span><a href=\"#サンプリングの方法として(GANへの応用)\" data-toc-modified-id=\"サンプリングの方法として(GANへの応用)-1.4\"><span class=\"toc-item-num\">1.4 </span>サンプリングの方法として(GANへの応用)</a></span><ul class=\"toc-item\"><li><span><a href=\"#Amortized-SVGD\" data-toc-modified-id=\"Amortized-SVGD-1.4.1\"><span class=\"toc-item-num\">1.4.1 </span>Amortized SVGD</a></span></li><li><span><a href=\"#Amortized-KSD\" data-toc-modified-id=\"Amortized-KSD-1.4.2\"><span class=\"toc-item-num\">1.4.2 </span>Amortized KSD</a></span></li><li><span><a href=\"#Amortized-MLE\" data-toc-modified-id=\"Amortized-MLE-1.4.3\"><span class=\"toc-item-num\">1.4.3 </span>Amortized MLE</a></span></li></ul></li><li><span><a href=\"#実装\" data-toc-modified-id=\"実装-1.5\"><span class=\"toc-item-num\">1.5 </span>実装</a></span></li><li><span><a href=\"#Reference\" data-toc-modified-id=\"Reference-1.6\"><span class=\"toc-item-num\">1.6 </span>Reference</a></span></li></ul></li></ul></div>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Stein discrepancyと最適化(SVGD)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 手法" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 目的" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"KL divergenceの最小化\n", | |
"\n", | |
"$\\phi^*=\\arg\\min [ \\partial_\\epsilon KL(q_\\epsilon // p) ]$\n", | |
"\n", | |
"を行いたい\n", | |
"\n", | |
"Stein Operator $A_p$(後述)を使って勾配を\n", | |
"\n", | |
"$\\nabla_\\epsilon KL(q_\\epsilon // p)|_{\\epsilon=0}=-E_{x \\sim q}[A_p \\phi(x)]$\n", | |
"\n", | |
"$ (x' = x + \\phi(x) )$\n", | |
"\n", | |
"としてこれを用いて計算する。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Stein Operator" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Stein operatorを\n", | |
"\n", | |
"$A_p \\equiv f(x)\\nabla_x \\log p(x)+ \\nabla_x f(x)$\n", | |
"\n", | |
"と定義する以下のように導出される。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"分布p(x)に対して\n", | |
"\n", | |
"$\\lim_{|x| <- \\infty} f(x)p(x)=0$\n", | |
"\n", | |
"となるようなf(x)に対して恒等式(Stein identity)\n", | |
"\n", | |
"$E_{x \\sim p}[ f(x)\\nabla_x \\log p(x)+ \\nabla_x f(x) ]=0$\n", | |
"\n", | |
"が成り立つ(部分積分から)。分布p(x)と近いq(x)に対してこれを\n", | |
"\n", | |
"$E_{x\\sim q}[\\nabla_x \\log p(x)+ \\nabla_x f(x)] \\equiv E_{x\\sim q}[A_p f(x)]$\n", | |
"\n", | |
"$A_p f(x) \\equiv \\nabla_x \\log p(x)+ \\nabla_x f(x)$\n", | |
"\n", | |
"と定義する(記号もあいまって共変微分のように見える)。また\n", | |
"\n", | |
"$E_{x \\sim q}[A_p f(x)]=E_{x \\sim q}[A_p f(x)]-E_{x \\sim q}[A_q f(x)] =E_{x \\sim q}[ f(x)(\\nabla \\log p(x) -\\nabla \\log q(x))]$\n", | |
"\n", | |
"という関係式も成り立つ。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"## Stein disrepancy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Stein Operator$A _p$を用いて\n", | |
"\n", | |
" $\\sqrt{S(p,q)} \\equiv \\max_{f \\in F} E_{x\\sim q} [A_p f(x)]$\n", | |
" \n", | |
" とStein disrepancyを定義する。\n", | |
" \n", | |
" これをどのように計算するのか" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"関数f(x)を\n", | |
"\n", | |
"$f(x) = \\sum_i \\omega_i f_i(x)$\n", | |
"\n", | |
"と展開すると\n", | |
"\n", | |
"$E_q[A_p f(x) ]= E_q[A_p \\sum_i \\omega_i f_i(x)]= \\sum \\omega_i \\beta_i$ \n", | |
"\n", | |
"$\\beta_i \\equiv E_{x\\sim q}[A_p f_i(x)]$\n", | |
"\n", | |
"さらにカーネルを使って\n", | |
"\n", | |
"$f(x)=\\sum_i \\omega_i k(x,x_i)$\n", | |
"\n", | |
"とかくと$ <f,g>_H =\\sum_{i,j} w_i v_j k(x_i,x_j)$という形に書けることから\n", | |
"\n", | |
"$E_{x \\sim q}[A_p f(x)]=<f(\\cdot),E_{x \\sim q}[A_p k(\\cdot,x)]>$\n", | |
"\n", | |
"添字iに関する和を有限個とすることができて(基底定理?)\n", | |
"\n", | |
"$A_p f(x)=\\frac{1}{n}\\sum_{i=0}^n [ f(x_i) k(x_i,x) \\nabla_x \\log p(x_i)+ \\nabla_{x_i}k(x_i,x) ]$\n", | |
"\n", | |
"となる。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"結局$\\phi^*=\\arg\\min [ \\partial_\\epsilon KL(q_\\epsilon // p) ]$に対しては\n", | |
"\n", | |
"$\\phi$がxの関数であるとしてparticle iに対し\n", | |
"\n", | |
"$x_i^{l+1} \\leftarrow x_i^l + \\epsilon_l \\phi^*(x_i^l)$\n", | |
"\n", | |
"$\\phi^*(x)=\\frac{1}{n}\\sum_{j=0}^n [ k(x_j^l,x) \\nabla_x \\log p(x_j^l)+ \\nabla_{x_j}k(x_j^l,x) ]$\n", | |
"\n", | |
"という更新式を使う。$k(x_i,x_j)$はカーネルで行列を使って表すことができる(未知の値$k(x_i,x)$に対しては関数として実装する)。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## サンプリングの方法として(GANへの応用)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"論文 https://arxiv.org/abs/1611.01722" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Amortized SVGD" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"乱数$\\xi_i$に対し\n", | |
"$x_i=f_n(\\xi_i) $\n", | |
"であるとする更新式は\n", | |
"\n", | |
"$ \\Delta x_i =E_{ x\\sim \\{x_i\\} } [\\nabla p(x) k(x,x_i) + \\nabla_x k(x,x_i)]$\n", | |
"\n", | |
"$x_i'=x_i+\\epsilon \\Delta x_i$\n", | |
"\n", | |
"関数$f(\\eta;\\xi_i)$のパラメータ(として)$\\eta$があるとすると\n", | |
"\n", | |
"$\\eta=\\eta+\\epsilon \\sum_i \\partial_\\eta f_\\eta (\\xi_i)\\Delta x_i$\n", | |
"\n", | |
"としてfを更新していくのがSVGD" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"特にKL divergence$KL(q_\\eta//p)$に関して書き直すと\n", | |
"\n", | |
"$\\nabla_\\eta KL(q_\\eta//p)=-E_{\\xi \\sim q_0}[\\partial_\\eta f(\\eta ;\\xi)(\\nabla_x \\log p(x) - \\nabla_x \\log q_\\eta(x) ] $\n", | |
"\n", | |
"となる。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"さらに$\\bar{\\Delta}x_i= \\nabla_x \\log p(x_i) - \\nabla_x \\log q_\\eta(x_i) $\n", | |
"\n", | |
"と書くと\n", | |
"\n", | |
"$\\Delta x_i \\simeq E_{x \\sim q}[k(x,x_i) \\nabla_x \\log p(x)+ \\nabla_{x}k(x,x_i) ] $\n", | |
"$= E_{x \\sim q}[k(x,x_i)( \\nabla_x \\log p(x)+ \\nabla_{x} \\log q(x)] $\n", | |
"$= E_{x \\sim q}[\\bar{\\Delta}x_i k(x,x_i]$\n", | |
"\n", | |
"という近似的関係が成り立つ($\\Delta x_i$は$\\bar{\\Delta} x_i$をカーネル重み付けしたもの" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Amortized KSD" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"$\\kappa_p(x,x')=\\nabla_x \\log p(x) \\nabla_x \\log p(x') k(x,x') +\\nabla_x \\log p(x) \\nabla_x k(x,x') + \\nabla_x \\log p(x') \\nabla_x k(x,x')+\n", | |
"\\nabla_x \\nabla_x' k(x,x')$\n", | |
"\n", | |
"これによって\n", | |
"$D^2(q//p)=E_{x,x'\\sim q}[\\kappa_p(x,x')]$\n", | |
"\n", | |
"の最小化を行う。\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Amortized MLE" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"$p(x|\\theta)=\\exp(-\\phi(x,\\theta)-\\Phi(\\theta)$\n", | |
"という形の関数の最適化\n", | |
"\n", | |
"$\\xi_i \\sim q_{\\xi}, x_i=f_\\eta(\\xi_i)$ \n", | |
"として\n", | |
"\n", | |
"$\\eta \\leftarrow \\epsilon \\sum_i \\partial_\\eta f_\\eta (\\xi_i) \\Delta x_i$\n", | |
"\n", | |
"$\\{x_{obs}\\}$ を新たに引く\n", | |
"\n", | |
"$\\theta \\leftarrow \\theta -E_{obs}[\\nabla_\\theta \\phi(x,\\theta)]+E_\\eta[\\nabla_\\theta \\phi(x,\\theta)]$\n", | |
"\n", | |
"を交互に繰り返す。\n", | |
"\n", | |
"\n", | |
"GANの場合は\n", | |
"\n", | |
"$p(x|\\theta) \\propto \\exp(-||x-D(E(x;\\theta);\\theta)||)$\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"## 実装" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"オリジナル実装はTheanoで書かれている。\n", | |
"\n", | |
"- https://github.com/DartML/Stein-Variational-Gradient-Descent/tree/master/python\n", | |
"- https://github.com/DartML/Stein-Variational-Gradient-Descent/tree/master/python\n", | |
"\n", | |
"SteinGAN もベタコード\n", | |
"- https://github.com/DartML/Stein-Variational-Gradient-Descent/blob/master/python/bayesian_logistic_regression.py\n", | |
"- https://github.com/DartML/SteinGAN/tree/master/lib\n", | |
"- https://github.com/ChunyuanLI/SVGD/blob/master/demo_svgd.ipynb\n", | |
"Edward, tensorflowのoptimizerとして組み込めるかもしれません。確率分布pとその微分を使うのでEdwardやtf.distributionsを使う必要がありそうです。particle分のパラメータを複製するなどのアイデアがありえます。\n", | |
"\n", | |
"- R package https://cran.r-project.org/web/packages/KSD/\n", | |
"\n", | |
"- Julia 実装 \n", | |
" - https://github.com/krisrs1128/readings/tree/master/svgd\n", | |
" - https://github.com/jgorham/stein_discrepancy/tree/master/src/experiments kernelは決め打ちです。自動微分を使えば任意のkernelのgradientが計算できるかと思ったのですが、Juliaのautodiffは偏微分をやるのが難しそうだったので断念しました\n", | |
"\n", | |
"- Tensorflow\n", | |
" - https://github.com/xiangze/SVGD" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Reference" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"- [Stein’s Method for Practical Machine Learning](https://www.cs.dartmouth.edu/~qliu/stein.html) ここに多くの論文がまとめられている\n", | |
"- [Probabilistic Learning and Inference Using Stein Discrepancy](https://www.cs.dartmouth.edu/~qliu/PDF/steinslides16.pdf) 概要スライド\n", | |
"- [Stein's Method](https://sites.google.com/site/steinsmethod/home)\n", | |
"- [機械学習論文読みメモ_14](https://qiita.com/festa78/items/1813377b59fd3685c119)\n", | |
"- http://chunyuan.li/\n", | |
"- 論文\n", | |
" - Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm [arxiv](https://arxiv.org/abs/1608.04471)\n", | |
" - LEARNING TO DRAW SAMPLES: WITH APPLICATION TO AMORTIZED MLE FOR GENERATIVE ADVERSARIAL LEARNING [pdf](https://arxiv.org/pdf/1611.01722.pdf)\n", | |
" - A Kernelized Stein Discrepancy for Goodness-of-fit Tests and Model Evaluation[arxiv](https://arxiv.org/abs/1602.03253)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
}, | |
"latex_envs": { | |
"LaTeX_envs_menu_present": true, | |
"autoclose": false, | |
"autocomplete": true, | |
"bibliofile": "biblio.bib", | |
"cite_by": "apalike", | |
"current_citInitial": 1, | |
"eqLabelWithNumbers": true, | |
"eqNumInitial": 1, | |
"hotkeys": { | |
"equation": "Ctrl-E", | |
"itemize": "Ctrl-I" | |
}, | |
"labels_anchors": false, | |
"latex_user_defs": false, | |
"report_style_numbering": false, | |
"user_envs_cfg": false | |
}, | |
"toc": { | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"toc_cell": true, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment