Skip to content

Instantly share code, notes, and snippets.

@narrowlyapplicable
Last active September 25, 2021 09:37
Show Gist options
  • Save narrowlyapplicable/0922b733fa2cc75167f71eff448bf1a4 to your computer and use it in GitHub Desktop.
Save narrowlyapplicable/0922b733fa2cc75167f71eff448bf1a4 to your computer and use it in GitHub Desktop.
ディリクレ過程(中華料理店過程)に基づく構造変化推定のPython実装。MLPシリーズ『ノンパラメトリックベイズ』(佐藤)§6の例題を再現している。
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ディリクレ過程に基づく構造変化推定(MLPノンパラベイズ本より)\n",
"- MLPシリーズ「ノンパラメトリックベイズ」の§6の内容をPythonで実装した。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## モデル\n",
"- クラスタ毎に線形回帰モデルを仮定し、この回帰係数および観測ノイズの生成過程にCRPを用いる。\n",
" - Gistだと下の数式が正しく表示されないかもしれません。\n",
" - 線形回帰モデル\n",
" $$y_t \\sim {\\cal{N}} (\\theta_t^{\\mathrm{T}}x_t, \\sigma_t^2) $$\n",
" - クラスタ毎の係数の生成過程\n",
" - ステップtにおけるクラスタの総数を$K_t^+$と置き、第kクラスタの出現回数を$n_{t,k}$と表記する。\n",
" - $k\\in\\{1,...,K_{t-1}^+\\}$、すなわち既存のクラスタに属する場合\n",
" $$P(\\theta^{(k)}, \\sigma_t^{(k)2}) = \\frac{n_{t-1,k}}{t-1+\\alpha}$$\n",
" - $k=K_{t-1}^+ + 1$、すなわち新規クラスタの場合\n",
" $$(\\theta^{(K_{t-1}^+ +1)}, \\sigma_t^{(K_{t-1}^+ +1)2}) \\sim {\\cal{N}}(\\mu,\\Sigma)IG(\\frac{n_0}{2},\\frac{\\tau}{2})$$\n",
" $$P(\\theta^{(K_{t-1}^+ +1)}, \\sigma_t^{(K_{t-1}^+ +1)2}) = \\frac{\\alpha}{t-1+\\alpha}$$\n",
" - $\\mu, \\Sigma, \\tau$については下記の通り事前分布を設定する。\n",
" $$\\mu \\sim {\\cal{N}}(\\mu_0, V_0), \\qquad \\Sigma^{-1} \\sim W(\\tau_0, \\Sigma_0), \\qquad \\tau \\sim Ga(\\frac{m_0}{2}, \\frac{\\tau_0}{2})$$\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 準備\n",
"### 1.1. ライブラリ"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import scipy.stats as st\n",
"import seaborn as sns\n",
"\n",
"plt.style.use('ggplot')\n",
"np.random.seed(1234)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.2. シミュレーションデータ"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 「§6.4 実験例」と同じ条件で作成する。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$$ y_t = 1 + 0.5t + u_t, \\quad u_t \\sim {\\cal{N}} (0, 0.3) \\qquad (1 \\leqq t \\leqq30) $$ \n",
"\n",
"$$ y_t = 25 - 0.3t + u_t, \\quad u_t \\sim {\\cal{N}} (0, 0.1) \\qquad (31 \\leqq t \\leqq 60) $$ \n",
"\n",
"$$ y_t = 1 + 0.1t + u_t, \\quad u_t \\sim {\\cal{N}} (0, 0.2) \\qquad (61 \\leqq t \\leqq90) $$"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"x = np.arange(1,91,1)\n",
"y = np.zeros(90)\n",
"y_tru = np.zeros(90)\n",
"\n",
"y[:30] = 1 + 0.5*x[:30] + st.norm.rvs(loc=0, scale=np.sqrt(0.3), size=30)\n",
"y[30:60] = 25 - 0.3*x[30:60] + st.norm.rvs(loc=0, scale=np.sqrt(0.1), size=30)\n",
"y[60:] = 1 +0.1*x[60:] + st.norm.rvs(loc=0, scale=np.sqrt(0.2), size=30)\n",
"y_tru[:30] = 1 + 0.5*x[:30]\n",
"y_tru[30:60] = 25 - 0.3*x[30:60]\n",
"y_tru[60:] = 1 +0.1*x[60:]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(figsize = (7,5))\n",
"ax.plot(x, y, marker=\".\", color=\"k\", linewidth=0)\n",
"ax.plot(x, y_tru, color=\"b\")\n",
"ax.set_xlabel('x'); ax.set_ylabel('y')\n",
"ax.set_title('simulation data')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. サンプリング用関数群\n",
"- ギブスサンプリングを行うための関数群を定義する\n",
" - 本来クラスにまとめた方が簡潔になるが、試行用なので略。\n",
" \n",
"まず潜在変数zのサンプリングから"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1. 潜在変数z\n",
"- 既存のクラスタ数を$K^+$として, \n",
"$k \\in \\{1,...,K^+\\}$のとき\n",
"$$p(z_k | y_t,x_t,\\theta^{(k)},z_{1:T}^{\\t}) \\sim {\\cal{N}} (y_t | \\theta^{(k)\\mathrm{T}} , \\sigma^{(k)2}) \\times \\frac{n_k^\\t}{T-1+\\alpha}$$\n",
"$k \\notin \\{1,...,K^+\\}$のとき\n",
"$$p(z_k | y_t,x_t,\\theta^{(k)},z_{1:T}^{\\t}) \\sim {\\cal{N}} (y_t | \\theta^{(new)\\mathrm{T}} , \\sigma^{(new)2}) \\times \\frac{\\alpha}{T-1+\\alpha}$$\n",
"$$(\\theta^{(new)}, \\sigma^{(new)}) \\sim {\\cal{N}} (\\mu, \\Sigma) IG(\\frac{n_0}{2}, \\frac{\\tau}{2})$$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def sampler_z(t, yt, xt, theta, sigma_y, z, alpha, T, mu, sigma, n0, tau):\n",
" k_plus, n_t = np.unique(z, return_counts=True)\n",
" prob_z = np.empty(k_plus.shape[0]+1)#P(z_t = k) not notrmalize\n",
" for k in k_plus:\n",
" ## culculate n_k^¥t\n",
" if(z[t]==k):\n",
" n_kt = n_t[k_plus==k] - 1\n",
" else:\n",
" n_kt = n_t[k_plus==k]\n",
" ## culculate probability s.t. z_t = k\n",
" prob_z[k] = st.norm.pdf(x=yt, loc = np.dot(theta[k_plus==k], np.r_[xt,1]),\\\n",
" scale=sigma_y[k_plus==k])*(n_kt/(T-1+alpha))\n",
"\n",
" ## create new cluster\n",
" theta_new = st.multivariate_normal.rvs(mean=mu, cov=sigma, size=1)\n",
" sigma_new = np.sqrt(st.invgamma.rvs(a=n0/2, scale=tau/2, size=1))\n",
" prob_z[-1] = st.norm.pdf(x=yt, loc = np.dot(theta_new, np.r_[xt,1]), scale=sigma_new)*(alpha/(T-1+alpha))\n",
" \n",
" ## sampling\n",
" prob_z /= np.sum(prob_z)\n",
" z_sample = np.random.choice(np.r_[k_plus, k_plus[-1]+1], size=1, p=prob_z)[0]#random.choice?\n",
" return z_sample, theta_new, sigma_new\n",
" ### return new z[t]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- クラスタが追加された場合に$\\theta^{new}$と$\\sigma^{new}$も返す必要がある\n",
" - 当座は毎回$theta^{new}$を返す仕様とし, mainの側でサンプリング値$z_k$が新クラスタならば$theta$に$theta^{new}$を追加する処理を加えておく。\n",
" - クラス化してしまえば内部で処理できるが、今回は手をつけていない。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.2. パラメータ$\\theta$\n",
"- これはクラスタ毎に定まる。$\\cal{T}_k = \\{t|z_t=k\\}$に対して、\n",
"$$p(\\theta^{(k)} | y,x,z,\\sigma^{(k)}, \\mu, \\Sigma) \\sim {\\cal{N}}(\\mu_k, \\Sigma_k).$$\n",
"ただし\n",
"$$\\Sigma_k^{-1} = \\sum_{t \\in \\cal{T}_k}{\\frac{x_t x_t^{\\mathrm{T}}}{\\sigma^{(k)2}}} + \\Sigma^{-1}$$\n",
"$$\\mu_k = \\Sigma_k(\\sum_{t \\in \\cal{T}_k}{\\frac{y_t x_t}{\\sigma^{(k)2}}} + \\Sigma^{-1}\\mu)$$"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def sampler_theta(k, y, x, z, sigma_y, mu, sigma_inv):\n",
" #print(mu, sigma_inv)\n",
" sigma_yk = sigma_y[k]\n",
" t_k = np.where(z == k)[0]\n",
" x = np.c_[x.copy(), np.ones(x.shape[0])]#X_t = [X_t, 1]\n",
" sigma_k_inv = sigma_inv.copy()\n",
" mu_k_tmp = np.dot(sigma_inv.copy(), mu) #Sigma^-1 * mu\n",
" for tt in t_k: # Sigma_(t in T_k)\n",
" x_tt = x[tt][:,np.newaxis]\n",
" sigma_k_inv += np.dot(x_tt,x_tt.T) / (sigma_yk**2)\n",
" mu_k_tmp += (y[tt]*x[tt]) / (sigma_yk**2)\n",
" sigma_k = np.linalg.inv(sigma_k_inv)\n",
" mu_k = np.dot(sigma_k, mu_k_tmp)\n",
" del mu_k_tmp\n",
" return st.multivariate_normal.rvs(mean=mu_k, cov=sigma_k, size=1)\n",
" ### return new theta[k]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.3. 観測時ノイズの分散$\\sigma^2$\n",
"- これもクラスタ毎に定まる。\n",
"$$p(\\sigma^{(k)2} | y,x,z,\\theta^{(k)},n_0, \\tau) = IG(\\sigma^{(k)2} | \\frac{n_0+n_k}{2}, \\frac{\\tau_k}{2})$$\n",
"ただし\n",
"$$\\tau_k = \\tau + \\sum_{t \\in \\cal{T}_k}{||y_t - \\theta^{(k)\\mathrm{T}}x_t||^2}$$\n",
"- 下記の関数ではsqrtを取り$\\sigma^{(k)}$をサンプリングしている。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def sampler_sigma_y(k, y, x, z, theta, n0, tau):\n",
" t_k = np.where(z == k)[0]\n",
" n_k = t_k.shape[0]\n",
" tau_k = tau\n",
" for tt in t_k:\n",
" resid = y[tt] - np.dot(theta[k],np.r_[x[tt],1])\n",
" tau_k += np.dot(resid, resid)\n",
" return np.sqrt(st.invgamma.rvs(a=(n0+n_k)/2, scale=tau_k/2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.4. $\\mu$ ($\\theta^{(new)}$の平均)\n",
"- サンプリング分布は\n",
"$$p(\\mu | \\theta^{(1:K^+)}, \\Sigma, \\mu_0, V_0) = {\\cal{N}}(\\mu | \\mu_+, V_+)$$ \n",
" ここで$\\mu_0, V_0$は\n",
"$$V_0^{-1} = K^+\\Sigma^{-1}+V_0^{-1}$$\n",
"$$\\mu_+ = V_+(\\Sigma^{-1}\\sum_{k=1}^{K^+}{\\theta^{(k)}}+V_0^{-1}\\mu_0)$$"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def sampler_mu(theta, sigma_inv, mu0, v0_inv):\n",
" vp = np.linalg.inv(theta.shape[0] * sigma_inv + v0_inv )\n",
" mup = np.dot(sigma_inv, np.sum(theta, axis=0)) + np.dot(v0_inv, mu0)\n",
" mup = np.dot(vp, mup)\n",
" return st.multivariate_normal.rvs(mean=mup, cov=vp)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.5. $\\Sigma$ ($\\theta^{(new)}$の共分散行列)\n",
"- サンプリング分布は\n",
"$$p(\\Sigma^{-1} | \\theta^{(1:K^+)}, \\mu, \\nu_0, \\Sigma_0) = W(\\Sigma^{-1} | \\nu_+, \\Sigma_+)$$\n",
"ただし\n",
"$$\\nu_+ = \\nu_0 + K^+$$\n",
"$$\\Sigma_+^{-1} = \\Sigma_0^{-1} + \\sum_{k=1}^{K^+}{(\\theta^{(k)}-\\mu)(\\theta^{(k)}-\\mu)}^\\mathrm{T}$$"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def sampler_Sigma_inv(theta, mu, nu0, sigma0_inv):\n",
" nup = nu0 + theta.shape[0]\n",
" sigmap_inv = sigma0_inv\n",
" for ii in range(theta.shape[0]):\n",
" tmp = (theta[ii] - mu)[:,np.newaxis]\n",
" sigmap_inv += np.dot(tmp, tmp.T)\n",
" return st.wishart.rvs(df=nup, scale=np.linalg.inv(sigmap_inv))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.6. $\\tau$ ($\\sigma_y$の生成過程のパラメータ)\n",
"- サンプリング分布は\n",
"$$p(\\tau | \\sigma^{(1:K^+)}, n_0, m_0, \\tau_0) = Ga(\\tau | \\frac{m_+}{2}, \\frac{\\tau_+}{2})$$\n",
"ただし\n",
"$$m_+ = m_0 + n_0K^+$$\n",
"$$\\tau_0 = \\tau_0 + \\sum_{k=1}^{K^+}{\\frac{1}{\\sigma^{(k)2}}}$$"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def sampler_tau(sigma_y, n0, m0, tau0):\n",
" mp = m0 + n0*sigma_y.shape[0]\n",
" taup = tau0 + np.sum(1/sigma_y**2)\n",
" return st.gamma.rvs(a=mp/2, scale=taup/2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. サンプリング"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.1. 初期値設定"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ハイパーパラメータ群"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"alpha = 1\n",
"n0 = 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"事前分布のパラメータ群"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"mu0 = np.array([0, 0])\n",
"v0_tmp = np.random.uniform(-10, 10, (2, 2))#np.array([[0.5, 0], [0,4]])\n",
"v0 = np.dot(v0_tmp, v0_tmp.T)\n",
"v0_inv = np.linalg.inv(v0)\n",
"\n",
"nu0 = 2\n",
"sigma0_tmp = np.random.uniform(-10, 10, (2, 2))#np.array([[0.5, 0], [0,10]])#np.random.uniform(-1, 1, (2, 2))\n",
"sigma0 = np.dot(sigma0_tmp, sigma0_tmp.T) # positive definite\n",
"sigma0_inv = np.linalg.inv(sigma0)\n",
"\n",
"m0 = 0.5\n",
"tau0 = 2 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"パラメータの初期値"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
"mu : [-9.56246008 7.33209689]\n",
"Sigma : [[120.50695055 -49.34876924]\n",
" [-49.34876924 23.14239448]]\n",
"tau : 0.24918271440471546\n"
]
}
],
"source": [
"T = x.shape[0]\n",
"\n",
"theta = np.array([[0.5,1], [-0.3, 25]])#, [0.1, 1]])\n",
"sigma_y = np.sqrt(np.array([0.2, 0.2]))#, 0.2]))\n",
"z = np.repeat(np.array([0,1]), 45)#np.repeat(np.array([0,1,2]), 30) #np.zeros(T, dtype=\"int\")\n",
"print(z)\n",
"mu = st.multivariate_normal.rvs(mean=mu0, cov=v0)\n",
"sigma = st.wishart.rvs(df=nu0, scale=sigma0)\n",
"sigma_inv = np.linalg.inv(sigma)\n",
"tau = st.gamma.rvs(a=m0/2, scale=tau0/2)\n",
"print(\"mu : \", mu)\n",
"print(\"Sigma : \", sigma)\n",
"print(\"tau :\", tau)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.2. ギブスサンプリング\n",
"- 本では少なくとも12000ステップのサンプリングを行っていたが、ここでは2000ステップ(うちバーンイン1000ステップ)と設定した。"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4min 27s, sys: 1.26 s, total: 4min 29s\n",
"Wall time: 4min 31s\n"
]
}
],
"source": [
"%%time\n",
"n_step = 2000\n",
"burnin = 1000\n",
"z_sample = np.zeros((n_step, T))\n",
"theta_sample = np.zeros((n_step, T, theta.shape[1]))\n",
"\n",
"n_cluster = theta.shape[0]\n",
"z_unique = np.unique(z)\n",
"for step in range(n_step):\n",
" #print(step)\n",
" for tt in range(T):\n",
" z_new, theta_new, sigma_new = sampler_z(tt, y[tt], x[tt], theta, sigma_y, z, alpha, T, mu, sigma, n0, tau)\n",
" z[tt] = z_new\n",
" if(z_unique is not np.unique(z)):\n",
" z_unique = np.unique(z)\n",
" # クラスタ数が増えた場合の処理\n",
" if(n_cluster < z_unique.shape[0]):\n",
" n_cluster += 1\n",
" theta = np.r_[theta, [theta_new]]\n",
" sigma_y = np.r_[sigma_y, sigma_new]\n",
" #print(z_new, theta, sigma_y)\n",
" # クラスタ数が減った場合の処理\n",
" elif(n_cluster > z_unique.shape[0]):\n",
" n_cluster -= 1\n",
" theta = theta[z_unique]\n",
" sigma_y = sigma_y[z_unique]\n",
" #Zを0,1,2,...に置き換える処理\n",
" for ii, z_val in zip(range(z_unique.shape[0]), z_unique):\n",
" z[z==z_val] = ii\n",
"\n",
" for kk in range(n_cluster):\n",
" theta[kk] = sampler_theta(kk, y, x, z, sigma_y, mu, sigma_inv.copy())\n",
" sigma_y[kk] = sampler_sigma_y(kk, y, x, z, theta, n0, tau)\n",
" mu = sampler_mu(theta, sigma_inv, mu0, v0_inv)\n",
" sigma_inv = sampler_Sigma_inv(theta, mu, nu0, sigma0_inv.copy())\n",
" tau = sampler_tau(sigma_y, n0, m0, tau0)\n",
"\n",
" z_sample[step] = z\n",
" theta_sample[step] = np.r_[theta, np.repeat([np.repeat(np.nan,theta.shape[1])], T-theta.shape[0], axis=0)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 旧型のMacBookPro(Mid2012)でも、4分ほどで実行できている。\n",
"- sampler_thetaに渡すsigma_invは、copy()を取らないと内部で値が変化してしまう。\n",
" - 内部でも代入はしていませんが、他のベクトルと内積を取った時に値が変わった?\n",
" - これで丸一日を溶かしました。"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 1, 2])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.unique(z)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 結果\n",
"### 4.1. クラスタ数の分布"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"サンプリングにより得られたクラスタ数の配分は以下。"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(array([2, 3, 4]), array([ 12, 957, 31]))\n"
]
}
],
"source": [
"print(np.unique(np.max(z_sample[burnin:].astype('int')+1, axis=1), return_counts=True))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"これを図示する。"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFopJREFUeJzt3XuQXnV9x/F3ZAGlqAEWgU2iYEm9MV4ZpDL1hqOClOBUvuIFA0YztXjFKmgdsV46WDsi3o1gm9QLfEUsUamI4LUVFFKsVVqHAiWBmLAkoIjKBLZ/nN/qM5tN9mGf3bO/s7xfM8/sc37n95zzzY/z7Idz2XMWjI2NIUlSbR4w1wVIkjQZA0qSVCUDSpJUJQNKklQlA0qSVCUDSpJUJQNKklSloak6RMRngGOAzZl5SGnbGzgfOBC4EYjM3BoRC4CzgaOBu4CTMnNd+cxy4B1lse/NzNUz+0+RJM0n/exB/RPw/AltpwOXZeZS4LIyDXAUsLS8VgKfgN8H2hnAU4HDgDMiYq9Bi5ckzV9TBlRmfhfYMqF5GTC+B7QaOK6nfU1mjmXmFcDCiDgAeB5waWZuycytwKVsH3qTGfPly5cvX/PyNaUpD/HtwH6ZuREgMzdGxMNK+yJgfU+/DaVtR+3biYiVNHtfZCZ33333NEtsDA0NsW3btoGW0RZrnXldqRO6U2tX6oTu1NqVOmFmat1tt936W9dAa9negknaxnbSvp3MXAWsGu8zOjo6UEHDw8MMuoy2WOvM60qd0J1au1IndKfWrtQJM1PryMhIX/2mexXfpnLojvJzc2nfACzp6bcYuGUn7ZIkTWq6AbUWWF7eLwcu6ml/RUQsiIjDgTvKocBLgOdGxF7l4ojnljZJkibVz2XmXwCeCQxHxAaaq/HOBDIiVgA3AceX7hfTXGJ+Hc1l5icDZOaWiHgP8KPS792ZOfHCC0mSfm/KgMrMl+xg1pGT9B0DTtnBcj4DfOY+VSdJut/yThKSpCoZUJKkKhlQkqQqGVCSpCrN9B/qSpqmTS982lyXAMAun1471yVIgHtQkqRKGVCSpCoZUJKkKhlQkqQqGVCSpCoZUJKkKhlQkqQqGVCSpCoZUJKkKhlQkqQqGVCSpCoZUJKkKhlQkqQqGVCSpCoZUJKkKhlQkqQqGVCSpCoZUJKkKhlQkqQqGVCSpCoZUJKkKhlQkqQqGVCSpCoZUJKkKhlQkqQqGVCSpCoZUJKkKhlQkqQqGVCSpCoZUJKkKhlQkqQqDQ3y4Yh4E/AqYAz4CXAycABwHrA3sA44MTPvjojdgTXAU4DbgBdn5o2DrF+SNH9New8qIhYBrwcOzcxDgF2AE4D3A2dl5lJgK7CifGQFsDUzDwbOKv0kSZrUoIf4hoAHRcQQsAewEXg2cEGZvxo4rrxfVqYp84+MiAUDrl+SNE9NO6Ay82bgH4CbaILpDuBq4PbM3Fa6bQAWlfeLgPXls9tK/32mu35J0vw27XNQEbEXzV7RQcDtwBeBoybpOlZ+Tra3NDaxISJWAisBMpPh4eHplgjA0NDQwMtoi7XOvK7UCbBprgsophqvLo1pV2rtSp3Qbq2DXCTxHOCGzLwVICIuBJ4GLIyIobKXtBi4pfTfACwBNpRDgg8FtkxcaGauAlaVybHR0dEBSmy+bIMuoy3WOvO6UmdNphqvLo1pV2rtSp0wM7WOjIz01W+QgLoJODwi9gB+AxwJXAV8C3gRzZV8y4GLSv+1ZfoHZf7lmbndHpQkSTDYOagraS52WEdzifkDaPZ8TgNOjYjraM4xnVs+ci6wT2k/FTh9gLolSfPcQH8HlZlnAGdMaL4eOGySvr8Fjh9kfZKk+w/vJCFJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqpIBJUmqkgElSaqSASVJqtLQIB+OiIXAOcAhwBjwSuB/gPOBA4EbgcjMrRGxADgbOBq4CzgpM9cNsn5J0vw16B7U2cDXM/PRwBOAa4HTgcsycylwWZkGOApYWl4rgU8MuG5J0jw27YCKiIcATwfOBcjMuzPzdmAZsLp0Ww0cV94vA9Zk5lhmXgEsjIgDpl25JGleG+QQ3yOBW4F/jIgnAFcDbwD2y8yNAJm5MSIeVvovAtb3fH5DadvYu9CIWEmzh0VmMjw8PECJMDQ0NPAy2mKtM68rdQJsmusCiqnGq0tj2pVau1IntFvrIAE1BDwZeF1mXhkRZ/OHw3mTWTBJ29jEhsxcBawanz86OjpAic2XbdBltMVaZ15X6qzJVOPVpTHtSq1dqRNmptaRkZG++g1yDmoDsCEzryzTF9AE1qbxQ3fl5+ae/kt6Pr8YuGWA9UuS5rFpB1Rm/gJYHxGPKk1HAj8D1gLLS9ty4KLyfi3wiohYEBGHA3eMHwqUJGmigS4zB14HfC4idgOuB06mCb2MiBXATcDxpe/FNJeYX0dzmfnJA65bkjSPDRRQmXkNcOgks46cpO8YcMog65Mk3X94JwlJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVDChJUpUMKElSlQwoSVKVhgZdQETsAlwF3JyZx0TEQcB5wN7AOuDEzLw7InYH1gBPAW4DXpyZNw66fknS/DQTe1BvAK7tmX4/cFZmLgW2AitK+wpga2YeDJxV+kmSNKmBAioiFgMvAM4p0wuAZwMXlC6rgePK+2VlmjL/yNJfkqTtDLoH9SHgrcC9ZXof4PbM3FamNwCLyvtFwHqAMv+O0l+SpO1M+xxURBwDbM7MqyPimaV5sj2isT7m9S53JbASIDMZHh6ebokADA0NDbyMtljrzOtKnQCb5rqAYqrx6tKYdqXWrtQJ7dY6yEUSRwDHRsTRwAOBh9DsUS2MiKGyl7QYuKX03wAsATZExBDwUGDLxIVm5ipgVZkcGx0dHaDE5ss26DLaYq0zryt11mSq8erSmHal1q7UCTNT68jISF/9pn2ILzPflpmLM/NA4ATg8sx8GfAt4EWl23LgovJ+bZmmzL88M7fbg5IkCWbn76BOA06NiOtozjGdW9rPBfYp7acCp8/CuiVJ88TAfwcFkJnfBr5d3l8PHDZJn98Cx8/E+iRJ8593kpAkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVcmAkiRVyYCSJFXJgJIkVWlouh+MiCXAGmB/4F5gVWaeHRF7A+cDBwI3ApGZWyNiAXA2cDRwF3BSZq4brHxJ0nw1yB7UNuDNmfkY4HDglIh4LHA6cFlmLgUuK9MARwFLy2sl8IkB1i1JmuemHVCZuXF8DygzfwVcCywClgGrS7fVwHHl/TJgTWaOZeYVwMKIOGDalUuS5rVpH+LrFREHAk8CrgT2y8yN0IRYRDysdFsErO/52IbStnHCslbS7GGRmQwPDw9U29DQ0MDLaIu1zryu1Amwaa4LKKYary6NaVdq7Uqd0G6tAwdUROwJfAl4Y2b+MiJ21HXBJG1jExsycxWwanz+6OjoQPUNDw8z6DLaYq0zryt11mSq8erSmHal1q7UCTNT68jISF/9BrqKLyJ2pQmnz2XmhaV50/ihu/Jzc2nfACzp+fhi4JZB1i9Jmr8GuYpvAXAucG1mfrBn1lpgOXBm+XlRT/trI+I84KnAHeOHAiVJmmiQQ3xHACcCP4mIa0rb22mCKSNiBXATcHyZdzHNJebX0VxmfvIA65YkzXPTDqjM/D6Tn1cCOHKS/mPAKdNdnyTp/sU7SUiSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASZKqZEBJkqo0NNcFSJKmds+rj53rEhpf/vfWVuUelCSpSgaUJKlKBpQkqUqtn4OKiOcDZwO7AOdk5pmzub5NL3zabC6+b7t8eu1clyBJndLqHlRE7AJ8DDgKeCzwkoh4bJs1SJK6oe1DfIcB12Xm9Zl5N3AesKzlGiRJHdD2Ib5FwPqe6Q3AU3s7RMRKYCVAZjIyMjLYGr921WCfb9nA/94WdaXWrtTZpW21M2NKd2qdss6Kto+2xrTtPagFk7SN9U5k5qrMPDQzDy39B3pFxNUzsZw2XtZ6/62zS7V2pc4u1dqVOme41im1HVAbgCU904uBW1quQZLUAW0f4vsRsDQiDgJuBk4AXtpyDZKkDmh1DyoztwGvBS4Brm2a8qezvNpVs7z8mWStM68rdUJ3au1KndCdWrtSJ7RY64KxsbGpe0mS1DLvJCFJqpIBJUmqUmcftxERS4A1wP7AvcCqzDx7Qp8FNLdVOhq4CzgpM9eVecuBd5Su783M1XNc68uA08rkncBrMvPHZd6NwK+Ae4Bt5RL8uarzmcBFwA2l6cLMfHeZ19ptrPqs9S3Ay8rkEPAYYN/M3NLimD4Q+C6we6nhgsw8Y0Kf3cu/5SnAbcCLM/PGMu9twIpS5+sz85LZqPM+1Hoq8CpgG3Ar8MrM/L8y7x7gJ6XrTZk5a8+H6LPWk4AP0FyQBfDRzDynzGvl+99nnWcBzyqTewAPy8yFZV5rY9pTzy7AVcDNmXnMhHmtbqtd3oPaBrw5Mx8DHA6cMsltk44ClpbXSuATABGxN3AGzR8JHwacERF7zXGtNwDPyMzHA+9h+xORz8rMJ87WL9L7UCfA90otT+wJp7ZvYzVlrZn5gfE6gbcB38nMLT1d2hjT3wHPzswnAE8Enh8Rh0/oswLYmpkHA2cB7wco/54TgMcBzwc+XsZ5Lmv9D+DQsp1eAPx9z7zf9GwXs/2LtJ9aAc7vqWk8nNr8/k9ZZ2a+qWc7/QhwYc/sNsd03BtoLmKbTKvbamf3oDJzI7CxvP9VRFxLc6eKn/V0Wwasycwx4IqIWBgRBwDPBC4d/2UVEZfSDOoX5qrWzOx9CtgVNH8j1qo+x3RHfn8bK4CIGL+NVT+fbaPWlzBL/313pmx7d5bJXctr4pVJy4B3lfcXAB8te//LgPMy83fADRFxHc04/2Cuas3Mb/VMXgG8fDZqmUqf47ojz6Ol7/806nwJTXjOiYhYDLwAeB9w6iRdWt1WOxtQvSLiQOBJwJUTZk12a6VFO2mfdTuptdcK4F97pseAb0TEGPCpzJz1yzynqPNPI+LHNH9k/dflTwWmvI3VbJlqTCNiD5pfQK/taW5tTMv/SV4NHAx8LDN3uJ1m5raIuAPYp7Rf0dNv1rfTPmrtNXE7fWBEXEWzd3tmZv7L7FXad61/ERFPB34OvCkz19Py97/fMY2IRwAHAZf3NLc6psCHgLcCD97B/Fa31S4f4gMgIvYEvgS8MTN/OWH2ZLfTGNtJ+6yaotbxPs+i+eKf1tN8RGY+mebw2SnlCzdXda4DHlEOWXwEGP/CVDumwJ8D/zbh8F5rY5qZ95TDN4uBwyLikAldqtlO+6gVgIh4OXAozTmecQ8vh0tfCnwoIv54jmv9CnBgORz5TWD8PFOr49rvmNIcIrsgM+/paWttTCPiGGBzZl69k26tbqudDqiI2JXml9PnMvPCSbrs6NZKrd9yqY9aiYjHA+cAyzLztvH2zLyl/NwMfJlm13lO6szMX2bmneX9xcCuETFMpWNanMCEwzdtjmnPOm8Hvk2zN9fr92MXEUPAQ4EtzOGtwXZSKxHxHOBvgGPLIZ3xz4yP6fXls0+ay1oz87ae+j5Nc2If5mhcdzamxc620zbG9Ajg2HIB0XnAsyPisxP6tLqtdjagynHPc4FrM/ODO+i2FnhFRCwoJybvKOcuLgGeGxF7lZOjzy1tc1ZrRDyc5uToiZn58572P4qIB4+/L7X+1xzWuX/pR0QcRrMN3UbPbawiYjeaL9usPaWxz//+RMRDgWfQXHk43tbmmO4bEeNXZD0IeA7w3xO6rQWWl/cvAi4v5y7WAidExO7R3B5sKfDD2aiz31oj4knAp2jCaXNP+17lCi/K/7AcwSydf7wPtR7QM3ksfzjx39r3v8///kTEo4C96Dln0/aYZubbMnNxZh5I8/29PDMnnmNsdVvt8jmoI4ATgZ9ExDWl7e3AwwEy85PAxTSXmF9Hc5n5yWXeloh4D80vVYB3Tzj8Mxe1vpPmWO7HIwL+cOnzfsCXS9sQ8PnM/Poc1vki4DURsQ34DXBC2UC3RcT4bax2AT6Ts3sbq35qBXgh8I3M/HXPZ9sc0wOA1eU8xAOa0vKrEfFu4KrMXEsTtP9cTixvofnlQGb+NCKS5pfSNuCUCYd/5qLWDwB7Al8s4zd+6fNjgE9FxL3ls2dm5qz9Mu2z1tdHxLE0Y7cFOAla//73Uyc0F0ecV75L49oe00nN5bbqrY4kSVXq7CE+SdL8ZkBJkqpkQEmSqmRASZKqZEBJkqpkQEmSqmRASS2KiG9HxKvmug6pCwwoqWMi4qSI+P5c1yHNNgNKup8p91CTqueGKk2h3Dzzo8ArgEcAXweWZ+Zvd/KZZcDfAo+kefLsKRNvpxQR7wIOHr/fWXlsyA3AruVRBifR3AJrX2CU5gmw64BP0tyk906aW2ItLPdsex8QNE9v/TLN4yV+E81TkD9Lc/f5NwGX0twmSqqae1BSf4LmLtQHAY+n3Ndt0o7NTXTXAG8BFgJPB268TytrbmL7YeCozHww8DTgmsy8FvhL4AeZuWeWR4PTPNn0T2ie2nowzbN43tmzyP2BvWkCduV9qUWaK+5BSf358PijDyLiKzRBsCMraG6We2mZvnma67wXOCQibup9gvBE5c7urwYe3/OU2L8DPk/zqPvxZZ3R+3gMqXbuQUn9+UXP+7to7ui9I0uA/x1kZeXu6y+m2VvaGBFfi4hH76D7vsAewNURcXtE3E5zGHLfnj637uyQpFQj96Ckmbce6OfJp7+mCZZx+/fOzMxLgEvKc4TeS/PQvT9j+yeVjtI8+uRxmbmjvTUfW6DOMaCkmXcu8I2I+CrwLZpnAj04Myc+qO4a4LTysMo7+MPhOCJiP+CpwGU04XMnMP58nU3A4ojYLTPvzsx7I+LTwFkR8drM3BwRi4BDSshJneQhPmmGZeYPaR6OeRZN8HyH5uKEif0uBc4H/hO4Gvhqz+wHAG+meWz2FpqnAv9VmXc58FPgFxExWtpOo3kw5xUR8Uvgm8CjZvQfJrXMBxZKkqrkHpQkqUqeg5KmISLeDrx9klnfy8yj2q5Hmo88xCdJqpKH+CRJVTKgJElVMqAkSVUyoCRJVfp/lsmmY/b1tg0AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"ax.hist(np.max(z_sample[burnin:].astype('int'), axis=1)+1)\n",
"ax.set_xlabel('n_cluster')\n",
"fig.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 構造変化の数と同じクラスタ数=3が多数を占めており、わずかに2個・4個の場合も存在した。\n",
" - 初期値はクラスタ数2であったから、正しくクラスタ数を推定できたことがわかる。\n",
"- MLP本の図(P.96, 図6.4)とは少し異なり、クラスタ数5以上は出現しなかった。\n",
" - サンプリングのステップ数を増やせば出現する?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.2. 傾き(勾配)の事後分布"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"z_after_burnin = z_sample[burnin:,:].astype('int')\n",
"theta_after_burnin = theta_sample[burnin:,:,0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"時刻0における傾きの事後分布は、下記のようにして取得できる。"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000,)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grad = [theta_after_burnin[ii,z_after_burnin[ii,0]] for ii in range(z_after_burnin.shape[0])]\n",
"grad = np.array(grad)\n",
"grad.shape"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda3/lib/python3.7/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n",
" return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x1a1c917630>"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"sns.distplot(grad)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"この結果を全時刻で統合する。"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"grad_all = np.empty((T,n_step - burnin))\n",
"for tt in range(T):\n",
" grad = [theta_after_burnin[ii,z_after_burnin[ii,tt]] for ii in range(z_after_burnin.shape[0])]\n",
" grad_all[tt] = np.array(grad)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"ax.plot(x, np.mean(grad_all, axis=1))\n",
"ax.fill_between(x, np.mean(grad_all, axis=1) +np.std(grad_all, axis=1), np.mean(grad_all, axis=1) -np.std(grad_all, axis=1),\n",
" alpha=0.3, color=\"purple\")\n",
"ax.set_xlabel('x')\n",
"ax.set_ylabel('grad')\n",
"fig.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 元にしたMLP本の図(P.96, 図6.3)をほぼ再現できた。\n",
" - x=30,60辺りで変動が起こっているらしいことは分かる。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment