Skip to content

Instantly share code, notes, and snippets.

@narrowlyapplicable
Last active October 5, 2019 14:52
Show Gist options
  • Save narrowlyapplicable/126ef083668ce6ea60681bd2025b047b to your computer and use it in GitHub Desktop.
Save narrowlyapplicable/126ef083668ce6ea60681bd2025b047b to your computer and use it in GitHub Desktop.
Tensorflow Probability の ReplicaExchangeMC で <http://statmodeling.hatenablog.com/entry/stan-parallel-tempering> を再現しようとした例。局所解を抜け出すことはできず、元記事の結果を再現できなかった。(ただしより簡単な初期値では成功した"tfp-replica_excahnge_easy.ipynb")
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tensorflow Probabilityによるレプリカ交換モンテカルロ試行\n",
"- `tfp.mcmc.ReplicaExchangeMC`を使用したが、例題を再現できなかった。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter('ignore')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"from tensorflow_probability import edward2 as ed"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"tfd = tfp.distributions\n",
"tf.compat.v1.disable_eager_execution()\n",
"\n",
"np.random.seed(123)\n",
"plt.style.use('ggplot')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"init_g = tf.global_variables_initializer()\n",
"init_l = tf.local_variables_initializer()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"num_results = 10000\n",
"num_burnin_steps = 2000"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Data\n",
"Statmodeling Memorandum様のレプリカ交換モンテカルロに関する記事の例題を取り上げる。\n",
"> <http://statmodeling.hatenablog.com/entry/stan-parallel-tempering> "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"N = 50\n",
"b = 0.6\n",
"s_y = 0.4\n",
"X_data = np.linspace(0.1, 4*np.pi, N)\n",
"Y_data = np.sin(b*X_data) + np.random.normal(loc=0.0, scale=s_y, size=N)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"ax.scatter(X_data, Y_data, marker=\"o\", color=\"k\")\n",
"ax.set_xlabel('X')\n",
"ax.set_ylabel('Y')\n",
"fig.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Model\n",
"元記事でのモデルは、Stanを用いて下記のように定義されている。\n",
"\n",
"```\n",
"data {\n",
" int<lower=1> N;\n",
" vector[N] Y;\n",
" vector[N] X;\n",
" real<lower=0> Inv_T;\n",
"}\n",
"\n",
"parameters {\n",
" real<lower=0> b;\n",
" real<lower=0> s_y;\n",
"}\n",
"\n",
"transformed parameters {\n",
" real E;\n",
" {\n",
" vector[N] mu;\n",
" for (n in 1:N)\n",
" mu[n] <- sin(b * X[n]);\n",
" E <- 0;\n",
" E <- E - normal_log(b, 0, 50);\n",
" E <- E - student_t_log(s_y, 4, 0, 5);\n",
" E <- E - normal_log(Y, mu, s_y);\n",
" }\n",
"}\n",
"\n",
"model {\n",
" increment_log_prob(-Inv_T * E);\n",
"}\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def joint_log_prob(x_data, y_data, b, s_y):\n",
" rv_b = tfd.TruncatedNormal(loc=0.0, scale=50.0, low=0.0, high=1e100, name=\"rv_b\")\n",
" #rv_sy = tfd.Uniform(low=0.0, high=1e100, name=\"rv_sy\")\n",
" rv_sy = tfd.TruncatedNormal(loc=0.0, scale=5.0, low=0.0, high=1e100, name=\"rv_sy\")\n",
" #StudentT(df=4, loc=0, scale=5, name=\"rv_sy\")\n",
" \n",
" mu = tf.sin(b*x_data)\n",
" rv_obs = tfd.Normal(loc=mu, scale=s_y, name=\"rv_obs\") \n",
" \n",
" return(\n",
" rv_b.log_prob(b)\n",
" + rv_sy.log_prob(s_y)\n",
" + tf.reduce_sum(rv_obs.log_prob(y_data))\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- TF-PにはStanの`<lower=0>`のような値域制限が(私の知る限りでは)ない。そのため事前分布でサポートに制限を加えている。\n",
" - bに関しては、一様事前分布の範囲を指定した。\n",
" - s_yに関しては切断t分布がtfdに無いため、切断正規分布で代用した。\n",
" - 事前分布によらず値域を制限する法は不明。`tf.gather(b, 条件)`だろうか?\n",
"- 逆温度指定の部分は`tfp.mcmc.ReplicaExchangeMC`に任せるため、モデルには明記しない。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def unnormalized_log_posterior(b, s_y):\n",
" return joint_log_prob(X_data, Y_data, b, s_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. HMC\n",
"### 3.1. Define the Kernel\n",
"まず最も単純なHMCカーネルを定義する。\n",
"- step_size調整無し\n",
"- bijectorによる効率化無し。"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"kernel1 = tfp.mcmc.HamiltonianMonteCarlo(\n",
" target_log_prob_fn=unnormalized_log_posterior,\n",
" step_size=0.01,\n",
" num_leapfrog_steps=3 #10\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.2. Inference"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"initial_state = [24.17, 0.4048] #[0.6, 0.4] #b=24.17, s_y=0.4048"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### sample chain"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /anaconda3/lib/python3.7/site-packages/tensorflow_probability/python/internal/special_math.py:154: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
]
}
],
"source": [
"states, kernel_results = tfp.mcmc.sample_chain(\n",
" num_results=num_results, \n",
" num_burnin_steps=num_burnin_steps,\n",
" kernel=kernel1,\n",
" current_state=initial_state,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### run"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"with tf.Session() as sess:\n",
" sess.run(init_g)\n",
" sess.run(init_l)\n",
" [states_, kernel_results_ ]= sess.run([states, kernel_results])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"acceptance rate : 96.66 %\n"
]
}
],
"source": [
"try:\n",
" print(f'acceptance rate : {kernel_results_.inner_results.is_accepted.mean()*100} %')\n",
"except AttributeError:\n",
" print(f'acceptance rate : {kernel_results_.is_accepted.mean()*100} %')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.3. Result"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"#states_"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x216 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1,2,figsize=(7,3))\n",
"ax[0].hist(states_[0], bins=100, color=\"C2\")\n",
"ax[0].set_title(r'posterior of $b$')\n",
"ax[1].hist(states_[1], bins=100, color=\"C2\")\n",
"ax[1].set_title(r'posterior of $\\sigma_y$')\n",
"fig.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 局所解から抜け出せていない。\n",
" - 真値付近から開始(`initial_state = [0.6, 0.4]`)した場合には、その周辺の探索は可能。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. HMC + Step Size Adaptation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.1. Define the Kernel\n",
"- `tfp.mcmc.SimpleStepSizeAdoptation`による自動ステップサイズ調整を用いる。\n",
" - TFP 0.7.0時点では、`tfp.mcmc.TransformedTransitionKernel`(bijectorによる探索空間の変換)や、この後使用する`tfp.mcmc.ReplicaExchangeMC`と併用するとエラーが発生する。"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"kernel2 = tfp.mcmc.SimpleStepSizeAdaptation(\n",
" inner_kernel=kernel1,\n",
" num_adaptation_steps=int(num_burnin_steps * 0.8),\n",
" adaptation_rate=0.001\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.2. Inference"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"states, kernel_results = tfp.mcmc.sample_chain(\n",
" num_results=num_results, \n",
" num_burnin_steps=num_burnin_steps,\n",
" kernel=kernel2,\n",
" current_state=initial_state,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"with tf.Session() as sess:\n",
" sess.run(init_g)\n",
" sess.run(init_l)\n",
" [states_, kernel_results_ ]= sess.run([states, kernel_results])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"acceptance rate : 60.74 %\n"
]
}
],
"source": [
"try:\n",
" print(f'acceptance rate : {kernel_results_.inner_results.is_accepted.mean()*100} %')\n",
"except AttributeError:\n",
" print(f'acceptance rate : {kernel_results_.is_accepted.mean()*100} %')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.3. Result"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x216 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1,2,figsize=(7,3))\n",
"ax[0].hist(states_[0], bins=100, color=\"C2\")\n",
"ax[0].set_title(r'posterior of $b$')\n",
"ax[1].hist(states_[1], bins=100, color=\"C2\")\n",
"ax[1].set_title(r'posterior of $\\sigma_y$')\n",
"fig.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- ややacceptance rateが変動するものの、結果に大きな変動は見られなかった。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Replica Exchange Monte Carlo"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.1. Define the Kernel\n",
"- `tfp.mcmc.ReplicaExchangeMC`を使う。\n",
"- 前述の通り、TFP 0.7.0時点では`SimpleStepSizeAdoptation`と併用できないため、ステップサイズ調整は無し。"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def make_kernel_fn(target_log_prob_fn, seed):\n",
" return tfp.mcmc.HamiltonianMonteCarlo(\n",
" target_log_prob_fn=unnormalized_log_posterior,\n",
" step_size=0.01, #0.1, #0.001\n",
" num_leapfrog_steps=3, #100, #2, #10,\n",
" seed=seed\n",
" )\n",
"\n",
"\n",
"kernel3 = tfp.mcmc.ReplicaExchangeMC(\n",
" target_log_prob_fn=unnormalized_log_posterior,\n",
" inverse_temperatures=(0.5**np.linspace(0, -np.log(0.002)/np.log(2), 10).astype(np.float32)).tolist(),\n",
" make_kernel_fn=make_kernel_fn,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 逆温度は元記事どおり、1.0から0.02まで等比数列で10段階を指定。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.2. Inference"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /anaconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/replica_exchange_mc.py:385: setdiff1d (from tensorflow.python.ops.array_ops) is deprecated and will be removed after 2018-11-30.\n",
"Instructions for updating:\n",
"This op will be removed after the deprecation date. Please switch to tf.sets.difference().\n"
]
}
],
"source": [
"states, kernel_results = tfp.mcmc.sample_chain(\n",
" num_results=num_results, \n",
" num_burnin_steps=num_burnin_steps,\n",
" kernel=kernel3,\n",
" current_state=initial_state,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"with tf.Session() as sess:\n",
" sess.run(init_g)\n",
" sess.run(init_l)\n",
" [states_, kernel_results_ ]= sess.run([states, kernel_results])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([24.160742, 24.183947, 24.15525 , ..., 24.159927, 24.178778,\n",
" 24.16116 ], dtype=float32),\n",
" array([0.59501934, 0.53401023, 0.5231044 , ..., 0.55630195, 0.5180631 ,\n",
" 0.5327149 ], dtype=float32)]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"states_"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"# try:\n",
"# print(f'acceptance rate : {kernel_results_.inner_results.is_accepted.mean()*100} %')\n",
"# except AttributeError:\n",
"# print(f'acceptance rate : {kernel_results_.is_accepted.mean()*100} %')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.3. Result"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x216 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1,2,figsize=(7,3))\n",
"ax[0].hist(states_[0], bins=100, color=\"C2\")\n",
"ax[0].set_title(r'posterior of $b$')\n",
"ax[1].hist(states_[1], bins=100, color=\"C2\")\n",
"ax[1].set_title(r'posterior of $\\sigma_y$')\n",
"fig.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 局所解から移動できていない!\n",
" - ステップサイズが小さすぎる?\n",
" - これ以上大きくすると、そもそものサンプリングができなくなる。\n",
" - 逆温度ごとに変えられれば良いが、現状のtf-pでは(筆者の知る限りにおいて)不可能\n",
" - リープフロッグの回数が不足?\n",
" - 100回に設定しても結果は変わらなかった。\n",
" - こちらも逆温度ごとに変える設定はできないはず。。\n",
" - 交換頻度が高すぎる?\n",
" - TF-P側での設定方法が不明\n",
"- bijectorの使用は?\n",
" - TF-P 0.7時点において、`tfp.mcmc.TransformedTransitionKernel`を`ReplicaExchangeMC`に与えるとエラーを生じる。\n",
" - `StepSizeAdaptation`との併用と同様、他の手法との併用を考慮できていない様子。。。\n",
"- TF-P 0.8(?)でのNUTS実装が待たれる。\n",
" - レプリカごとにリープフロッグ回数が調整されれば、移動距離の問題は解決するはず。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment