Skip to content

Instantly share code, notes, and snippets.

@ferrine
Last active November 3, 2018 10:37
Show Gist options
  • Save ferrine/1c5eae72717b7f3014e7ca36b7db6fb6 to your computer and use it in GitHub Desktop.
Save ferrine/1c5eae72717b7f3014e7ca36b7db6fb6 to your computer and use it in GitHub Desktop.
Horseshoe prior fail
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import pymc3 as pm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import theano.tensor as tt\n",
"import theano\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"def gen_batch(n, w, beta):\n",
" d = len(w)\n",
" X = np.random.uniform(-10, 10, (n, 1))\n",
" X = np.sort(X, axis=0)\n",
" X = np.hstack([X ** i for i in range(d)])\n",
" t = X.dot(w) + np.random.normal(size=n) / beta ** 0.5\n",
" return X, t\n",
"\n",
"n = 200\n",
"d = 21\n",
"w_true = np.zeros(d) # lot's of ittelevant features\n",
"w_true[1] = 100 # one coef is super large\n",
"w_true[3] = -1 # one is small\n",
"beta_true = 1e-4\n",
"\n",
"X_train, t_train = gen_batch(n, w_true, beta_true)\n",
"X_test, t_test = gen_batch(n, w_true, beta_true)\n",
"\n",
"plt.scatter(X_train[:, 1], t_train, s=3, label='Train data', alpha=0.3)\n",
"plt.scatter(X_test[:, 1], t_test, s=3, label='Test data', alpha=0.3)\n",
"plt.plot(X_train[:, 1], X_train.dot(w_true), label='Ground truth')\n",
"\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')\n",
"plt.legend(ncol=3, loc=9, bbox_to_anchor=(0.5, 1.15))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"local_scale_log__ -24.040001\n",
"global_scale_log__ -1.140000\n",
"weight -19.299999\n",
"sigma_log__ -5.060000\n",
"obs -2015.530029\n",
"Name: Log-probability of test_point, dtype: float64"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_storage = theano.shared(X_train.astype(dtype='float32'))\n",
"t_storage = theano.shared(t_train.astype(dtype='float32'))\n",
"with pm.Model() as model:\n",
" # Horseshoe prior\n",
" local_shrincage = pm.HalfCauchy('local_scale', 1., shape=(21,))\n",
" global_shrincage = pm.HalfCauchy('global_scale', 1., shape=())\n",
" scale = local_shrincage * global_shrincage\n",
" weight = pm.Normal('weight', mu=0., sd=scale, shape=(21, ))\n",
" sigma = pm.HalfCauchy('sigma', 1., shape=(), testval=100)\n",
" obs = pm.Normal('obs', x_storage.dot(weight), sigma, observed=t_storage)\n",
"model.check_test_point()"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (2 chains in 2 jobs)\n",
"NUTS: [sigma, weight, global_scale, local_scale]\n",
"Sampling 2 chains: 10%|█ | 206/2000 [00:05<00:48, 37.05draws/s]\n"
]
},
{
"ename": "RuntimeError",
"evalue": "Chain 1 failed.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRemoteTraceback\u001b[0m Traceback (most recent call last)",
"\u001b[0;31mRemoteTraceback\u001b[0m: \n\"\"\"\nTraceback (most recent call last):\n File \"/Users/ferres/dev/pymc3/pymc3/parallel_sampling.py\", line 82, in run\n self._start_loop()\n File \"/Users/ferres/dev/pymc3/pymc3/parallel_sampling.py\", line 123, in _start_loop\n point, stats = self._compute_point()\n File \"/Users/ferres/dev/pymc3/pymc3/parallel_sampling.py\", line 154, in _compute_point\n point, stats = self._step_method.step(self._point)\n File \"/Users/ferres/dev/pymc3/pymc3/step_methods/arraystep.py\", line 247, in step\n apoint, stats = self.astep(array)\n File \"/Users/ferres/dev/pymc3/pymc3/step_methods/hmc/base_hmc.py\", line 135, in astep\n self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)\n File \"/Users/ferres/dev/pymc3/pymc3/step_methods/hmc/quadpotential.py\", line 231, in raise_ok\n raise ValueError('\\n'.join(errmsg))\nValueError: Mass matrix contains zeros on the diagonal. \nThe derivative of RV `sigma_log__`.ravel()[0] is zero.\nThe derivative of RV `weight`.ravel()[1] is zero.\nThe derivative of RV `weight`.ravel()[3] is zero.\n\"\"\"",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;31mValueError\u001b[0m: Mass matrix contains zeros on the diagonal. \nThe derivative of RV `sigma_log__`.ravel()[0] is zero.\nThe derivative of RV `weight`.ravel()[1] is zero.\nThe derivative of RV `weight`.ravel()[3] is zero.",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-47-e5af4f687374>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest_point\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjitter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1e-4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/dev/pymc3/pymc3/sampling.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)\u001b[0m\n\u001b[1;32m 438\u001b[0m \u001b[0m_print_step_hierarchy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 439\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 440\u001b[0;31m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_mp_sample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0msample_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 441\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPickleError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[0m_log\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarning\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Could not pickle model, sampling singlethreaded.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/dev/pymc3/pymc3/sampling.py\u001b[0m in \u001b[0;36m_mp_sample\u001b[0;34m(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, use_mmap, **kwargs)\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 990\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0msampler\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 991\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mdraw\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msampler\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 992\u001b[0m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtraces\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchain\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mchain\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 993\u001b[0m if (trace.supports_sampler_stats\n",
"\u001b[0;32m~/dev/pymc3/pymc3/parallel_sampling.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 344\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_active\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 345\u001b[0;31m \u001b[0mdraw\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mProcessAdapter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecv_draw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_active\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 346\u001b[0m \u001b[0mproc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_last\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuning\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwarns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 347\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_progress\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/dev/pymc3/pymc3/parallel_sampling.py\u001b[0m in \u001b[0;36mrecv_draw\u001b[0;34m(processes, timeout)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 248\u001b[0m \u001b[0merror\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Chain %s failed.\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mproc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 249\u001b[0;31m \u001b[0msix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_from\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mold_error\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 250\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"writing_done\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[0mproc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_readable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.pyenv/versions/pymc4/lib/python3.6/site-packages/six.py\u001b[0m in \u001b[0;36mraise_from\u001b[0;34m(value, from_value)\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Chain 1 failed."
]
}
],
"source": [
"with model:\n",
" trace = pm.sample(start=model.test_point, jitter=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pymc4",
"language": "python",
"name": "pymc4"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment