Skip to content

Instantly share code, notes, and snippets.

@tanutarou
Last active March 3, 2020 14:28
Show Gist options
  • Save tanutarou/543e4e1e53b63f9b56e2bb68df2f475b to your computer and use it in GitHub Desktop.
Save tanutarou/543e4e1e53b63f9b56e2bb68df2f475b to your computer and use it in GitHub Desktop.
ベイズ統計の理論と方法 6.2.3 例25 再現コード
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"import numpy as np\n",
"from pystan import StanModel\n",
"import scipy as sp\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ベイズ統計の理論と方法 6.2.3 例25 再現コード\n",
"以下のモデルについて汎化誤差, WAIC, DIC1, DIC2を計算してみる\n",
"\n",
"\n",
"* 確率モデル: \n",
"$p(x,y|w) = \\frac{s(x)}{(2 \\pi \\sigma^2)^{2/2}} \\exp(- \\frac{\\| y - R_H(x, w) \\| ^2}{2 \\sigma^2})$ \n",
"$s(x) = Normal(0, 2^2 I) $\n",
"\n",
"* 回帰関数: \n",
"$R_H(x, w) = \\sum_{h=1}^H \\frac{a_h}{1 + \\exp(-b_h \\cdot x)}$\n",
"\n",
"* パラメータ: \n",
"$w = \\{(a_h \\in R^2, b_h \\in R^3); h = 1, 2, \\cdots, H \\} \\in R^{5H}$\n",
"* 事前分布: \n",
"$\\varphi(w) = Normal(0, 10^2 I)$\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"code_waic = \"\"\"\n",
"data {\n",
" int N;\n",
" int H;\n",
" vector[3] X[N];\n",
" vector[2] Y[N];\n",
"}\n",
"\n",
"parameters {\n",
" vector[2] a[H];\n",
" vector[3] b[H];\n",
"}\n",
"\n",
"model {\n",
" for (h in 1:H){\n",
" a[h] ~ normal(0, 10);\n",
" b[h] ~ normal(0, 10);\n",
" }\n",
" for (n in 1:N) {\n",
" vector[2] reg;\n",
" reg[1] = 0;\n",
" reg[2] = 0;\n",
" \n",
" for (h in 1:H){\n",
" real tmp = 0;\n",
" for(i in 1:3){\n",
" tmp += b[h][i] * X[n][i];\n",
" }\n",
" for(i in 1:2){\n",
" reg[i] += a[h][i] / (1 + exp(-tmp));\n",
" }\n",
" }\n",
" Y[n] ~ normal(reg, 0.1);\n",
" }\n",
"}\n",
"\n",
"generated quantities {\n",
" vector[N] log_likelihood;\n",
" for (n in 1:N) {\n",
" vector[2] reg;\n",
" reg[1] = 0;\n",
" reg[2] = 0;\n",
" \n",
" for (h in 1:H){\n",
" real tmp = 0;\n",
" for(i in 1:3){\n",
" tmp += b[h][i] * X[n][i];\n",
" }\n",
" for(i in 1:2){\n",
" reg[i] += a[h][i] / (1 + exp(-tmp));\n",
" }\n",
" }\n",
" log_likelihood[n] = normal_lpdf(Y[n]|reg, 0.1);\n",
" }\n",
"}\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" with open('./nn_model.pkl', \"rb\") as f:\n",
" stanmodel_waic = pickle.load(f)\n",
"except FileNotFoundError:\n",
" stanmodel_waic = StanModel(model_code=code_waic, model_name=\"nn_model\")\n",
" with open('./nn_model.pkl', \"wb\") as f:\n",
" pickle.dump(stanmodel_waic, f) "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def reg(x, a, b, h=1):\n",
" return np.sum([a[i] / (1 + np.exp(-b[i].dot(x))) for i in range(h)], axis=0)\n",
"\n",
"def p_model(x, a, b, h=1, sigma=0.1):\n",
" return np.random.multivariate_normal(reg(x, a, b, h), (sigma**2)*np.eye(2))\n",
"\n",
"def p_model_pdf(x, y, a, b, h=1, sigma=0.1):\n",
" return sp.stats.multivariate_normal.pdf(y, reg(x, a, b, h), (sigma**2)*np.eye(2))\n",
" \n",
"def data_gen(gt_a, gt_b, n=200):\n",
" # dataの作成\n",
" xs = np.random.multivariate_normal([0, 0, 0], 4*np.eye(3), n)\n",
" ys = np.array([p_model(x, gt_a, gt_b) for x in xs])\n",
" return xs, ys"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def Ln(xs, ys, a, b, h=1):\n",
" ps = np.array([p_model_pdf(xs[i], ys[i], a, b, h=h) for i in range(len(ys))])\n",
" return -np.mean(np.log(ps))\n",
"\n",
"def calc_Gn0(samples, xs, ys, true_w):\n",
" true_a, true_b = true_w\n",
" ps = np.mean(np.exp(samples), axis=0)\n",
" qs = np.array([p_model_pdf(xs[i], ys[i], true_a, true_b) for i in range(len(ys))])\n",
" return np.mean(np.log(qs/ps) + ps/qs - 1)\n",
"\n",
"def calc_WAIC(samples):\n",
" Tn = - np.mean(np.log(np.mean(np.exp(samples), axis=0)))\n",
" Vn = np.sum(np.mean(samples**2, axis=0) - np.mean(samples, axis=0)**2)\n",
" waic = Tn + Vn/samples.shape[1]\n",
" return waic\n",
"\n",
"def calc_DIC1(samples, xs, ys, a, b, h):\n",
" Tn = - np.mean(np.log(np.mean(np.exp(samples), axis=0)))\n",
" mean_a = np.mean(a, axis=0)\n",
" mean_b = np.mean(b, axis=0)\n",
" \n",
" ps = np.array([p_model_pdf(xs[i], ys[i], mean_a, mean_b, h=h) for i in range(len(ys))])\n",
" Deff = 2*np.sum(-np.mean(samples, axis=0) + np.log(ps))\n",
" \n",
" return Tn + Deff/samples.shape[1]\n",
"\n",
"def calc_DIC2(samples):\n",
" Tn = - np.mean(np.log(np.mean(np.exp(samples), axis=0)))\n",
" Deff2 = 2 * (np.mean((np.sum(samples, axis=1)) ** 2, axis=0) - (np.mean(np.sum(samples, axis=1), axis=0) ** 2))\n",
" return Tn + Deff2/samples.shape[1]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def model_eval(H, num_iter=200, show_plot=True):\n",
" ms_list = []\n",
"\n",
" # 真のパラメータ\n",
" gt_a = np.array([[-0.1, 0.1]]) \n",
" gt_b = np.array([[0.1, -0.1, 0.3]])\n",
"\n",
" waic_list = []\n",
" dic1_list = []\n",
" dic2_list = []\n",
" Gn0_list = []\n",
" for i in range(0, num_iter):\n",
" # データ生成\n",
" xs, ys = data_gen(gt_a, gt_b, n=400)\n",
"\n",
" # 事後分布からのサンプリング\n",
" standata = {\"N\":ys.shape[0], \"H\":H, \"X\":xs, \"Y\":ys}\n",
" fit = stanmodel_waic.sampling(data = standata, iter=9000, warmup=4000, thin=10, chains=3)\n",
" #print(fit)\n",
" ms = fit.extract()\n",
" ms_list.append(ms)\n",
"\n",
" # WAICの計算\n",
" waic =calc_WAIC(ms['log_likelihood'])\n",
" # DIC1の計算\n",
" dic1 = calc_DIC1(ms['log_likelihood'], xs, ys, ms['a'], ms['b'], h=H)\n",
" # DIC2の計算\n",
" dic2 = calc_DIC2(ms['log_likelihood'])\n",
"\n",
" # 経験対数損失関数の計算\n",
" Ln_w0 = Ln(xs, ys, gt_a, gt_b)\n",
" waic_list.append(waic-Ln_w0)\n",
" dic1_list.append(dic1-Ln_w0)\n",
" dic2_list.append(dic2-Ln_w0)\n",
"\n",
" Gn0 = calc_Gn0(ms['log_likelihood'], xs, ys, (gt_a, gt_b))\n",
" Gn0_list.append(Gn0)\n",
" print(f'{i:3} Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: {Gn0:.07f}, {waic-Ln_w0:.07f}, {dic1-Ln_w0:.07f}, {dic2-Ln_w0:.07f}')\n",
" with open(f'out/{i}.pkl', 'wb') as f:\n",
" pickle.dump((i, Gn0, waic - Ln_w0, dic1 - Ln_w0, dic2 - Ln_w0), f) \n",
" if num_iter > 1 and show_plot:\n",
" show_result(Gn0_list, waic_list, dic1_list, dic2_list)\n",
" \n",
"def show_result(Gn0_list, waic_list, dic1_list, dic2_list, min_x=-0.05, max_x=0.20, show_plot=True):\n",
" # print result\n",
" print('###############################result###############################')\n",
" print(f'E[Gn0 - L(w0)] = {np.mean(Gn0_list):.07f}')\n",
" print(f'E[WAIC - Ln(w0)] = {np.mean(waic_list):.07f}, diff = {np.mean(Gn0_list) - np.mean(waic_list):.07f}')\n",
" print(f'E[DIC1 - Ln(w0)] = {np.mean(dic1_list):.07f}, diff = {np.mean(Gn0_list) - np.mean(dic1_list):.07f}')\n",
" print(f'E[DIC2 - Ln(w0)] = {np.mean(dic2_list):.07f}, diff = {np.mean(Gn0_list) - np.mean(dic2_list):.07f}')\n",
"\n",
" # plot\n",
" if show_plot:\n",
" f, axes = plt.subplots(2, 2, figsize=(7, 7), sharex=True, sharey=True)\n",
" axes[0,0].set_xlim([min_x, max_x])\n",
"\n",
" bin_size = 100\n",
" bins=[min_x + (max_x - min_x)* i / bin_size for i in range(bin_size)]\n",
" axes[0, 0].hist(Gn0_list, bins=bins, color=\"skyblue\")\n",
" axes[0, 1].hist(waic_list, bins=bins, color=\"olive\")\n",
" axes[1, 0].hist(dic1_list, bins=bins, color=\"gold\")\n",
" axes[1, 1].hist(dic2_list, bins=bins, color=\"teal\")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0037873, 0.0102097, 0.0097870, 0.0131894\n",
" 1 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0010726, 0.0123298, 0.0131563, 0.0170095\n",
" 2 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0084154, 0.0073628, 0.0066426, 0.0080363\n",
" 3 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0081765, 0.0057835, 0.0050549, 0.0071082\n",
" 4 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0074267, 0.0081477, 0.0066685, 0.0073407\n",
" 5 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0026439, 0.0097422, 0.0099751, 0.0115195\n",
" 6 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0038766, 0.0080024, 0.0088971, 0.0095032\n",
" 7 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0015165, 0.0100378, 0.0114041, 0.0114947\n",
" 8 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0095293, 0.0081587, 0.0062285, 0.0090910\n",
" 9 Gn0 - L(w0), waic-Ln_w0, dic1-Ln_w0, dic2-Ln_w0: 0.0097062, 0.0040909, 0.0032489, 0.0048985\n",
"###############################result###############################\n",
"E[Gn0 - L(w0)] = 0.0056151\n",
"E[WAIC - Ln(w0)] = 0.0083865, diff = -0.0027714\n",
"E[DIC1 - Ln(w0)] = 0.0081063, diff = -0.0024912\n",
"E[DIC2 - Ln(w0)] = 0.0099191, diff = -0.0043040\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAa8AAAGbCAYAAABzgB+6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAASoklEQVR4nO3cf4zkd33f8de7dwbSBjU4t00RxjmQCJVBVaxeiVTUVnUTxZAUKiV/GBUUtUin/ohE1EgRCFVh+5ebSlH6R6ToRChUSePQErUICUUm4EZIAXJnG+MfUBxCFBCNz9AU3FaODO/+seN0Od/uzdzO7M6bezyk1c3sfOe7b2b24yffme9OdXcAYJK/cNIDAMCqxAuAccQLgHHEC4BxxAuAcU5vasdnzpzps2fPbmr3cCwuXbr0ZHfvnPQcz7KumG5da2pj8Tp79mwuXry4qd3DsaiqPzrpGfazrphuXWvKy4YAjCNeAIwjXgCMs3K8qupUVT1QVR/axEAAcC3Xc+T1tiSPrXsQAFjWSvGqqluS/FiSd29mHAC4tlWPvH4pyc8l+dYGZgGApSwdr6r68SRPdPelQ7Y5X1UXq+ri5cuX1zLgJt39wJO5+4EnT3oMONSkdbW7W9ndrZMegxvAKkder03yhqr6YpJ7ktxRVb+2f4PuvtDd57r73M7O1nwoAYxmXcFzLR2v7n5Hd9/S3WeT3JXko9395o1NBgAH8HdeAIxzXZ9t2N33JblvrZMAwJIceQEwjngBMI54ATCOeAEwjngBMI54ATCOeAEwjngBMI54ATCOeAEwjngBMI54ATCOeAEwjngBMI54ATCOeAEwjngBMI54ATCOeAEwjngBMI54ATCOeAEwjngBMI54ATCOeAEwjngBMI54ATCOeAEwjngBMI54ATCOeAEwjngBMI54ATDO0vGqqhdU1aeq6tNV9UhV7W5yMAA4yOkVtn06yR3d/VRV3ZTk41X14e7+xIZmA4CrWjpe3d1JnlpcvWnx1ZsYCgAOs9J7XlV1qqoeTPJEknu7+5NX3H6+qi5W1cXLly+vc861ufuBJ3P3A0+e9BiwtAnr6kq7u5Xd3TrpMfgOtlK8uvub3f2DSW5J8pqqevUVt1/o7nPdfW5nZ2edc8INy7qC57qusw27+0+TfCzJnesdBwCubZWzDXeq6nsWl78ryY8k+eymBgOAg6xytuGLk7yvqk5lL3rv7+4PbWYsADjYKmcbPpTk9g3OAgBL8QkbAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMs3S8quqlVfWxqnq0qh6pqrdtcjAAOMjpFbZ9JsnPdvf9VfXCJJeq6t7ufnRDswHAVS195NXdX+nu+xeXv5HksSQv2dRgAHCQVY68/lxVnU1ye5JPXvH980nOJ8mtt956xNHW6+4HnjzpEeC6bPO62m93t056BG4gK5+wUVXfneQDSX6mu7++/7buvtDd57r73M7OzrpmhBuadQXPtVK8quqm7IXr17v7tzYzEgAcbpWzDSvJryZ5rLt/cXMjAcDhVjnyem2StyS5o6oeXHy9fkNzAcCBlj5ho7s/nsQ7sgCcOJ+wAcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOEvHq6reU1VPVNXDmxwIAK5llSOv9ya5c0NzAMDSlo5Xd/9ukq9tcBYAWMrpde6sqs4nOZ8kt9566zp3vXZ3P/DkSY8AS9n2dbW7Wyc9AjegtZ6w0d0Xuvtcd5/b2dlZ567hhmVdwXM52xCAccQLgHFWOVX+N5L8XpJXVtWXquqtmxsLAA629Akb3f2mTQ4CAMvysiEA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOIFwDjiBcA44gXAOOsFK+qurOqPldVj1fV2zc1FAAcZul4VdWpJL+c5HVJbkvypqq6bVODAcBBVjnyek2Sx7v7C939Z0nuSfLGzYwFAAc7vcK2L0nyx/uufynJD+3foKrOJzm/uPp0VT18tPHW6kySJ692wzuOeZAcMssJ2aZ5tmmWJHnlSQ+wxevqms/Vu95VxzRKku363dmmWZLtmmcta2qVeF1Td19IciFJqupid59b5/6PYpvm2aZZku2aZ5tmSfbmOekZtnVdbdMsyXbNs02zJNs1z7rW1CovG345yUv3Xb9l8T0AOFarxOv3k7yiql5WVc9LcleSD25mLAA42NIvG3b3M1X100l+O8mpJO/p7kcOucuFow63Zts0zzbNkmzXPNs0S2Kew2zTLMl2zbNNsyTbNc9aZqnuXsd+AODY+IQNAMYRLwDGES8AxhEvAMYRLwDGES8AxhEvAMYRLwDGES8AxhEvAMYRLwDGES8AxhEvAMYRLwDGES8AxhEvAMYRLwDGES8AxhEvAMYRLwDGES8AxhEvAMYRLwDGES8AxhEvAMYRLwDGES8AxhEvAMYRLwDGOb2pHZ85c6bPnj27qd3Dsbh06dKT3b1z0nM8y7piunWtqY3F6+zZs7l48eKmdg/Hoqr+6KRn2M+6Yrp1rSkvGwIwjngBMI54ATDOyvGqqlNV9UBVfWgTAwHAtVzPkdfbkjy27kEAYFkrxauqbknyY0nevZlxAODaVj3y+qUkP5fkW1e7sarOV9XFqrp4+fLlIw+3UZ+tk54AljJqXcExWTpeVfXjSZ7o7ksHbdPdF7r7XHef29nZmr/rhNGsK3iuVY68XpvkDVX1xST3JLmjqn5tI1MBwCGWjld3v6O7b+nus0nuSvLR7n7zxiYDgAP4Oy8Axrmuzzbs7vuS3LfWSQBgSY68ABhHvAAYR7wAGEe8ABhHvAAYR7wAGEe8ABhHvAAYR7wAGEe8ABhHvAAYR7wAGEe8ABhHvAAYR7wAGEe8ABhHvAAYR7wAGEe8ABhHvAAYR7wAGEe8ABhHvAAYR7wAGEe8ABhHvAAYR7wAGEe8ABhHvAAYR7wAGEe8ABhHvAAYZ+l4VdULqupTVfXpqnqkqnY3ORgAHOT0Cts+neSO7n6qqm5K8vGq+nB3f2JDswHAVS0dr+7uJE8trt60+OpNDAUAh1npPa+qOlVVDyZ5Ism93f3JK24/X1UXq+ri5cuX1znnZny29r5gi01aV7W7m9r1jgKbt1K8uvub3f2DSW5J8pqqevUVt1/o7nPdfW5nZ2edc8INy7qC57qusw27+0+TfCzJnesdBwCubZWzDXeq6nsWl78ryY8k+eymBgOAg6xytuGLk7yvqk5lL3rv7+4PbWYsADjYKmcbPpTk9g3OAgBL8QkbAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMI14AjCNeAIwjXgCMs3S8quqlVfWxqnq0qh6pqrdtcjAAOMjpFbZ9JsnPdvf9VfXCJJeq6t7ufnRDswHAVS195NXdX+nu+xeXv5HksSQv2dRgAHCQ63rPq6rOJrk9ySev+P75qrpYVRcvX7589OmAMeuqdne/7fL+67BuK8erqr47yQeS/Ex3f33/bd19obvPdfe5nZ2ddc0INzTrCp5rpXhV1U3ZC9evd/dvbWYkADjcKmcbVpJfTfJYd//i5kYCgMOtcuT12iRvSXJHVT24+Hr9huYCgAMtfap8d388SW1wFgBYik/YAGAc8QJgHPECYBzxAmAc8QJgHPECYBzxAmAc8QJgHPECYBzxAmAc8QJgHPECYBzxAmAc8QJgHPECYBzxAmAc8QJgHPECYBzxAmAc8QJgHPECYBzxAmAc8QJgHPECYBzxAmAc8QJgHPECYBzxAmAc8QJgHPECYBzxAmAc8QJgHPECYJyl41VV76mqJ6rq4U0OBADXssqR13uT3LmhOQBgaUvHq7t/N8nXNjgLACzl9Dp3VlXnk5xPkltvvXWdu16fz9a1v/fX+nhmgSVs+7qq3d2lbuuf//njGIcbxFpP2OjuC919rrvP7ezsrHPXcMOyruC5nG0IwDjiBcA4q5wq/xtJfi/JK6vqS1X11s2NBQAHW/qEje5+0yYHAYBledkQgHHEC4BxxAuAccQLgHHEC4BxxAuAccQLgHHEC4BxxAuAccQLgHHEC4BxxAuAccQLgHHEC4BxxAuAccQLgHHEC4BxxAuAccQLgHHEC4BxxAuAccQLgHHEC4BxxAuAccQLgHHEC4BxxAuAccQLgHHEC4BxxAuAccQLgHHEC4BxVopXVd1ZVZ+rqser6u2bGgoADrN0vKrqVJJfTvK6JLcleVNV3bapwQDgIKsceb0myePd/YXu/rMk9yR542bGAoCDnV5h25ck+eN917+U5If2b1BV55OcX1x9uqoePtp4a3UmyZPLbVobHSQrzXIstmmebZolSV550gNs8bpa6bmqd71rc5Ps2abfnW2aJdmuedayplaJ1zV194UkF5Kkqi5297l17v8otmmebZol2a55tmmWZG+ek55hW9fVNs2SbNc82zRLsl3zrGtNrfKy4ZeTvHTf9VsW3wOAY7VKvH4/ySuq6mVV9bwkdyX54GbGAoCDLf2yYXc/U1U/neS3k5xK8p7ufuSQu1w46nBrtk3zbNMsyXbNs02zJOY5zDbNkmzXPNs0S7Jd86xllurudewHAI6NT9gAYBzxAmCcI8Wrqm6uqnur6vOLf190wHY/tdjm81X1U/u+f9/i46YeXHz9leuc49CPraqq51fVby5u/2RVnd132zsW3/9cVf3o9fz8dcxSVWer6v/ueyx+5aizLDnP36mq+6vqmar6yStuu+rzdkKzfHPfY7OWE4WWmOdfVtWjVfVQVf1OVX3/vtvW+tjs2681taZZrKnv8DXV3df9leQXkrx9cfntSf7NVba5OckXFv++aHH5RYvb7kty7ogznEryB0lenuR5ST6d5LYrtvnnSX5lcfmuJL+5uHzbYvvnJ3nZYj+nTmiWs0kePspjcZ3znE3y15P8hyQ/uczzdtyzLG576gQem7+X5C8uLv+zfc/VWh+bK36mNbW+Wayp7+A1ddSXDd+Y5H2Ly+9L8g+vss2PJrm3u7/W3f8zyb1J7jziz91vmY+t2j/nf07y96uqFt+/p7uf7u4/TPL4Yn8nMcsmXHOe7v5idz+U5FtX3Hfdz9tRZtmEZeb5WHf/n8XVT2TvbxuTzf5OW1Prm2UTrKmjzbO2NXXUeH1fd39lcfl/JPm+q2xztY+Vesm+6/9+cdj6r67zF+5a+/+2bbr7mST/K8n3Lnnf45olSV5WVQ9U1X+rqr99hDlWmWcT993E/l5QVRer6hNVdbX/oG96nrcm+fB13ncV1tT6ZkmsqcOMXlPX/DuvqvpIkr96lZveuf9Kd3dVrXre/T/q7i9X1QuTfCDJW7J3eHsj+kqSW7v7q1X1N5L8l6p6VXd//aQH2xLfv/hdeXmSj1bVZ7r7D47jB1fVm5OcS/J317Q/a+p4WFOHG72mrnnk1d0/3N2vvsrXf03yJ1X14sUwL07yxFV2ceDHSnX3s/9+I8l/zPW9vLDMx1b9+TZVdTrJX07y1SXveyyzLF5m+WqSdPel7L12/ANHmGXZeTZx37Xvb9/vyhey977O7UeYZel5quqHsxeVN3T306vc9yDW1PHMYk0dbvyaut435xZvsv3bfPuby79wlW1uTvKH2XsT7kWLyzdn76jvzGKbm7L3WvU/vY4ZTmfvzb2X5f+/SfiqK7b5F/n2N3Tfv7j8qnz7m8tfyNHeXD7KLDvP/uzsveH55SQ3H/H5ueY8+7Z9b5775vJznrcTmuVFSZ6/uHwmyedzxRvBG3qubs/ef/Besczv9FHmsaasKWtqtcfmqAvte5P8zuJ/+Eee/WHZOxx8977t/kn23rh9PMk/XnzvLyW5lOShJI8k+XfX+0ue5PVJ/vviQXnn4nv/OntlT5IXJPlPi5//qSQv33ffdy7u97kkrzvK43GUWZL8xOJxeDDJ/Un+wVFnWXKev5m915f/d/b+n/Mjhz1vJzFLkr+V5DOLxfCZJG89psfmI0n+ZPGcPJjkg5t6bKwpa8qaWu2x8fFQAIzjEzYAGEe8ABhHvAAYR7wAGEe8ABhHvAAYR7wAGOf/AdPgg79fFECTAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 504x504 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"## 真の分布が確率モデルで実現可能かつ正則な場合\n",
"model_eval(H=1, num_iter=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## 真の分布が確率モデルで実現可能かつ正則でない場合\n",
"model_eval(H=3, num_iter=300)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"純粋な平均値\n",
"###############################result###############################\n",
"E[Gn0 - L(w0)] = 0.0338410\n",
"E[WAIC - Ln(w0)] = 0.0147883, diff = 0.0190527\n",
"E[DIC1 - Ln(w0)] = -0.1509355, diff = 0.1847766\n",
"E[DIC2 - Ln(w0)] = 0.0263153, diff = 0.0075257\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x504 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"idx_list = []\n",
"Gn0_list = []\n",
"waic_list = []\n",
"dic1_list = []\n",
"dic2_list = []\n",
"min_x, max_x = (-0.05, 0.20)\n",
"\n",
"print(\"純粋な平均値\")\n",
"for i in range(400):\n",
" with open(f'out/{i}.pkl', 'rb') as f:\n",
" res = pickle.load(f)\n",
" idx, Gn0, waic, dic1, dic2 = res\n",
" idx_list.append(idx)\n",
" Gn0_list.append(Gn0)\n",
" waic_list.append(waic)\n",
" dic1_list.append(dic1)\n",
" dic2_list.append(dic2)\n",
" \n",
"show_result(Gn0_list, waic_list, dic1_list, dic2_list, min_x=min_x, max_x=max_x)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"外れ値的な値を除去した場合\n",
"###############################result###############################\n",
"E[Gn0 - L(w0)] = 0.0208351\n",
"E[WAIC - Ln(w0)] = 0.0147883, diff = 0.0060468\n",
"E[DIC1 - Ln(w0)] = -0.0189318, diff = 0.0397669\n",
"E[DIC2 - Ln(w0)] = 0.0263153, diff = -0.0054803\n"
]
}
],
"source": [
"idx_list = []\n",
"Gn0_list = []\n",
"waic_list = []\n",
"dic1_list = []\n",
"dic2_list = []\n",
"min_x, max_x = (-0.05, 0.20)\n",
"print(\"外れ値的な値を除去した場合\")\n",
"def is_outlier(x, min_x, max_x):\n",
" return False if min_x < x < max_x else True\n",
" \n",
"for i in range(400):\n",
" with open(f'out/{i}.pkl', 'rb') as f:\n",
" res = pickle.load(f)\n",
" idx, Gn0, waic, dic1, dic2 = res\n",
" idx_list.append(idx)\n",
" if not is_outlier(Gn0, min_x, max_x):\n",
" Gn0_list.append(Gn0)\n",
" if not is_outlier(waic, min_x, max_x):\n",
" waic_list.append(waic)\n",
" if not is_outlier(dic1, min_x, max_x):\n",
" dic1_list.append(dic1)\n",
" dic2_list.append(dic2)\n",
"show_result(Gn0_list, waic_list, dic1_list, dic2_list, min_x=min_x, max_x=max_x, show_plot=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment