Skip to content

Instantly share code, notes, and snippets.

@xiangze
Last active December 9, 2018 05:42
Show Gist options
  • Save xiangze/3eb3b438593daef23527216e1644ba6c to your computer and use it in GitHub Desktop.
Save xiangze/3eb3b438593daef23527216e1644ba6c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"toc": true
},
"source": [
"<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
"<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"><li><span><a href=\"#Stein-discrepancyと最適化(SVGD)\" data-toc-modified-id=\"Stein-discrepancyと最適化(SVGD)-1\"><span class=\"toc-item-num\">1&nbsp;&nbsp;</span>Stein discrepancyと最適化(SVGD)</a></span><ul class=\"toc-item\"><li><span><a href=\"#手法\" data-toc-modified-id=\"手法-1.1\"><span class=\"toc-item-num\">1.1&nbsp;&nbsp;</span>手法</a></span><ul class=\"toc-item\"><li><span><a href=\"#目的\" data-toc-modified-id=\"目的-1.1.1\"><span class=\"toc-item-num\">1.1.1&nbsp;&nbsp;</span>目的</a></span></li></ul></li><li><span><a href=\"#Stein-Operator\" data-toc-modified-id=\"Stein-Operator-1.2\"><span class=\"toc-item-num\">1.2&nbsp;&nbsp;</span>Stein Operator</a></span></li><li><span><a href=\"#Stein-disrepancy\" data-toc-modified-id=\"Stein-disrepancy-1.3\"><span class=\"toc-item-num\">1.3&nbsp;&nbsp;</span>Stein disrepancy</a></span></li><li><span><a href=\"#サンプリングの方法として(GANへの応用)\" data-toc-modified-id=\"サンプリングの方法として(GANへの応用)-1.4\"><span class=\"toc-item-num\">1.4&nbsp;&nbsp;</span>サンプリングの方法として(GANへの応用)</a></span><ul class=\"toc-item\"><li><span><a href=\"#Amortized-SVGD\" data-toc-modified-id=\"Amortized-SVGD-1.4.1\"><span class=\"toc-item-num\">1.4.1&nbsp;&nbsp;</span>Amortized SVGD</a></span></li><li><span><a href=\"#Amortized-KSD\" data-toc-modified-id=\"Amortized-KSD-1.4.2\"><span class=\"toc-item-num\">1.4.2&nbsp;&nbsp;</span>Amortized KSD</a></span></li><li><span><a href=\"#Amortized-MLE\" data-toc-modified-id=\"Amortized-MLE-1.4.3\"><span class=\"toc-item-num\">1.4.3&nbsp;&nbsp;</span>Amortized MLE</a></span></li></ul></li><li><span><a href=\"#実装\" data-toc-modified-id=\"実装-1.5\"><span class=\"toc-item-num\">1.5&nbsp;&nbsp;</span>実装</a></span></li><li><span><a href=\"#Reference\" data-toc-modified-id=\"Reference-1.6\"><span class=\"toc-item-num\">1.6&nbsp;&nbsp;</span>Reference</a></span></li></ul></li></ul></div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Stein discrepancyと最適化(SVGD)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 手法"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 目的"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"KL divergenceの最小化\n",
"\n",
"$\\phi^*=\\arg\\min [ \\partial_\\epsilon KL(q_\\epsilon // p) ]$\n",
"\n",
"を行いたい\n",
"\n",
"Stein Operator $A_p$(後述)を使って勾配を\n",
"\n",
"$\\nabla_\\epsilon KL(q_\\epsilon // p)|_{\\epsilon=0}=-E_{x \\sim q}[A_p \\phi(x)]$\n",
"\n",
"$ (x' = x + \\phi(x) )$\n",
"\n",
"としてこれを用いて計算する。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Stein Operator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Stein operatorを\n",
"\n",
"$A_p \\equiv f(x)\\nabla_x \\log p(x)+ \\nabla_x f(x)$\n",
"\n",
"と定義する以下のように導出される。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"分布p(x)に対して\n",
"\n",
"$\\lim_{|x| <- \\infty} f(x)p(x)=0$\n",
"\n",
"となるようなf(x)に対して恒等式(Stein identity)\n",
"\n",
"$E_{x \\sim p}[ f(x)\\nabla_x \\log p(x)+ \\nabla_x f(x) ]=0$\n",
"\n",
"が成り立つ(部分積分から)。分布p(x)と近いq(x)に対してこれを\n",
"\n",
"$E_{x\\sim q}[\\nabla_x \\log p(x)+ \\nabla_x f(x)] \\equiv E_{x\\sim q}[A_p f(x)]$\n",
"\n",
"$A_p f(x) \\equiv \\nabla_x \\log p(x)+ \\nabla_x f(x)$\n",
"\n",
"と定義する(記号もあいまって共変微分のように見える)。また\n",
"\n",
"$E_{x \\sim q}[A_p f(x)]=E_{x \\sim q}[A_p f(x)]-E_{x \\sim q}[A_q f(x)] =E_{x \\sim q}[ f(x)(\\nabla \\log p(x) -\\nabla \\log q(x))]$\n",
"\n",
"という関係式も成り立つ。"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## Stein disrepancy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Stein Operator$A _p$を用いて\n",
"\n",
" $\\sqrt{S(p,q)} \\equiv \\max_{f \\in F} E_{x\\sim q} [A_p f(x)]$\n",
" \n",
" とStein disrepancyを定義する。\n",
" \n",
" これをどのように計算するのか"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"関数f(x)を\n",
"\n",
"$f(x) = \\sum_i \\omega_i f_i(x)$\n",
"\n",
"と展開すると\n",
"\n",
"$E_q[A_p f(x) ]= E_q[A_p \\sum_i \\omega_i f_i(x)]= \\sum \\omega_i \\beta_i$ \n",
"\n",
"$\\beta_i \\equiv E_{x\\sim q}[A_p f_i(x)]$\n",
"\n",
"さらにカーネルを使って\n",
"\n",
"$f(x)=\\sum_i \\omega_i k(x,x_i)$\n",
"\n",
"とかくと$ <f,g>_H =\\sum_{i,j} w_i v_j k(x_i,x_j)$という形に書けることから\n",
"\n",
"$E_{x \\sim q}[A_p f(x)]=<f(\\cdot),E_{x \\sim q}[A_p k(\\cdot,x)]>$\n",
"\n",
"添字iに関する和を有限個とすることができて(基底定理?)\n",
"\n",
"$A_p f(x)=\\frac{1}{n}\\sum_{i=0}^n [ f(x_i) k(x_i,x) \\nabla_x \\log p(x_i)+ \\nabla_{x_i}k(x_i,x) ]$\n",
"\n",
"となる。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"結局$\\phi^*=\\arg\\min [ \\partial_\\epsilon KL(q_\\epsilon // p) ]$に対しては\n",
"\n",
"$\\phi$がxの関数であるとしてparticle iに対し\n",
"\n",
"$x_i^{l+1} \\leftarrow x_i^l + \\epsilon_l \\phi^*(x_i^l)$\n",
"\n",
"$\\phi^*(x)=\\frac{1}{n}\\sum_{j=0}^n [ k(x_j^l,x) \\nabla_x \\log p(x_j^l)+ \\nabla_{x_j}k(x_j^l,x) ]$\n",
"\n",
"という更新式を使う。$k(x_i,x_j)$はカーネルで行列を使って表すことができる(未知の値$k(x_i,x)$に対しては関数として実装する)。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## サンプリングの方法として(GANへの応用)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"論文 https://arxiv.org/abs/1611.01722"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Amortized SVGD"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"乱数$\\xi_i$に対し\n",
"$x_i=f_n(\\xi_i) $\n",
"であるとする更新式は\n",
"\n",
"$ \\Delta x_i =E_{ x\\sim \\{x_i\\} } [\\nabla p(x) k(x,x_i) + \\nabla_x k(x,x_i)]$\n",
"\n",
"$x_i'=x_i+\\epsilon \\Delta x_i$\n",
"\n",
"関数$f(\\eta;\\xi_i)$のパラメータ(として)$\\eta$があるとすると\n",
"\n",
"$\\eta=\\eta+\\epsilon \\sum_i \\partial_\\eta f_\\eta (\\xi_i)\\Delta x_i$\n",
"\n",
"としてfを更新していくのがSVGD"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"特にKL divergence$KL(q_\\eta//p)$に関して書き直すと\n",
"\n",
"$\\nabla_\\eta KL(q_\\eta//p)=-E_{\\xi \\sim q_0}[\\partial_\\eta f(\\eta ;\\xi)(\\nabla_x \\log p(x) - \\nabla_x \\log q_\\eta(x) ] $\n",
"\n",
"となる。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"さらに$\\bar{\\Delta}x_i= \\nabla_x \\log p(x_i) - \\nabla_x \\log q_\\eta(x_i) $\n",
"\n",
"と書くと\n",
"\n",
"$\\Delta x_i \\simeq E_{x \\sim q}[k(x,x_i) \\nabla_x \\log p(x)+ \\nabla_{x}k(x,x_i) ] $\n",
"$= E_{x \\sim q}[k(x,x_i)( \\nabla_x \\log p(x)+ \\nabla_{x} \\log q(x)] $\n",
"$= E_{x \\sim q}[\\bar{\\Delta}x_i k(x,x_i]$\n",
"\n",
"という近似的関係が成り立つ($\\Delta x_i$は$\\bar{\\Delta} x_i$をカーネル重み付けしたもの"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Amortized KSD"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$\\kappa_p(x,x')=\\nabla_x \\log p(x) \\nabla_x \\log p(x') k(x,x') +\\nabla_x \\log p(x) \\nabla_x k(x,x') + \\nabla_x \\log p(x') \\nabla_x k(x,x')+\n",
"\\nabla_x \\nabla_x' k(x,x')$\n",
"\n",
"これによって\n",
"$D^2(q//p)=E_{x,x'\\sim q}[\\kappa_p(x,x')]$\n",
"\n",
"の最小化を行う。\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Amortized MLE"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$p(x|\\theta)=\\exp(-\\phi(x,\\theta)-\\Phi(\\theta)$\n",
"という形の関数の最適化\n",
"\n",
"$\\xi_i \\sim q_{\\xi}, x_i=f_\\eta(\\xi_i)$ \n",
"として\n",
"\n",
"$\\eta \\leftarrow \\epsilon \\sum_i \\partial_\\eta f_\\eta (\\xi_i) \\Delta x_i$\n",
"\n",
"$\\{x_{obs}\\}$ を新たに引く\n",
"\n",
"$\\theta \\leftarrow \\theta -E_{obs}[\\nabla_\\theta \\phi(x,\\theta)]+E_\\eta[\\nabla_\\theta \\phi(x,\\theta)]$\n",
"\n",
"を交互に繰り返す。\n",
"\n",
"\n",
"GANの場合は\n",
"\n",
"$p(x|\\theta) \\propto \\exp(-||x-D(E(x;\\theta);\\theta)||)$\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## 実装"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"オリジナル実装はTheanoで書かれている。\n",
"\n",
"- https://github.com/DartML/Stein-Variational-Gradient-Descent/tree/master/python\n",
"- https://github.com/DartML/Stein-Variational-Gradient-Descent/tree/master/python\n",
"\n",
"SteinGAN もベタコード\n",
"- https://github.com/DartML/Stein-Variational-Gradient-Descent/blob/master/python/bayesian_logistic_regression.py\n",
"- https://github.com/DartML/SteinGAN/tree/master/lib\n",
"- https://github.com/ChunyuanLI/SVGD/blob/master/demo_svgd.ipynb\n",
"Edward, tensorflowのoptimizerとして組み込めるかもしれません。確率分布pとその微分を使うのでEdwardやtf.distributionsを使う必要がありそうです。particle分のパラメータを複製するなどのアイデアがありえます。\n",
"\n",
"- R package https://cran.r-project.org/web/packages/KSD/\n",
"\n",
"- Julia 実装 \n",
" - https://github.com/krisrs1128/readings/tree/master/svgd\n",
" - https://github.com/jgorham/stein_discrepancy/tree/master/src/experiments kernelは決め打ちです。自動微分を使えば任意のkernelのgradientが計算できるかと思ったのですが、Juliaのautodiffは偏微分をやるのが難しそうだったので断念しました\n",
"\n",
"- Tensorflow\n",
" - https://github.com/xiangze/SVGD"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- [Stein’s Method for Practical Machine Learning](https://www.cs.dartmouth.edu/~qliu/stein.html) ここに多くの論文がまとめられている\n",
"- [Probabilistic Learning and Inference Using Stein Discrepancy](https://www.cs.dartmouth.edu/~qliu/PDF/steinslides16.pdf) 概要スライド\n",
"- [Stein's Method](https://sites.google.com/site/steinsmethod/home)\n",
"- [機械学習論文読みメモ_14](https://qiita.com/festa78/items/1813377b59fd3685c119)\n",
"- http://chunyuan.li/\n",
"- 論文\n",
" - Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm [arxiv](https://arxiv.org/abs/1608.04471)\n",
" - LEARNING TO DRAW SAMPLES: WITH APPLICATION TO AMORTIZED MLE FOR GENERATIVE ADVERSARIAL LEARNING [pdf](https://arxiv.org/pdf/1611.01722.pdf)\n",
" - A Kernelized Stein Discrepancy for Goodness-of-fit Tests and Model Evaluation[arxiv](https://arxiv.org/abs/1602.03253)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autoclose": false,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": true,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment