Skip to content

Instantly share code, notes, and snippets.

@bmorris3
Last active October 18, 2022 14:54
Show Gist options
  • Save bmorris3/a69842ce9384966feba965eb0d726da6 to your computer and use it in GitHub Desktop.
Save bmorris3/a69842ce9384966feba965eb0d726da6 to your computer and use it in GitHub Desktop.
Understanding Cross Validation: Leave-One-Out
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "ac58a442-1485-474c-949e-491dd230e044",
"metadata": {
"tags": []
},
"source": [
"# Pareto smoothed importance sampling leave-one-out cross validation\n",
"\n",
"##### Brett Morris\n",
"\n",
"In this tutorial, we will construct a linear model with numpyro and carry out Pareto-Smoothed importance sampling Leave-One-Out Cross Validation on the results. \n",
"\n",
"To begin, let's import the necessary packages. In the example I'll assume there is only one CPU core available – you can increase this number to fit your machine."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7ac4a0b0-8e77-4d0f-9805-4c03d49e4448",
"metadata": {},
"outputs": [],
"source": [
"number_cpu_cores = 1\n",
"\n",
"%matplotlib inline\n",
"import warnings\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy.special import logsumexp\n",
"\n",
"import numpyro\n",
"numpyro.set_host_device_count(number_cpu_cores)\n",
"\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"from numpyro import distributions as dist\n",
"\n",
"import jax\n",
"from jax import random, numpy as jnp\n",
"\n",
"import arviz\n",
"from arviz.stats.stats import psislw, _gpdfit"
]
},
{
"cell_type": "markdown",
"id": "6cdd7639-cd99-4cba-91c8-8d22686ffd8a",
"metadata": {},
"source": [
"Now let's generate some simulated observations $y(x)$ where \n",
"$$y \\sim {\\cal N}(a x^2 + b x + c, \\sigma)$$\n",
"where $\\sigma$ is the uncertainty on each observation, and the mean is a quadratic function.\n",
"\n",
"We'll begin by including no significant outliers, but you can adjust the number of outliers with the ``n_outliers`` parameter."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "67926f21-6bd2-4532-8c8b-2dbadb37a8bf",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEICAYAAABWJCMKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAVmElEQVR4nO3de4ysd13H8c+H0qYMF2npAgvt7gKpDRzUwzlr5WKQu+3KRVCT4shFSTYYizTRmpI1BIOrkUSiJkocSxU9I0SgCCkHsApNA0jp9pxtaT1cWjx7OPScdrlIwY2Utl//mNllzp7Z3ZnZeZ7fM/N7v5LJmX2u3/PMM893fpfn9zgiBADIz8NSBwAASIMEAACZIgEAQKZIAACQKRIAAGSKBAAAmUqWAGxfZHu543Wf7StSxQMAuXEV7gOwfYakb0r6uYhY2Wq58847L2ZmZkqLCwDGwS233PKtiJjYPP3hKYLp4sWS7tru4i9JMzMzWlpaKikkABgPtrteW6vSBnCZpPenDgIAcpI8Adg+S9IrJX1wi/nztpdsL62urpYbHACMseQJQNKlkg5FxD3dZkZEIyJmI2J2YuK0KiwAwICqkABeK6p/AKB0SROA7Zqkl0q6NmUcAJCjpL2AImJN0uNSxgAAuapCFRAAIAESAABkigQAABW3vLys5eXloW+XBAAAmSIBAECmSAAAkCkSAABkigQAAJkiAQBApkgAAJApEgAAZIoEAACZIgEAQIU1m03Nzc1p3759mpmZUbPZHNq2q/JMYADAJs1mU/Pz81pbW5MkraysaH5+XpJUr9d3vX1KAABQUQsLCxsX/3Vra2taWFgYyvZJAABQUceOHetrer9IAABQUVNTU31N7xcJAAAqanFxUbVa7ZRptVpNi4uLQ9l+6mcCP9b2h2x/2fYR289JGQ8AVEm9Xlej0dDk5KRsa3p6Wo1GYygNwFL6XkB/KemTEfGrts+SVNtpBQDISb1e1549eyRJe/fuHeq2kyUA24+R9HxJb5SkiLhf0v2p4gGA3KSsAnqqpFVJf2/7sO2rbT9y80K2520v2V5aXV0tP0oAGFMpE8DDJe2T9J6IeJak/5V01eaFIqIREbMRMTsxMVF2jAAwtlImgOOSjkfETe2/P6RWQgAAlCBZAoiIk5K+Yfui9qQXS/qvVPEAQG5S9wJ6i6RmuwfQ1yX9ZuJ4ACAbSRNARCxLmk0ZAwDkijuBASBTJAAAyFTqNgAAwA6GfQfwOkoAAJApEgAAZIoEAACZIgEAQKZIAACQKRIAAAxoeXlZy8vLha9TFBIAAGSKBAAAmSIBAECmSAAAkCkSAABkigQAAJkiAQBApkgAAJCppMNB2z4q6fuSHpT0QETwdDAAKEkVSgAvjIi9XPwBjJJms6m5uTnt27dPMzMzajabhaxTJB4IAwB9ajabmp+f19ramiRpZWVF8/PzkqR6vT60dYrmiEiyY0my/d+SvispJP1tRDS2W352djaWlpZKiQ0AtjIzM6OVlZXTpk9PT+vo0aNDW2dYbN/SrZYldQJ4UkTcbfvxkq6X9JaIuHHTMvOS5iVpampqf7cDCABletjDHqZu107beuihh4a2zrBslQCStgFExN3tf++V9BFJF3dZphERsxExOzExUXaIAHCaqampvqYPuk7RkiUA24+0/ej195JeJun2VPEAQK8WFxdVq9VOmVar1bS4uDjUdYqWsgTwBEmftX2rpC9K+nhEfDJhPADQk3q9rkajocnJSdnW9PS0Go3Gto25g6xTtKRtAP2iERhAr9YfurJ3795K7aOMuDarZBsAACAdEgAAZIoEAACZIgEAQKZIAACgVuPsegNtLkgAAJApEgAAZIoEAACZIgEAQKZIAADGTlkPXtm7d2/fd/QOsk5ReCAMgLFSxQevVBUlAABjZWFhYePiv25tbU0LCwuJIqouEgCAsXLs2LG+pueMBABgrFTxwStVRQIAMFaq+OCVqiIBABgrgzx4paxeQ1VDLyAAY6der2vPnj2Sdn7wSs69hpKXAGyfYfuw7etSxwIgPzn3GkqeACS9VdKR1EEAyFPOvYaSJgDb50v6JUlXp4wDQL5y7jWUugTwF5L+QNJDieMAkKmcew0lSwC2Xy7p3oi4ZYfl5m0v2V5aXV0tKToAuRik19C4SNkL6HmSXml7TtLZkh5j+0BE/EbnQhHRkNSQpNnZ2Sg/TADjrp9eQ+MkWQkgIt4WEedHxIykyyR9evPFHwBQnNRtAACARCpxI1hE3CDphsRhABgjOVXlDIoSAABkigQAAJkiAQBApkgAAJApEgAAZKoSvYAAILUcew1RAgCATJEAACBTJAAAyBQJAAAyRQIAgEyRAAAgUyQAAMgUCQAAMkUCAIBMZZEAlpeXtby8nDoMAKiULBIAAOB0JAAAyFSyBGD7bNtftH2r7Tts/1GqWAAgRylHA/2hpBdFxA9snynps7Y/ERFfSBgTAGRjxxKA7cttnzPsHUfLD9p/ntl+xbD3AwDorpcqoCdKutn2v9i+xLaHtXPbZ9helnSvpOsj4qZhbRsAsL0dE0BE/KGkCyW9V9IbJX3N9p/Yftpudx4RD0bEXknnS7rY9jM3L2N73vaS7aXV1dXd7hIA0NZTI3BEhKST7dcDks6R9CHb7xpGEBHxP5JukHRJl3mNiJiNiNmJiYlh7A4AoN7aAH7X9i2S3iXpc5J+KiJ+W9J+Sb8y6I5tT9h+bPv9IyS9RNKXB90eAKA/vZQAzpP0moj4xYj4YET8SJIi4iFJL9/Fviclfcb2bZJuVqsN4LpdbK+rZrOpubk57du3TzMzM2o2m8PeBQCMpB27gUbE27eZd2TQHUfEbZKeNej6vWg2m5qfn9fa2pokaWVlRfPz85Kker1e5K4BDNH6UC45Pri9SGN9J/DCwsLGxX/d2tqaFhYWEkUEANUx1gng2LFjfU0HgJyMdQKYmprqazoA5GSsE8Di4qJqtdop02q1mhYXFxNFBIwfhlsfXWOdAOr1uhqNhiYnJ2Vb09PTajQaNAADgNIOBleKer2uPXv2SKIHAQB0GusSAABgayQAAKWizaA6SAAJ8UUAkBIJYAtcnAGMOxLAEJE0gOFjPK/ijH0vIACji/G8ikUJAEBlMZ5XsUgAAAZWdPUM43kVK4sqIG4AA4avjOqZqakpraysdJ2O3aMEAGAgZVTPMJ5XsUgAAAZSRvUM43kVK1kCsH2B7c/YPmL7DttvTRULgP4NMtz6IG0G9XpdBw8e1KFDh3T06FEu/kOUsgTwgKTfi4inS3q2pN+x/YyE8ZSKvs0Ydf1Wz6y3GZw4cUIRsdFmwLmfUERU4iXpo5Jeut0y+/fvjzIcOHAgJicnw3ZMT0/HgQMHelrv8OHDcfjw4Z62X6vVQtLGq1ar9bwfoCr6+a5MT0+fcs6vv6anp3fcT6/fLXQnaSm6XXe7TSz7JWlG0jFJj9luuTISwKAX57K+CEDV9Hpxtt31vLc9tH2gu60SQPJGYNuPkvRhSVdExH1d5s/bXrK9tLq6Wng8g/Rs6LdoS99m5IhHtFZP0gRg+0y1Lv7NiLi22zIR0YiI2YiYnZiYKDymQS7O/SYNvgjIEV06qydlLyBLeq+kIxHx7lRxbDbIxbnfpMEXATmiS2f1pCwBPE/S6yS9yPZy+zWXMB5Jg12c+00afBGQK7p0VkuyoSAi4rOSnGr/W1k/Ia+88kqdPHlSU1NTWlxc3PZEXVxcPOWWeGnnpMGzigGklsVYQP3q9+I8SNIA0Dt+JBWDBDAk4/SLfv2hNqP+/wCwveTdQAEAaVACALArlBRHFyUAAMgUCQDAhuXl5Y02IIw/EgAAZIo2AAClos2gOkgACfFFAJASCWALXJwx6rifAzshAQwRXzQAo4RGYJyCR1UC+aAEgA3rD7ZZH9Ru/cE2khjXCBhDlABGTJH9tAd5GhqA0UUCwAYeVQnkhQSADbt5VCV3kAKjhwSADTyqMm90AMgPCQAbeFRlvtY7AJw4cUIRsdEBgCQw3pL2ArJ9jaSXS7o3Ip6ZMha0jNODbcoyDjdcbdcBgB8A4yt1CeAfJF2SOIaxR/18fvqtzqEDQJ6SJoCIuFHSd1LGAIybQapzdtMBAKMrdQlgR7bnbS/ZXlpdXU0dDhKiJNObQe7noANAniqfACKiERGzETE7MTGROhyg8gapzqEDQJ4qnwDwY3TTQy8Grc6p1+s6ePCgDh06pKNHj3LxzwAJYETQTQ+9ojoHvUqaAGy/X9J/SrrI9nHbb0oZT5VVeZweSibVQnUOepX0PoCIeG3K/Y+SqnbTYwTRauJ+DvSCKqARUWY3vb179/Z80ahyyQTA9kgAI6Kq9bpVLZmUheovjDISwIioar1uzjcQ0TCPUUcCGCGDdNMr+hdqVUsmZaD6C6OOBDDGyviFWlbJpMyqll7vOM69+gujj2cCj7GyRngsusdJVXsaTU1NaWVlpev0IpQx6ig9hvJCCWCMVf0Xaq+/tHdT1VLk+EE5V39hPJAAxti4NNBWNZFVtWEe6BVVQGNscXHxlKoTaTR/oZZd1dKPKt9wVbV4UD2UAMbYuPxCpaoFKAYlgDFX5V+ovVpPWFdeeaVOnjypqakpLS4ujlwiA6qGBICRMA6JDKgaEsCI4eIHYFhoAwCATJEAMBT9jCBaZQzuhpxQBQS0lXnHcRl39QI7oQQAtFV5cDdKJihC6kdCXmL7K7bvtH1VylhQrjIuaP3uo6p3HDPsNAoTEUleks6QdJekp0o6S9Ktkp6x3Tr79+8P9O/w4cNx+PDh1GFsOHDgQNRqtZC08arVanHgwIFt1+vn/zHIPqanp09Zfv01PT3dx/+uN/38X8qMC+NJ0lJ0uaamLAFcLOnOiPh6RNwv6QOSXpUwnrFVtQbaMqpaBtlHVe84rmrJBKMvZQJ4sqRvdPx9vD0NY27QC1o/iWyQfVR16IxxGdQP1ZMyAbjLtDhtIXve9pLtpdXV1RLCQtHKuKANuo9BnrpWtKqWTDD6UiaA45Iu6Pj7fEl3b14oIhoRMRsRsxMTE6UFh+KUcUEbp4tmVUsmGH0p7wO4WdKFtp8i6ZuSLpP06wnjQUnKGNxt3AaQYywkFCFZAoiIB2xfLulTavUIuiYi7kgVD8pVxgWNiyawvaT3AUTEwYj4yYh4WkSMXtkcGAA3daEqGAoCKFFVH3CPPDEUBFCiKg83gfyQAIAScVMXqoQEAJSIm7pQJSQAoETjdH8CRh+NwMAmRXYZHbf7EzDaSABAyQa9P4F7GTBsJAAkU8YFjYsmsDXaAAAgUyQAAMgUCQAAMkUCAIBMkQAAIFMkAADIFAkAADLFfQBAAtyfgCqgBAAAmUqSAGz/mu07bD9kezZFDACQu1QlgNslvUbSjYn2DwDZS9IGEBFHJMl2it0DAEQbAABkq7ASgO1/l/TELrMWIuKjfWxnXtK8xFOTAGCYCksAEfGSIW2nIakhSbOzszGMbQIAqAICgGyl6gb6atvHJT1H0sdtfypFHACQM0eMTq2K7VVJKwOufp6kbw0xnGEhrv4QV3+Iqz9VjUvaXWzTETGxeeJIJYDdsL0UEZW76Yy4+kNc/SGu/lQ1LqmY2GgDAIBMkQAAIFM5JYBG6gC2QFz9Ia7+EFd/qhqXVEBs2bQBAABOlVMJAADQYawSQK/DTNu+xPZXbN9p+6qO6efavt7219r/njOkuHbcru2LbC93vO6zfUV73jtsf7Nj3lxZcbWXO2r7S+19L/W7fhFx2b7A9mdsH2l/5m/tmDfU47XV+dIx37b/qj3/Ntv7el234Ljq7Xhus/152z/TMa/rZ1pSXC+w/b2Oz+ftva5bcFxXdsR0u+0HbZ/bnlfI8bJ9je17bd++xfxiz62IGJuXpKdLukjSDZJmt1jmDEl3SXqqpLMk3SrpGe1575J0Vfv9VZL+bEhx9bXddown1eq7K0nvkPT7BRyvnuKSdFTSebv9fw0zLkmTkva13z9a0lc7PsehHa/tzpeOZeYkfUKSJT1b0k29rltwXM+VdE77/aXrcW33mZYU1wskXTfIukXGtWn5V0j6dAnH6/mS9km6fYv5hZ5bY1UCiIgjEfGVHRa7WNKdEfH1iLhf0gckvao971WS3td+/z5Jvzyk0Prd7osl3RURg9701qvd/n+THa+IOBERh9rvvy/piKQnD2n/nbY7Xzrj/cdo+YKkx9qe7HHdwuKKiM9HxHfbf35B0vlD2veu4ipo3WFv+7WS3j+kfW8pIm6U9J1tFin03BqrBNCjJ0v6Rsffx/XjC8cTIuKE1LrASHr8kPbZ73Yv0+kn3+XtIuA1w6pq6SOukPRvtm9xa3TWftcvKi5Jku0ZSc+SdFPH5GEdr+3Ol52W6WXdIuPq9Ca1fkmu2+ozLSuu59i+1fYnbO/pc90i45LtmqRLJH24Y3JRx2snhZ5bI/dQeO9+mOluT6HZdVeo7eLqcztnSXqlpLd1TH6PpHeqFec7Jf25pN8qMa7nRcTdth8v6XrbX27/chnYEI/Xo9T6ol4REfe1Jw98vLrtosu0zefLVssUcq7tsM/TF7RfqFYC+PmOyUP/TPuI65Ba1Zs/aLfP/KukC3tct8i41r1C0uciovOXeVHHayeFnlsjlwBi98NMH5d0Qcff50u6u/3+HtuTEXGiXcy6dxhx2e5nu5dKOhQR93Rse+O97b+TdF2ZcUXE3e1/77X9EbWKnzcq8fGyfaZaF/9mRFzbse2Bj1cX250vOy1zVg/rFhmXbP+0pKslXRoR316fvs1nWnhcHYlaEXHQ9t/YPq+XdYuMq8NpJfACj9dOCj23cqwCulnShbaf0v61fZmkj7XnfUzSG9rv3yCp5wfX7KCf7Z5W99i+CK57tVrPVC4lLtuPtP3o9feSXtax/2THy7YlvVfSkYh496Z5wzxe250vnfG+vt1j49mSvteuuupl3cLisj0l6VpJr4uIr3ZM3+4zLSOuJ7Y/P9m+WK3r0Ld7WbfIuNrx/ISkX1DHOVfw8dpJsefWsFu1U77U+rIfl/RDSfdI+lR7+pMkHexYbk6tXiN3qVV1tD79cZL+Q9LX2v+eO6S4um63S1w1tb4IP7Fp/X+S9CVJt7U/5Mmy4lKrl8Gt7dcdVTlealVnRPuYLLdfc0Ucr27ni6Q3S3pz+70l/XV7/pfU0QNtq3NtSMdpp7iulvTdjuOztNNnWlJcl7f3e6tajdPPrcLxav/9Rkkf2LReYcdLrR97JyT9SK1r15vKPLe4ExgAMpVjFRAAQCQAAMgWCQAAMkUCAIBMkQAAIFMkAADIFAkAADJFAgB2wfbPtgedO7t9x+gdtp+ZOi6gF9wIBuyS7T+WdLakR0g6HhF/mjgkoCckAGCX2mOx3Czp/9Qa1uDBxCEBPaEKCNi9cyU9Sq0nk52dOBagZ5QAgF2y/TG1nsj0FLUGnrs8cUhAT0bueQBAldh+vaQHIuKfbZ8h6fO2XxQRn04dG7ATSgAAkCnaAAAgUyQAAMgUCQAAMkUCAIBMkQAAIFMkAADIFAkAADJFAgCATP0/3hw9cqWCNKQAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"np.random.seed(42) \n",
"n = 25\n",
"true_slope = 2\n",
"true_int = 0.5\n",
"err_scale = 0.5\n",
"\n",
"x = np.linspace(-1, 1, n)\n",
"y = 2 * true_slope * x ** 2 + true_slope * x + true_int\n",
"y += np.random.normal(0, err_scale, n)\n",
"\n",
"# Set the number of outliers in this dataset:\n",
"n_outliers = 0\n",
"\n",
"if n_outliers > 0:\n",
" outliers = np.random.randint(0, n, size=n_outliers)\n",
" y[outliers] += np.random.normal(0, 5 * err_scale, len(outliers))\n",
"\n",
"yerr = 1 * err_scale\n",
"plt.errorbar(x, y, yerr, fmt='o', color='k', ecolor='silver')\n",
"plt.gca().set(xlabel='x', ylabel='y');"
]
},
{
"cell_type": "markdown",
"id": "ab2852b1-274e-43f5-bec6-65a248280042",
"metadata": {},
"source": [
"Let's construct a simple linear model with ``numpyro``. The model computes the linear regression coefficients $\\beta$ which give $y = {\\bf X}\\beta$, where ${\\bf X} = [{\\bf x}^2~~{\\bf x}~~ {\\bf 1}]$ is the Vandermonde matrix."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "65c2562a-7051-4189-bc5d-c5e3ae1a9353",
"metadata": {},
"outputs": [],
"source": [
"X = np.vander(x, 3)\n",
"\n",
"def model():\n",
" # linear regression coefficients beta\n",
" beta = numpyro.sample(\n",
" 'beta', dist.Uniform(\n",
" low=np.array([0, 0, 0]), high=np.array([5, 5, 5])\n",
" )\n",
" )\n",
" # uncertainty on y\n",
" sigma = numpyro.sample(\n",
" 'sigma', dist.HalfNormal(0.1)\n",
" )\n",
" # define the linear mean model\n",
" mean_model = numpyro.deterministic(\n",
" 'mean_model', jnp.dot(X, beta)\n",
" )\n",
" # define the Gaussian likelihood of observations y\n",
" numpyro.sample(\n",
" 'obs', dist.Normal(mean_model, sigma), obs=y\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "fee942a8-e86b-4063-b74f-3438ae8aa67a",
"metadata": {},
"source": [
"With our model defined, we are ready to sample from the posterior distributions for $\\beta$ and $\\sigma$. We'll sample with the No U-Turn Sampler (NUTS) for 5000 steps, after 100 warmup steps. By default the number of chains will be equal to the number of CPU cores you set in the uppermost cell of this notebook. \n",
"\n",
"We'll then run the sampler."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cf421342-26d8-4674-92ee-974a38ef1881",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 5100/5100 [00:02<00:00, 2184.59it/s, 15 steps of size 4.77e-01. acc. prob=0.95]\n"
]
}
],
"source": [
"rng_key = random.PRNGKey(0)\n",
"\n",
"mcmc = MCMC(\n",
" sampler=NUTS(\n",
" model,\n",
" ),\n",
" num_warmup=100,\n",
" num_samples=5000,\n",
" chain_method='parallel',\n",
" num_chains=len(jax.devices()),\n",
")\n",
"mcmc.run(\n",
" rng_key,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "16908ccb-d6dc-4b43-b3e4-146476301e8c",
"metadata": {},
"source": [
"For convenience we'll extract the log likelihood from the ``mcmc`` object using ``arviz``. We'll also define a variable ``b`` that represents the number of draws in the posteriors:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d7ff774f-879a-40f1-9e73-128767cf0d02",
"metadata": {},
"outputs": [],
"source": [
"idata = arviz.from_numpyro(mcmc)\n",
"b = 1 / (mcmc.num_samples * mcmc.num_chains)\n",
"\n",
"# Stack the chain and draw dimensions of the log likelihood\n",
"log_likelihood = idata.log_likelihood['obs'].stack(\n",
" __sample__=['chain', 'draw']\n",
")"
]
},
{
"cell_type": "markdown",
"id": "cf65e833-e4d1-41c4-9916-1faa922f7c88",
"metadata": {},
"source": [
"The log pointwise predictive density or ${\\rm lppd}_{\\rm LOO}$ is given by\n",
"$$ \\log \\prod_{i=1}^{n} \\int p(y_i | \\theta) p(\\theta^{s}) d\\theta \n",
"\\approx \\sum_{i=1}^n \\log \\frac{\\sum_s p(\\theta^{s})}{S} \\\\\n",
"= \\sum_{i=1}^n \\left(\\log \\sum_s p(\\theta^{s}) - \\log(S)\\right)\n",
"$$\n",
"for posterior probability $p(\\theta^{s} | y)$ and samples $s = 1, 2, \\dots, S$. \n",
"\n",
"We can succinctly implement this log of the sum of the exponentials of the log-likelihoods with the ``logsumexp`` function:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "24b07369-75bd-4531-ac6f-f1b51c1f78a5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-14.095360571471577"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# evaluates the lppd for every point\n",
"lppd_loo_i = logsumexp(log_likelihood, axis=1, b=b)\n",
"\n",
"# the sum of the lppd's is the lppd_loo:\n",
"lppd_loo = lppd_loo_i.sum()\n",
"\n",
"lppd_loo"
]
},
{
"cell_type": "markdown",
"id": "9f37e1ff-eeab-40b3-97d2-b69bc668e44a",
"metadata": {},
"source": [
"The expected log predictive density ${\\rm elpd}_{\\rm LOO}$ is given by\n",
"\\begin{equation}\n",
"{\\rm elpd}_{\\rm LOO} = \\sum_n \\log p(y_i | y_{-i}),\n",
"\\end{equation}\n",
"where the leave one out predictive density without data point $i$ is \n",
"\\begin{equation}\n",
"p(y_i | y_{-i}) = \\int p(y_i|\\theta) p(\\theta|y_{-i}) ~ d\\theta. \n",
"\\end{equation}\n",
"\n",
"The importance sampling uses importance ratios: \n",
"\\begin{equation}\n",
"r_i^s \\propto \\frac{\\prod_j p(y_j | \\theta^s) p(\\theta^s)}{\\prod_i p(y_i | \\theta^s) p(\\theta^s)} = \\frac{1}{\\sum_i \\log p(y_i|\\theta^s)} \\Rightarrow -{\\rm loglike}\n",
"\\end{equation}\n",
"where $j = \\{1, \\dots, j-1, j+1, \\dots, N\\}$\n",
"<!-- Posterior predictive density is the empirical average over posterior draws\n",
"\\begin{equation}\n",
"p(\\tilde{y}_i | \\theta) \\approx \\frac{1}{S} \\sum_S p(\\tilde{y}_i|\\theta^s). \n",
"\\end{equation}\n",
" -->"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "eaece49c-dafe-465e-94d5-1b2985c3fbbd",
"metadata": {},
"outputs": [],
"source": [
"log_weights = (\n",
" - log_likelihood - logsumexp(-log_likelihood, axis=1, b=b)[:, None]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "181896db-8ae6-4df5-a562-c94f107e5aab",
"metadata": {},
"source": [
"Let's plot the pointwise weights for every posterior sample, and compare them to the expectation if all samples had equal weight, remembering that the Generalized Pareto distribution is given by: \n",
"$$p(y | u, \\sigma, k) = \n",
"\\begin{cases}\n",
"\\frac{1}{\\sigma} \\left(1 + k \\left(\\frac{y-u}{\\sigma}\\right)\\right)^{-\\frac{1}{k} - 1 } & k \\neq 0\\\\\n",
"\\frac{1}{\\sigma} \\exp{\\left(\\frac{y-u}{\\sigma}\\right)} & k = 0\\\\\n",
"\\end{cases}$$"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1352cbbd-c293-4e27-b4b6-94cbacc484d0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"reff = 0.8030332158992549\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x864 with 25 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# effective sample size\n",
"ess = arviz.ess(idata, method='mean')\n",
"# relative mcmc efficiency\n",
"reff = np.mean([ess[v].mean() for v in ['beta', 'sigma']]) * b\n",
"print(f'reff = {reff}')\n",
"# choose the length of the tail to consider based on n_samples, reff:\n",
"tail_len = int(3 * ((1/b) / reff)**0.5)\n",
"sorted_probs = np.sort(np.exp(log_weights.values), axis=1) * b\n",
"\n",
"def khat(k, M):\n",
" # Regularization for khat\n",
" return (M * k + 10 * 0.5)/(M + 10)\n",
"\n",
"gpd_k, gpd_sigma = np.array([\n",
" _gpdfit(p[-tail_len:] - p[-tail_len]) for p in sorted_probs\n",
"]).T \n",
"u = np.finfo(float).tiny\n",
"elws = np.array([p[-tail_len:] - p[-tail_len] for p in sorted_probs])\n",
"\n",
"panels_per_side = int(gpd_k.shape[0] ** 0.5)\n",
"fig, ax = plt.subplots(\n",
" panels_per_side, panels_per_side, figsize=(10, 12)\n",
")\n",
"counter = -1\n",
"\n",
"for i, (k, s, elw) in enumerate(\n",
" zip(gpd_k, gpd_sigma, elws)\n",
"):\n",
" weight_grid = np.linspace(0, elw.max())\n",
" ax = fig.axes[i]\n",
" ax.axvline(b, color='k')\n",
" counter += 1\n",
" ax.hist(\n",
" elw, log=True, density=True, \n",
" bins=10, alpha=0.5, color=f'C{i}'\n",
" )\n",
"\n",
" warn = 'r' if k > 0.7 else 'k'\n",
"\n",
" ax.set_title(f'$\\\\hat{{k}} = {k:.2f}$', color=warn)\n",
" ax.set_xlabel('$w_i$')\n",
" \n",
" gpd = 1 / s * (\n",
" 1 + k * (weight_grid - u)/s\n",
" ) ** (-1/k - 1)\n",
" ax.plot(weight_grid, gpd, color=f'C{i}')\n",
" plt.setp(ax.get_xticklabels(), rotation=25, ha='right', color=warn)\n",
" \n",
"while counter < panels_per_side**2 - 1: \n",
" counter += 1\n",
" fig.axes[counter].axis('off')\n",
"fig.suptitle('$\\\\leftarrow$low lnlike \\t\\t\\t high lnlike$\\\\rightarrow$')\n",
"fig.tight_layout();"
]
},
{
"cell_type": "markdown",
"id": "1e6ae487-ce66-48aa-94a4-54d9f4af4c93",
"metadata": {},
"source": [
"Let's plot the posterior draws in the data space for all samples (left) and where the opacity of the line is weighted by the likelihood weighting (right): "
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b5fcf704-5c68-4c42-879c-64c80b009215",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"pred = Predictive(model, mcmc.get_samples(), return_sites=['mean_model'])\n",
"\n",
"predictions = pred(rng_key)\n",
"\n",
"skip = 100\n",
"fig, ax = plt.subplots(1, 2, figsize=(8, 3))\n",
"\n",
"ax[0].errorbar(x, y, yerr, fmt='o', color='k', ecolor='silver')\n",
"ax[0].plot(x, predictions['mean_model'][::skip].T, alpha=1, color='DodgerBlue');\n",
"ax[0].set(\n",
" xlabel='x', ylabel='y', title='all samples'\n",
")\n",
"\n",
"ax[1].errorbar(x, y, yerr, fmt='o', color='k', ecolor='silver')\n",
"\n",
"weights = np.exp(logsumexp(log_weights, 0, b=b))\n",
"weights /= weights.sum()\n",
"\n",
"for m, a in zip(predictions['mean_model'][::skip],\n",
" weights[::skip]):\n",
" ax[1].plot(x, m, alpha=600 * a, color='DodgerBlue');\n",
"ax[1].set(\n",
" xlabel='x', ylabel='y', title='lnlike-weighted'\n",
");"
]
},
{
"cell_type": "markdown",
"id": "104068f5-8132-4937-854b-cfea4ec32bc0",
"metadata": {},
"source": [
"The expected log predictive density from LOO-CV is given by: "
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3e17ab90-f92d-4289-8a69-131d7c07b0e1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-18.73730239689501"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"elpd_loo = logsumexp(log_weights + log_likelihood, axis=1, b=b).sum()\n",
"\n",
"elpd_loo"
]
},
{
"cell_type": "markdown",
"id": "247605b9-256b-484f-b515-9a8d20fdb9c0",
"metadata": {},
"source": [
"The effective number of free parameters is simply:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "83e31def-f476-44dc-af47-1d83d4095c7f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4.641941825423434"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"p_loo = lppd_loo - elpd_loo\n",
"\n",
"p_loo"
]
},
{
"cell_type": "markdown",
"id": "eaeaf098-e256-4de3-af9e-91132b796bef",
"metadata": {},
"source": [
"Now let's us the PSIS log weighting built into ``arviz``. We supply the (negative) log likelihood and it returns the Pareto smoothed log weights, and the shape parameter $\\hat{k}$ for each $i$:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "3e41eb23-fb54-405c-94de-f99113bccb71",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"corrected_log_weights, pareto_shape = psislw(-log_likelihood, reff)\n",
"\n",
"ax = plt.gca()\n",
"ax.plot(x, pareto_shape, 'ok');\n",
"\n",
"ax.set(xlabel='x', ylabel='$\\\\hat{k}$')\n",
"\n",
"# Label the places where the pareto smoothing is invalid\n",
"for k, ls, label in zip([0.7, 1], [':', '-'], ['ok', 'bad']):\n",
" ax.axhline(k, ls=ls, color='r')\n",
" ax.annotate(label, (0, k), color='r', va='bottom')"
]
},
{
"cell_type": "markdown",
"id": "c8ff21ab-502e-4f83-acbf-958665f7ab75",
"metadata": {},
"source": [
"We can also compute all of the LOO stats with the ``loo`` function:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "18dd0d43-d027-432f-beb7-7684951eb5da",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Computed from 5000 by 25 log-likelihood matrix\n",
"\n",
" Estimate SE\n",
"elpd_loo -18.74 5.27\n",
"p_loo 4.64 -\n",
"------\n",
"\n",
"Pareto k diagnostic values:\n",
" Count Pct.\n",
"(-Inf, 0.5] (good) 24 96.0%\n",
" (0.5, 0.7] (ok) 1 4.0%\n",
" (0.7, 1] (bad) 0 0.0%\n",
" (1, Inf) (very bad) 0 0.0%"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loo = arviz.loo(idata, pointwise=True)\n",
"loo"
]
},
{
"cell_type": "markdown",
"id": "b1cc4f97-1eaf-41b8-b2d7-c4d39581583e",
"metadata": {},
"source": [
"Now let's compare a poorly-fitting model ``less_complex_model`` with the original ``model`` using LOO: "
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9ceebfcb-08f4-4236-af60-201cf4325440",
"metadata": {},
"outputs": [],
"source": [
"X_2 = np.vander(x, 2)\n",
"\n",
"def less_complex_model():\n",
" # linear regression coefficients beta\n",
" beta = numpyro.sample(\n",
" 'beta', dist.Uniform(\n",
" low=np.array([-1, -1]), high=np.array([5, 5])\n",
" )\n",
" )\n",
" # uncertainty on y\n",
" sigma = numpyro.sample(\n",
" 'sigma', dist.HalfNormal(0.1)\n",
" )\n",
" # define the linear mean model\n",
" mean_model = numpyro.deterministic(\n",
" 'mean_model', jnp.dot(X_2, beta)\n",
" )\n",
" # define the Gaussian likelihood of observations y\n",
" numpyro.sample(\n",
" 'obs', dist.Normal(mean_model, sigma), obs=y\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cc3f24f1-680b-47d4-a85b-1f973c04f02f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 5100/5100 [00:02<00:00, 2299.17it/s, 3 steps of size 6.59e-01. acc. prob=0.95]\n"
]
}
],
"source": [
"rng_key = random.PRNGKey(0)\n",
"\n",
"less_complex_mcmc = MCMC(\n",
" sampler=NUTS(\n",
" less_complex_model,\n",
" ),\n",
" num_warmup=100,\n",
" num_samples=5000,\n",
" chain_method='parallel',\n",
" num_chains=len(jax.devices()),\n",
")\n",
"less_complex_mcmc.run(\n",
" rng_key,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "77e3c797-53f4-4202-a108-35e98f7ce0f1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"//anaconda3/envs/pymc/lib/python3.8/site-packages/arviz/stats/stats.py:694: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.7 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>rank</th>\n",
" <th>loo</th>\n",
" <th>p_loo</th>\n",
" <th>d_loo</th>\n",
" <th>weight</th>\n",
" <th>se</th>\n",
" <th>dse</th>\n",
" <th>warning</th>\n",
" <th>loo_scale</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>original</th>\n",
" <td>0</td>\n",
" <td>-18.736161</td>\n",
" <td>4.640812</td>\n",
" <td>0.000000</td>\n",
" <td>1.000000e+00</td>\n",
" <td>5.268723</td>\n",
" <td>0.000000</td>\n",
" <td>False</td>\n",
" <td>log</td>\n",
" </tr>\n",
" <tr>\n",
" <th>simpler</th>\n",
" <td>1</td>\n",
" <td>-66.038448</td>\n",
" <td>9.981362</td>\n",
" <td>47.302287</td>\n",
" <td>3.542056e-12</td>\n",
" <td>9.458131</td>\n",
" <td>10.686855</td>\n",
" <td>True</td>\n",
" <td>log</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" rank loo p_loo d_loo weight se \\\n",
"original 0 -18.736161 4.640812 0.000000 1.000000e+00 5.268723 \n",
"simpler 1 -66.038448 9.981362 47.302287 3.542056e-12 9.458131 \n",
"\n",
" dse warning loo_scale \n",
"original 0.000000 False log \n",
"simpler 10.686855 True log "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idata_less_complex = arviz.from_numpyro(less_complex_mcmc)\n",
"\n",
"arviz.compare(\n",
" dict(original=idata, simpler=idata_less_complex)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b3b55eae-68a2-4242-bcf6-28017961a2ab",
"metadata": {},
"source": [
"The warning raised by the comparison indicates that the simpler model isn't a good fit to the data, and its LOO comparison shouldn't be trusted."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6090f520-b315-4749-828b-dd0b13a4f4bc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment