Skip to content

Instantly share code, notes, and snippets.

@corneliusroemer
Created December 30, 2021 11:34
Show Gist options
  • Save corneliusroemer/2a8f383ce18bbe53f924c62e5a649731 to your computer and use it in GitHub Desktop.
Save corneliusroemer/2a8f383ce18bbe53f924c62e5a649731 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{"cells":[{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/jax/experimental/optimizers.py:28: FutureWarning: jax.experimental.optimizers is deprecated, import jax.example_libraries.optimizers instead\n"," warnings.warn('jax.experimental.optimizers is deprecated, '\n","/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/jax/experimental/stax.py:28: FutureWarning: jax.experimental.stax is deprecated, import jax.example_libraries.stax instead\n"," warnings.warn('jax.experimental.stax is deprecated, '\n"]}],"source":["import numpy as np\n","import numpyro\n","import numpyro.distributions as dist\n","from jax import random\n","from numpyro.handlers import reparam\n","from numpyro.infer import MCMC, NUTS\n","from numpyro.infer.reparam import LocScaleReparam\n"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["pupils_per_school = 10\n","n_schools = 10\n","schools = np.repeat(np.arange(0, n_schools), pupils_per_school)\n","true_outliers = np.repeat(np.random.randint(0, 2, n_schools), pupils_per_school)\n","grade = np.random.normal(true_outliers, 1)\n"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["reparam_config = {k: LocScaleReparam(0) for k in [\"mu\"]}\n","\n","# Reparametrization doesn't seem to work with discrete latent variable\n","@numpyro.handlers.reparam(config=reparam_config)\n","def model(school, grade=None):\n"," sigma = numpyro.sample(\"sigma\", dist.HalfNormal(1))\n","\n"," n_schools = len(np.unique(school))\n"," with numpyro.plate(\"plate_j\", n_schools):\n"," outlier = numpyro.sample(\"outlier\", dist.Bernoulli(0.5))\n"," mu = numpyro.sample(\"mu\", dist.Normal(outlier, sigma))\n","\n"," with numpyro.plate(\"data\", len(grade)):\n"," numpyro.sample(\"obs\", dist.Normal(loc=mu[school], scale=1), obs=grade)\n","\n","# ValueError: Incompatible shapes for broadcasting: ((1, 100), (100, 10))\n","\n"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"ename":"ValueError","evalue":"Incompatible shapes for broadcasting: ((1, 100), (100, 10))","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)","\u001b[0;32m/var/folders/1n/22rjr5g11cxf3zgt_zfn4v3r0000gn/T/ipykernel_57124/3378792922.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mmcmc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMCMC\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnuts_kernel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_warmup\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mrng_key\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPRNGKey\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mmcmc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng_key\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mschools\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrade\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrade\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/infer/mcmc.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, rng_key, extra_fields, init_params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0mmap_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mrng_key\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_chains\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 572\u001b[0;31m \u001b[0mstates_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlast_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpartial_map_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 573\u001b[0m \u001b[0mstates\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnewaxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m...\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstates_flat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/infer/mcmc.py\u001b[0m in \u001b[0;36m_single_chain_mcmc\u001b[0;34m(self, init, args, kwargs, collect_fields)\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0mrng_key\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_params\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minit_state\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 360\u001b[0;31m init_state = self.sampler.init(\n\u001b[0m\u001b[1;32m 361\u001b[0m \u001b[0mrng_key\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_warmup\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/infer/hmc.py\u001b[0m in \u001b[0;36minit\u001b[0;34m(self, rng_key, num_warmup, init_params, model_args, model_kwargs)\u001b[0m\n\u001b[1;32m 694\u001b[0m \u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng_key\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 695\u001b[0m )\n\u001b[0;32m--> 696\u001b[0;31m init_params = self._init_state(\n\u001b[0m\u001b[1;32m 697\u001b[0m \u001b[0mrng_key_init_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 698\u001b[0m )\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/infer/hmc.py\u001b[0m in \u001b[0;36m_init_state\u001b[0;34m(self, rng_key, model_args, model_kwargs, init_params)\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_init_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrng_key\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 641\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_model\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 642\u001b[0;31m init_params, potential_fn, postprocess_fn, model_trace = initialize_model(\n\u001b[0m\u001b[1;32m 643\u001b[0m \u001b[0mrng_key\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 644\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_model\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/infer/util.py\u001b[0m in \u001b[0;36minitialize_model\u001b[0;34m(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[0minit_strategy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_init_to_unconstrained_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0munconstrained_values\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 617\u001b[0m \u001b[0mprototype_params\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minv_transforms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconstrained_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minvert\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 618\u001b[0;31m (init_params, pe, grad), is_valid = find_valid_initial_params(\n\u001b[0m\u001b[1;32m 619\u001b[0m \u001b[0mrng_key\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 620\u001b[0m substitute(\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/infer/util.py\u001b[0m in \u001b[0;36mfind_valid_initial_params\u001b[0;34m(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)\u001b[0m\n\u001b[1;32m 372\u001b[0m \u001b[0;31m# Handle possible vectorization\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrng_key\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 374\u001b[0;31m (init_params, pe, z_grad), is_valid = _find_valid_params(\n\u001b[0m\u001b[1;32m 375\u001b[0m \u001b[0mrng_key\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexit_early\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 376\u001b[0m )\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/infer/util.py\u001b[0m in \u001b[0;36m_find_valid_params\u001b[0;34m(rng_key, exit_early)\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;31m# Early return if valid params found. This is only helpful for single chain,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0;31m# where we can avoid compiling body_fn in while_loop.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 360\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minit_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz_grad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_valid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minit_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbody_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 361\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnot_jax_tracer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_valid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdevice_get\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_valid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/infer/util.py\u001b[0m in \u001b[0;36mbody_fn\u001b[0;34m(state)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[0mz_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjacfwd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpotential_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 344\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 345\u001b[0;31m \u001b[0mpe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpotential_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 346\u001b[0m \u001b[0mz_grad_flat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mravel_pytree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz_grad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 347\u001b[0m \u001b[0mis_valid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misfinite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpe\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misfinite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz_grad_flat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n"," \u001b[0;31m[... skipping hidden 8 frame]\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/infer/util.py\u001b[0m in \u001b[0;36mpotential_energy\u001b[0;34m(model, model_args, model_kwargs, params, enum)\u001b[0m\n\u001b[1;32m 225\u001b[0m )\n\u001b[1;32m 226\u001b[0m \u001b[0;31m# no param is needed for log_density computation because we already substitute\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 227\u001b[0;31m log_joint, model_trace = log_density_(\n\u001b[0m\u001b[1;32m 228\u001b[0m \u001b[0msubstituted_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 229\u001b[0m )\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py\u001b[0m in \u001b[0;36mlog_density\u001b[0;34m(model, model_args, model_kwargs, params)\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mreturn\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlog\u001b[0m \u001b[0mof\u001b[0m \u001b[0mjoint\u001b[0m \u001b[0mdensity\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0ma\u001b[0m \u001b[0mcorresponding\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \"\"\"\n\u001b[0;32m--> 270\u001b[0;31m result, model_trace, _ = _enum_log_density(\n\u001b[0m\u001b[1;32m 271\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunsor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogaddexp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunsor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 272\u001b[0m )\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py\u001b[0m in \u001b[0;36m_enum_log_density\u001b[0;34m(model, model_args, model_kwargs, params, sum_op, prod_op)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msubstitute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mplate_to_enum_plate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 159\u001b[0;31m \u001b[0mmodel_trace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpacked_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mmodel_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 160\u001b[0m \u001b[0mlog_factors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0mtime_to_factors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdefaultdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# log prob factors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/handlers.py\u001b[0m in \u001b[0;36mget_trace\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mreturn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mOrderedDict\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0mcontaining\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mexecution\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \"\"\"\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/var/folders/1n/22rjr5g11cxf3zgt_zfn4v3r0000gn/T/ipykernel_57124/3996660135.py\u001b[0m in \u001b[0;36mmodel\u001b[0;34m(school, grade)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mnumpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrade\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mnumpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"obs\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmu\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mschool\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrade\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;31m# ValueError: Incompatible shapes for broadcasting: ((1, 100), (100, 10))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(name, fn, obs, rng_key, sample_shape, infer, obs_mask)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# ...and use apply_stack to send it to the Messengers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mapply_stack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minitial_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36mapply_stack\u001b[0;34m(msg)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mpointer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mpointer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandler\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreversed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_PYRO_STACK\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mhandler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0;31m# When a Messenger sets the \"stop\" field of a message,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;31m# it prevents any Messengers above it on the stack from being applied.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py\u001b[0m in \u001b[0;36mprocess_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 522\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"type\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"to_funsor\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"to_data\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 524\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mOrigPlateMessenger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 525\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpostprocess_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/numpyro/primitives.py\u001b[0m in \u001b[0;36mprocess_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0moverlap_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpected_shape\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdist_batch_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0mtrailing_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexpected_shape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0moverlap_idx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m broadcast_shape = lax.broadcast_shapes(\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mtrailing_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdist_batch_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m )\n"," \u001b[0;31m[... skipping hidden 2 frame]\u001b[0m\n","\u001b[0;32m/usr/local/Caskroom/mambaforge/base/envs/numpyro/lib/python3.10/site-packages/jax/_src/lax/lax.py\u001b[0m in \u001b[0;36mbroadcast_shapes\u001b[0;34m(*shapes)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0mresult_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_try_broadcast_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult_shape\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m raise ValueError(\"Incompatible shapes for broadcasting: {}\"\n\u001b[0m\u001b[1;32m 125\u001b[0m .format(tuple(map(tuple, shapes))))\n\u001b[1;32m 126\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult_shape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mValueError\u001b[0m: Incompatible shapes for broadcasting: ((1, 100), (100, 10))"]}],"source":["nuts_kernel = NUTS(model)\n","\n","mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=200)\n","rng_key = random.PRNGKey(0)\n","mcmc.run(rng_key, schools, grade=grade)"]}],"metadata":{"interpreter":{"hash":"15008f8347171a2128547983e1278a0bab7a00b4575d1c82f19dcd9f5c4d4af3"},"kernelspec":{"display_name":"Python 3.10.1 64-bit ('numpyro': conda)","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.1"},"orig_nbformat":4},"nbformat":4,"nbformat_minor":2}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment