Skip to content

Instantly share code, notes, and snippets.

@jgoodson
Last active September 11, 2018 00:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jgoodson/b008f92e002617b84d378ee5ab7f207f to your computer and use it in GitHub Desktop.
Save jgoodson/b008f92e002617b84d378ee5ab7f207f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import pymc3 as pm\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"ml_model = pm.Model()\n",
"k = np.array([[[10,20,10,0], [0,10,20,50]],\n",
" [[10,20,10,0], [0,10,20,50]]])\n",
"s=k.sum(axis=(0,1))\n",
"gates = np.log10(np.array([0, 500, 3000, 7000, np.inf]))\n",
"sq2 = np.sqrt(2)\n",
"\n",
"def f(u, sd):\n",
" p = (1+pm.math.erf((gates-u[:,np.newaxis])/(sd[:,np.newaxis]*sq2)))/2\n",
" return p[:,1:]-p[:,:-1]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Sequential sampling (1 chains in 1 job)\n",
"NUTS: [start_prop, sdF, meanF]\n",
"100%|██████████| 1500/1500 [00:04<00:00, 328.19it/s]\n",
"Only one chain was sampled, this makes it impossible to run some convergence checks\n"
]
}
],
"source": [
"with ml_model:\n",
" mu = pm.Normal('meanF', mu=2.5, sd=0.5, shape=k.shape[1])\n",
" sigma = pm.TruncatedNormal('sdF', mu=1, lower=0.1, shape=k.shape[1])\n",
" sprop = pm.Dirichlet('start_prop', a=np.ones(k.shape[1]))\n",
" \n",
" q = f(mu, sigma)*sprop[:,np.newaxis]\n",
" u = pm.Deterministic('spprop2', q*s)\n",
" \n",
" trace = pm.sample(1000, njobs=1, nchains=1)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[5.86377329, 6.15639611, 1.9004066 , 8.0784005 ],\n",
" [5.57336786, 6.2781665 , 1.80904149, 7.82920926]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trace['spprop2'].mean(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Sequential sampling (1 chains in 1 job)\n",
"NUTS: [Counts, start_prop, sdF, meanF]\n",
" 6%|▌ | 89/1500 [00:00<00:05, 253.75it/s]\n"
]
},
{
"ename": "ValueError",
"evalue": "Mass matrix contains zeros on the diagonal. \nThe derivative of RV `Counts`.ravel()[0] is zero.\nThe derivative of RV `Counts`.ravel()[1] is zero.\nThe derivative of RV `Counts`.ravel()[2] is zero.\nThe derivative of RV `Counts`.ravel()[3] is zero.\nThe derivative of RV `Counts`.ravel()[4] is zero.\nThe derivative of RV `Counts`.ravel()[5] is zero.\nThe derivative of RV `Counts`.ravel()[6] is zero.\nThe derivative of RV `Counts`.ravel()[7] is zero.\nThe derivative of RV `start_prop_stickbreaking__`.ravel()[0] is zero.\nThe derivative of RV `sdF_lowerbound__`.ravel()[0] is zero.\nThe derivative of RV `sdF_lowerbound__`.ravel()[1] is zero.\nThe derivative of RV `meanF`.ravel()[0] is zero.\nThe derivative of RV `meanF`.ravel()[1] is zero.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-a86f37f9a2df>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mK\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Counts'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmu\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnjobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnchains\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[0m_log\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Sequential sampling ({} chains in 1 job)'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchains\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 468\u001b[0m \u001b[0m_print_step_hierarchy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 469\u001b[0;31m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_sample_many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0msample_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 470\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 471\u001b[0m \u001b[0mdiscard\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtune\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdiscard_tuned_samples\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py\u001b[0m in \u001b[0;36m_sample_many\u001b[0;34m(draws, chain, chains, start, random_seed, step, **kwargs)\u001b[0m\n\u001b[1;32m 513\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchains\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 514\u001b[0m trace = _sample(draws=draws, chain=chain + i, start=start[i],\n\u001b[0;32m--> 515\u001b[0;31m step=step, random_seed=random_seed[i], **kwargs)\n\u001b[0m\u001b[1;32m 516\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrace\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 517\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraces\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py\u001b[0m in \u001b[0;36m_sample\u001b[0;34m(chain, progressbar, random_seed, start, draws, step, trace, tune, model, live_plot, live_plot_kwargs, **kwargs)\u001b[0m\n\u001b[1;32m 557\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0mstrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 559\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrace\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msampling\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 560\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlive_plot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 561\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlive_plot_kwargs\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/py36/lib/python3.6/site-packages/tqdm/_tqdm.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 931\u001b[0m \"\"\", fp_write=getattr(self.fp, 'write', sys.stderr.write))\n\u001b[1;32m 932\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 933\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 934\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 935\u001b[0m \u001b[0;31m# Update and possibly print the progressbar.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/sampling.py\u001b[0m in \u001b[0;36m_iter_sample\u001b[0;34m(draws, step, start, trace, chain, tune, model, random_seed)\u001b[0m\n\u001b[1;32m 653\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstop_tuning\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 654\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerates_stats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 655\u001b[0;31m \u001b[0mpoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstates\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 656\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msupports_sampler_stats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 657\u001b[0m \u001b[0mstrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, point)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerates_stats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 247\u001b[0;31m \u001b[0mapoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstats\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 248\u001b[0m \u001b[0mpoint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_logp_dlogp_func\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray_to_full_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mapoint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mpoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstats\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/base_hmc.py\u001b[0m in \u001b[0;36mastep\u001b[0;34m(self, q0)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misfinite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menergy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpotential\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_ok\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_logp_dlogp_func\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ordering\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvmap\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 116\u001b[0m raise ValueError('Bad initial energy: %s. The model '\n\u001b[1;32m 117\u001b[0m 'might be misspecified.' % start.energy)\n",
"\u001b[0;32m~/anaconda/envs/py36/lib/python3.6/site-packages/pymc3/step_methods/hmc/quadpotential.py\u001b[0m in \u001b[0;36mraise_ok\u001b[0;34m(self, vmap)\u001b[0m\n\u001b[1;32m 199\u001b[0m errmsg.append('The derivative of RV `{}`.ravel()[{}]'\n\u001b[1;32m 200\u001b[0m ' is zero.'.format(*name_slc[ii]))\n\u001b[0;32m--> 201\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'\\n'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merrmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m~\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misfinite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Mass matrix contains zeros on the diagonal. \nThe derivative of RV `Counts`.ravel()[0] is zero.\nThe derivative of RV `Counts`.ravel()[1] is zero.\nThe derivative of RV `Counts`.ravel()[2] is zero.\nThe derivative of RV `Counts`.ravel()[3] is zero.\nThe derivative of RV `Counts`.ravel()[4] is zero.\nThe derivative of RV `Counts`.ravel()[5] is zero.\nThe derivative of RV `Counts`.ravel()[6] is zero.\nThe derivative of RV `Counts`.ravel()[7] is zero.\nThe derivative of RV `start_prop_stickbreaking__`.ravel()[0] is zero.\nThe derivative of RV `sdF_lowerbound__`.ravel()[0] is zero.\nThe derivative of RV `sdF_lowerbound__`.ravel()[1] is zero.\nThe derivative of RV `meanF`.ravel()[0] is zero.\nThe derivative of RV `meanF`.ravel()[1] is zero."
]
}
],
"source": [
"ml_model2 = pm.Model()\n",
"with ml_model2:\n",
" #same model as above\n",
" mu = pm.Normal('meanF', mu=2.5, sd=0.5, shape=k.shape[1])\n",
" sigma = pm.TruncatedNormal('sdF', mu=1, lower=0.1, shape=k.shape[1])\n",
" sprop = pm.Dirichlet('start_prop', a=np.ones(k.shape[1]))\n",
" \n",
" q = f(mu, sigma)*sprop[:,np.newaxis]\n",
" u = pm.Deterministic('spprop2', q*s)\n",
" \n",
" #Added multi-distribution using the above deterministic as a paramter\n",
" K = pm.Normal('Counts', mu=u, shape=k.shape[1:])\n",
" \n",
" trace = pm.sample(1000, njobs=1, nchains=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (3.6)",
"language": "python",
"name": "py36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment