Skip to content

Instantly share code, notes, and snippets.

@Shirataki2
Last active August 15, 2018 17:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Shirataki2/6d2e2f962cb62c0933ab9ae882620e2d to your computer and use it in GitHub Desktop.
Save Shirataki2/6d2e2f962cb62c0933ab9ae882620e2d to your computer and use it in GitHub Desktop.
PRML/notes/RVM.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-08-15T14:31:47.650837Z",
"end_time": "2018-08-15T14:31:52.886830Z"
},
"trusted": true,
"scrolled": false
},
"cell_type": "code",
"source": "import numpy as np\nimport tqdm\nimport os\nimport os.path as path\nimport matplotlib.pyplot as plt\nimport kernels\nfrom scipy.stats import multivariate_normal\nfrom IPython.display import clear_output",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-08-15T14:31:52.905795Z",
"end_time": "2018-08-15T14:31:52.920741Z"
},
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "def sin(x):\n return np.sin(2 * np.pi * x)\n\n\ndef generate_data(func, noise_scale, N):\n x = np.random.uniform(size=N)\n t = func(x) + np.random.normal(scale=noise_scale, size=x.shape)\n return x, t\n\n\ndef display_predict(truefunc,\n X_train,\n t_train,\n X_test,\n t_test_m,\n t_test_s,\n text=None,\n name=None,\n show=True):\n fig = plt.figure(figsize=(8, 6))\n plt.fill_between(\n X_test,\n t_test_m - t_test_s,\n t_test_m + t_test_s,\n color='r',\n alpha=0.15)\n plt.plot(X_test, truefunc(X_test), label='Target Curve', c='blue')\n plt.scatter(\n X_train,\n t_train,\n facecolors='none',\n edgecolors='black',\n label='Observated Points',\n alpha=0.6)\n plt.plot(X_test, t_test_m, c='r', label='Predict Curve')\n if text is not None:\n plt.title(text)\n plt.xlabel('$x$')\n plt.ylabel('$t$')\n plt.ylim((-1.75, 1.75))\n if name is not None:\n plt.savefig('{}.png'.format(name))\n if show:\n plt.legend()\n plt.show()\n\n\ndef display_predict2(truefunc,\n X_train,\n t_train,\n model,\n v_rel=None,\n text=None,\n name=None,\n num=200):\n X_test = np.linspace(0, 1, 200)\n t_m, t_s = model.predict(X_test)\n display_predict(\n truefunc, X_train, t_train, X_test, t_m, t_s, text=text, show=False)\n\n def func(x, w): return sum([_w * x**i for i, _w in enumerate(w)])\n if v_rel is not None:\n plt.scatter(X_train[v_rel], t_train[v_rel], c='g',\n s=75, marker='^', label='Relevance Vector',alpha=0.7)\n plt.legend(loc='lower left')\n plt.xlim((0, 1))\n if name is not None:\n p = path.dirname(name)\n os.makedirs(p, exist_ok=True)\n plt.savefig('{}.png'.format(name))\n plt.show()\n plt.close()",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-08-15T14:31:57.013360Z",
"end_time": "2018-08-15T14:31:57.020341Z"
},
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "class RelevanceVectorMachineRegression(object):\n def __init__(self, kernel):\n self.kernel = kernel\n self.alpha = None\n self.beta = None\n self.isnan = False\n \n def calc_design_matrix(self, x, y=None):\n if y is None:\n phi = self.kernel(*np.meshgrid(x, x))\n else:\n phi = self.kernel(*np.meshgrid(x, y, indexing='ij'))\n return phi\n\n def fit(self, x, t):\n self.x = x\n pass\n\n def predict(self, x):\n phi = self.calc_design_matrix(x, self.x)\n t_m = phi.dot(self.w_m)\n t_v = 1/self.beta + np.sum(phi.dot(self.w_s) * phi, axis=1)\n t_s = np.sqrt(t_v)\n return t_m, t_s",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-08-15T14:31:57.803444Z",
"end_time": "2018-08-15T14:31:57.812436Z"
},
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "class RelevanceVectorMachineRegressionWithEvidence(\n RelevanceVectorMachineRegression):\n def __init__(self, kernel, alpha, beta):\n super(RelevanceVectorMachineRegressionWithEvidence, self).__init__(\n kernel=kernel)\n self.alpha = alpha\n self.beta = beta\n def fit_1(self, x, t, lr=1e-2):\n if isinstance(self.alpha,int) or isinstance(self.alpha,float):\n self.alpha = np.ones_like(x)*self.alpha\n self.phi = self.calc_design_matrix(x)\n pred_param=np.hstack([self.alpha, self.beta])\n phi = self.phi\n self.x = x\n self.t = t\n A = np.diag(self.alpha)\n self.pre = A + self.beta*phi.T.dot(phi)\n\n self.w_s = np.linalg.inv(self.pre)\n self.w_m = self.beta * self.w_s.dot(phi.T).dot(t)\n self.gamma = 1 - self.alpha*np.diag(self.w_s)\n self.alpha = self.gamma / np.square(self.w_m)\n self.alpha = np.clip(self.alpha, 0, 1e10)\n N = len(x)\n self.beta = (N - np.sum(self.gamma))/ \\\n np.sum((t - phi.dot(self.w_m))**2)\n new_param=np.hstack([self.alpha, self.beta])\n return np.allclose(pred_param, new_param, atol=1e-4, rtol=1e-4)\n\n def fit(self, x, t, iter_num=10000):\n for i in range(iter_num):\n if self.fit_1(x, t):\n break\n else:\n print('[w] hyperparameter(s) may not converge yet.')\n print('[*] {} iteration(s) has proceeded.'.format(i))",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-08-13T14:27:23.638266Z",
"end_time": "2018-08-13T14:27:23.647250Z"
},
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "class SuccessiveRVM(\n RelevanceVectorMachineRegressionWithEvidence):\n def __init__(self, kernel, beta):\n super(RelevanceVectorMachineRegressionWithEvidence, self).__init__(\n kernel=kernel, alpha=None, beta=beta)\n self.alpha = None\n\n def calc_cov_matrix(self, alpha, x, y=None):\n if y is None:\n phi = self.kernel(*np.meshgrid(x, x))\n else:\n phi = self.kernel(*np.meshgrid(x, y, indexing='ij'))\n c = 1/self.beta*np.ones_like(phi)+1/alpha*phi\n return c\n\n def fit_1(self, x, t, lr=1e-2):\n if self.alpha is None:\n self.alpha = np.ones_like(x)*1e10\n self.phi = self.calc_cov_matrix(x)\n pred_param = np.hstack([self.alpha, self.beta])\n phi = self.phi\n self.x = x\n self.t = t\n A = np.diag(self.alpha)\n self.pre = A + self.beta*phi.T.dot(phi)\n\n self.w_s = np.linalg.inv(self.pre)\n self.w_m = self.beta * self.w_s.dot(phi.T).dot(t)\n self.gamma = 1 - self.alpha*np.diag(self.w_s)\n self.alpha = self.gamma / np.square(self.w_m)\n self.alpha = np.clip(self.alpha, 0, 1e10)\n N = len(x)\n self.beta = (N - np.sum(self.gamma)) / \\\n np.sum((t - phi.dot(self.w_m))**2)\n new_param = np.hstack([self.alpha, self.beta])\n return np.allclose(pred_param, new_param, atol=1e-4, rtol=1e-4)\n\n def fit(self, x, t, iter_num=10000):\n for i in range(iter_num):\n if self.fit_1(x, t):\n break\n else:\n print('[w] hyperparameter(s) may not converge yet.')\n print('[*] {} iteration(s) has proceeded.'.format(i))",
"execution_count": 222,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-08-15T14:32:45.975199Z",
"end_time": "2018-08-15T14:32:47.002465Z"
},
"trusted": true
},
"cell_type": "code",
"source": "%%time\nN = 30\n\nX, t = generate_data(sin, 0.2, N)\n\nkernel = kernels.GaussianKernel([1., 50])\nmodel = RelevanceVectorMachineRegressionWithEvidence(kernel, 1., 1.)\nmodel.fit(X, t, 10000)\nv_rel = np.abs(model.w_m) > 1/kernel.p[1]\ntry:\n sparse = np.bincount(np.where(model.alpha < 1e9, 1, 0))[1] \nexcept:\n sparse = 0\ntry:\n infty = np.bincount(np.where(model.alpha > 1e9, 1, 0))[1]\nexcept:\n infty = 0\ndisplay_predict2(sin, X, t, model, v_rel,\n text=r'$N={:d},\\alpha:(N_0={:d},N_\\infty={:d}),\\beta={:.2f}$'.format(N, sparse, infty, model.beta))",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"text": "[*] 2897 iteration(s) has proceeded.\n",
"output_type": "stream"
},
{
"output_type": "display_data",
"metadata": {},
"data": {
"text/plain": "<Figure size 576x432 with 1 Axes>",
"image/png": "\n"
}
},
{
"name": "stdout",
"text": "Wall time: 1.02 s\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-08-11T13:05:23.623902Z",
"end_time": "2018-08-11T13:12:21.074209Z"
},
"trusted": true
},
"cell_type": "code",
"source": "N = 25\n\nX, t = generate_data(sin, 0.3, N)\n\nkernel = kernels.GaussianKernel([1., 50.])\nmodel = RelevanceVectorMachineRegressionWithEvidence(kernel, 1., 1.)\nfor i in tqdm.tnrange(10000):\n if model.fit_1(X, t):\n break\n v_rel = np.abs(model.w_m) > 1/kernel.p[1]\n try:\n sparse = np.bincount(np.where(model.alpha < 1e9, 1, 0))[1]\n except:\n sparse = 0\n try:\n infty = np.bincount(np.where(model.alpha > 1e9, 1, 0))[1]\n except:\n infty = 0\n display_predict2(sin, X, t, model, v_rel,\n text=r'$N={:d},\\alpha:(N_0={:d},N_\\infty={:d}),\\beta={:.2f}$'.format(N,sparse,infty,model.beta),name='./1-1/{:04d}'.format(i))",
"execution_count": 261,
"outputs": [
{
"output_type": "display_data",
"metadata": {},
"data": {
"text/plain": "HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "72fb6399ebf44bab8a5985bc490250cd"
}
}
},
{
"name": "stdout",
"text": "\n",
"output_type": "stream"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/6d2e2f962cb62c0933ab9ae882620e2d"
},
"kernelspec": {
"name": "conda-env-tensorflow-py",
"display_name": "Python [conda env:tensorflow]",
"language": "python"
},
"gist": {
"id": "6d2e2f962cb62c0933ab9ae882620e2d",
"data": {
"description": "PRML/notes/RVM.ipynb",
"public": true
}
},
"language_info": {
"version": "3.5.5",
"name": "python",
"pygments_lexer": "ipython3",
"file_extension": ".py",
"nbconvert_exporter": "python",
"codemirror_mode": {
"version": 3,
"name": "ipython"
},
"mimetype": "text/x-python"
},
"varInspector": {
"window_display": false,
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"library": "var_list.py",
"delete_cmd_prefix": "del ",
"delete_cmd_postfix": "",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"library": "var_list.r",
"delete_cmd_prefix": "rm(",
"delete_cmd_postfix": ") ",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
]
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment