Skip to content

Instantly share code, notes, and snippets.

@bmorris3
Last active November 2, 2023 18:23
Show Gist options
  • Save bmorris3/10a9a59352271a91873054be038aeb44 to your computer and use it in GitHub Desktop.
Save bmorris3/10a9a59352271a91873054be038aeb44 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "73ed22d7-c061-4d10-9f08-7935c3f280a5",
"metadata": {},
"source": [
"# Sample posteriors with SVI in numpyro\n",
"\n",
"Brett Morris\n",
"\n",
"**Goal**: Trade slow \"perfect\" posteriors for fast \"approximate\" posteriors\n",
"\n",
"### Overview\n",
"\n",
"In this example, we have an observed stellar spectrum that we would like to fit as a mixture of models for one cool and one hot stellar spectrum. We have two PHOENIX model stellar spectra which bracket the solar effective temperature, and a publicly available \"observed\" solar spectrum. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "186c80c8-4190-4d07-9e94-8ddb91b9db54",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Set the number of cores on your machine for parallelism when \n",
"# running MCMC (SVI does not run parallel chains).\n",
"\n",
"cpu_cores = 4"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7058aac6-1007-4293-b9c7-eacefcf11dbf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"import numpyro\n",
"numpyro.set_host_device_count(cpu_cores)\n",
"\n",
"from numpyro.infer import (\n",
" MCMC, NUTS, SVI, autoguide,\n",
" Trace_ELBO, Predictive, autoguide\n",
")\n",
"from numpyro import optim, distributions as dist\n",
"\n",
"import jax\n",
"from jax import numpy as jnp\n",
"from jax import jit\n",
"from jax.scipy.signal import fftconvolve\n",
"from jax.random import PRNGKey, split\n",
"\n",
"rng_seed = 42\n",
"rng_keys = split(\n",
" PRNGKey(rng_seed), \n",
" cpu_cores\n",
")\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from astropy.table import Table\n",
"import astropy.units as u\n",
"from astropy.constants import sigma_sb\n",
"\n",
"from expecto import get_spectrum\n",
"from specutils import Spectrum1D\n",
"\n",
"import arviz\n",
"from corner import corner"
]
},
{
"cell_type": "markdown",
"id": "b12b0019-959c-4c9a-b943-7c74351ede71",
"metadata": {},
"source": [
"Download a solar spectrum from the [National Renewable Energy Laboratory](https://www.nrel.gov/grid/solar-resource/spectra.html):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "237f952b-e27d-4bcf-bbce-87c3e28bf736",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"url = \"https://www.nrel.gov/grid/solar-resource/assets/data/newguey2003.txt\"\n",
"\n",
"solar_spectrum = Table.read(\n",
" url, \n",
" format='ascii', \n",
" names=['wavelength', 'irradiance'], \n",
" units=[u.nm, u.Unit('W/(m2 nm)')],\n",
" \n",
" # choosing these rows gives us roughly\n",
" # from 200 nm to 2500 nm:\n",
" data_start=200, \n",
" data_end=2000, \n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9384345-8fd7-45ee-a2e8-1a07a96bcfd5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def to_uniform_sampling(wavelength, flux, delta_lambda):\n",
" \"\"\"\n",
" PHOENIX model spectra are at a much higher spectral resolution\n",
" than the \"observed\" solar spectrum, so we will need to convolve\n",
" the PHOENIX model spectra with a kernel to represent\n",
" instrumental broadening. \n",
" \n",
" The JAX function that we will use to convolve the model spectra \n",
" is meant for uniformly-sampled grids, but the PHOENIX models\n",
" and observations have uneven sampling. \n",
" \n",
" This function interpolates onto a linear grid.\n",
" \"\"\"\n",
" wl = np.arange(\n",
" wavelength.min(), \n",
" wavelength.max(), \n",
" delta_lambda\n",
" )\n",
" fl = np.interp(\n",
" wl, \n",
" wavelength,\n",
" flux\n",
" )\n",
" return wl, fl"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e24874a-ea05-43ba-a3fe-0f9e1640c44a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# the truth is between these temperatures:\n",
"temperatures = [5700, 5900]\n",
"colors = ['r', 'b']\n",
"\n",
"fig, ax = plt.subplots()\n",
"solar_norm = float(u.R_sun / (1*u.AU))**2\n",
"\n",
"flux_unit = u.Unit('erg / s cm^3')\n",
"\n",
"spectra = []\n",
"\n",
"for i, (teff, color) in enumerate(zip(temperatures, colors)):\n",
" # Download the PHOENIX spectrum:\n",
" spectrum = get_spectrum(teff, log_g=4.5, cache=True)\n",
" \n",
" # PHOENIX models are not normalized such that their\n",
" # total emitted power is exactly equivalent to the\n",
" # expectation for a blackbody (P=σT^4). Here we'll \n",
" # compute the term to renormalize:\n",
" power_norm = (\n",
" (sigma_sb * (teff * u.K)**4) /\n",
" np.trapz(spectrum.flux, spectrum.wavelength).to(u.W/u.m**2)\n",
" )\n",
" if i == 0:\n",
" model_norm = spectrum.flux.to_value(flux_unit).max()\n",
" \n",
" # resample onto a uniform grid:\n",
" resampled_wavelength, resampled_flux = to_uniform_sampling(\n",
" spectrum.wavelength.to_value(u.nm), \n",
" spectrum.flux.to_value(flux_unit) / model_norm / power_norm, \n",
" delta_lambda=0.01\n",
" )\n",
" \n",
" ax.plot(\n",
" resampled_wavelength, resampled_flux, \n",
" zorder=-i, color=color, alpha=0.3,\n",
" label=f'T = {teff} K'\n",
" )\n",
" \n",
" spectra.append(\n",
" Spectrum1D(resampled_flux * flux_unit, spectral_axis=resampled_wavelength*u.nm)\n",
" )\n",
"\n",
"native_solar_wavelength = solar_spectrum['wavelength'].to(u.nm).value\n",
"native_solar_flux = solar_spectrum['irradiance'].to(flux_unit).value / solar_norm / model_norm \n",
"\n",
"# resample the solar spectrum onto a uniform grid:\n",
"solar_wavelength, solar_flux = to_uniform_sampling(\n",
" native_solar_wavelength, native_solar_flux, delta_lambda=1\n",
")\n",
"\n",
"ax.semilogy(solar_wavelength, solar_flux, label='Sun', color='k')\n",
"ax.legend()\n",
"ax.set(\n",
" xlim=[300, 1000],\n",
" ylim=[0.2, 1],\n",
" xlabel='Wavelength [nm]',\n",
" ylabel=f'Irradiance [normalized]'\n",
")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b29874e-dc64-4bb9-909d-1252898fce7d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# we will only fit within a range of wavelengths [nm]:\n",
"extract_region = (solar_wavelength < 1000) & (solar_wavelength > 400)\n",
"\n",
"@jit\n",
"def convolve_spectrum(\n",
" model_wavelength, model_flux, observed_wavelength, sigma,\n",
"):\n",
" \"\"\"\n",
" Convolve the high-res spectrum with a Gaussian kernel with \n",
" stddev `sigma`. Then interpolate the result onto the wavelength\n",
" grid of the observations.\n",
" \"\"\"\n",
" kernel = jnp.exp(\n",
" -0.5 * (model_wavelength - jnp.mean(model_wavelength))**2 / \n",
" sigma**2\n",
" )\n",
" # don't forget to normalize the kernel!\n",
" kernel = kernel / jnp.sum(kernel)\n",
" \n",
" convolved_model_flux = fftconvolve(\n",
" model_flux, kernel, mode='same'\n",
" )\n",
" interp_model = jnp.interp(\n",
" observed_wavelength, \n",
" model_wavelength, \n",
" convolved_model_flux\n",
" )\n",
" return interp_model\n",
"\n",
"@partial(jit, static_argnums=np.arange(2, 6).tolist())\n",
"def model(\n",
" f_cool, \n",
" sigma,\n",
" model_wavelength=jnp.array(spectra[0].wavelength.to_value(u.nm)),\n",
" model_cool=jnp.array(spectra[0].flux.to_value(flux_unit)), \n",
" model_hot=jnp.array(spectra[1].flux.to_value(flux_unit)),\n",
" observed_wavelength=jnp.array(solar_wavelength[extract_region])\n",
"):\n",
" \"\"\"\n",
" Compute a convolved linear combination of the model spectra\n",
" \"\"\"\n",
" spectrum_mixture = (\n",
" (1 - f_cool) * model_hot + \n",
" f_cool * model_cool\n",
" )\n",
" \n",
" return convolve_spectrum(\n",
" model_wavelength, spectrum_mixture, observed_wavelength, sigma\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0cea9a5b-534a-4b4c-86ca-02b4844d214a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def numpyro_model(\n",
" save_model=False,\n",
" model_wavelength=jnp.array(spectra[0].wavelength.to_value(u.nm)),\n",
" observed_flux=jnp.array(solar_flux[extract_region]),\n",
"):\n",
" # the weight of the cooler model in the linear combination:\n",
" f_cool = numpyro.sample('$f_{\\\\rm cool}$', dist.Uniform(low=0, high=1))\n",
" \n",
" # the stddev of the Gaussian convolution kernel:\n",
" sigma = numpyro.sample('$\\\\sigma$', \n",
" dist.TwoSidedTruncatedDistribution(\n",
" dist.Normal(0, 0.2), low=0.01, high=10\n",
" )\n",
" )\n",
" \n",
" # we don't know the uncertainties of the observations, so we'll\n",
" # fit for the uncertainty:\n",
" log_beta = numpyro.sample('$\\\\beta$', dist.Uniform(low=-6, high=0))\n",
"\n",
" # synthesize the convolved linear combination of model spectra:\n",
" synth_spectrum = model(f_cool, sigma)\n",
" \n",
" if save_model:\n",
" # this gets used to produce posterior predictive samples later on:\n",
" numpyro.deterministic(\"_synth_spectrum\", synth_spectrum)\n",
" \n",
" # define the likelihood:\n",
" numpyro.sample(\n",
" 'obs', \n",
" dist.Normal(\n",
" loc=synth_spectrum,\n",
" scale=jnp.power(10.0, log_beta)\n",
" ),\n",
" obs=observed_flux\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "4c0cd8c1-db44-4461-a49e-b7e920c556bd",
"metadata": {},
"source": [
"<div class=\"alert alert-block alert-info\">\n",
"⚠️ The cell below is saved as a \"Raw\" type Jupyter cell, so it will not run the code if you execute it. If you change the type from \"Raw\" to \"Code\", you can run it.\n",
"</div>\n",
"\n",
"If you ran the cell below, you would find that the sampling speed drops very quickly during warmup, and the resulting posteriors would not have converged."
]
},
{
"cell_type": "raw",
"id": "bdfb674c-7bff-401b-803f-ac560a96f253",
"metadata": {
"tags": []
},
"source": [
"# Define a sampler, using here the No U-Turn Sampler (NUTS)\n",
"# with a dense mass matrix:\n",
"sampler = NUTS(\n",
" numpyro_model, \n",
" # dense_mass=True\n",
")\n",
"\n",
"# Monte Carlo sampling for a number of steps and parallel chains: \n",
"mcmc = MCMC(\n",
" sampler, \n",
" num_warmup=100, \n",
" num_samples=100, \n",
" num_chains=cpu_cores\n",
")\n",
"\n",
"# Run the MCMC\n",
"mcmc.run(rng_keys)\n",
"\n",
"result = arviz.from_numpyro(mcmc)\n",
"corner(result)\n",
"fig = plt.gcf()\n",
"fig.suptitle('NUTS')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "0654b372-639e-483b-b763-170bb41c0e46",
"metadata": {},
"source": [
"On my Macbook Pro, the above cell takes about 3.5 minutes to finish running.\n",
"\n",
"\n",
"### SVI\n",
"\n",
"Now we'll use SVI to integrate for posteriors on the same model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5334d73c-872d-45d1-a5c9-af51f44c462d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"\n",
"\n",
"guide = autoguide.AutoBNAFNormal(numpyro_model)\n",
"svi = SVI(\n",
" model=numpyro_model, \n",
" guide=guide, \n",
" optim=optim.Adagrad(step_size=0.1), \n",
" loss=Trace_ELBO()\n",
")\n",
"svi_result = svi.run(\n",
" rng_key=rng_keys[0], \n",
" num_steps=500\n",
")"
]
},
{
"cell_type": "markdown",
"id": "417d1940-9d7e-47d7-82c6-9aa38e4157ae",
"metadata": {},
"source": [
"On my Macbook Pro, the above cell takes about 30 seconds to finish running.\n",
"\n",
"We can see if the fit is converging to a solution when the losses asymptotically approach a minimum:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a965443c-ef32-42f2-ac8c-5d6efde96742",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"ax.semilogy(svi_result.losses - svi_result.losses.min())\n",
"ax.set(\n",
" xlabel='Steps',\n",
" ylabel='Loss'\n",
")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "55484a37-f828-46e4-b00b-1e606192bd02",
"metadata": {},
"source": [
"The posteriors can be extracted and inspected like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c56e9cc-3e6c-45f3-a287-2409fc77bfdb",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"params = svi_result.params\n",
"posteriors = guide.sample_posterior(PRNGKey(1), params, sample_shape=(5000,))\n",
"labels = [k for k, v in posteriors.items() if not k.startswith(\"_\")]\n",
"\n",
"corner(posteriors, labels=labels)\n",
"fig = plt.gcf()\n",
"fig.suptitle('SVI')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "e1214b9e-af85-436f-9257-83913527d36f",
"metadata": {},
"source": [
"To plot our inference, we can draw posterior samples and compute the spectra that they correspond to like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7baec419-17fb-49e4-9b08-6c74126e0f45",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"posterior_predictive = Predictive(\n",
" model=numpyro_model, \n",
" guide=guide,\n",
" params=params,\n",
" num_samples=50, \n",
" return_sites=['_synth_spectrum']\n",
")\n",
"y_pred = posterior_predictive(\n",
" rng_key=PRNGKey(1), \n",
" save_model=True\n",
")['_synth_spectrum']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1baebea0-6c09-44c9-b83a-f7b875b7905a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n",
"\n",
"observed_wavelength = spectra[0].wavelength.to_value(u.nm)\n",
"\n",
"for axis in ax[:2]:\n",
" axis.plot(solar_wavelength[extract_region], model_norm * model(0, 1), color='b', label='Hotter model');\n",
" axis.plot(solar_wavelength[extract_region], model_norm * model(1, 1), color='r', label='Cooler model');\n",
" axis.plot(solar_wavelength[extract_region], model_norm * solar_flux[extract_region], color='k', label='Observed')\n",
" axis.plot(solar_wavelength[extract_region], model_norm * y_pred.T, alpha=0.05, color='DodgerBlue', label=\"Posterior predictive\");\n",
"\n",
" if axis == ax[0]:\n",
" patches = axis.get_legend_handles_labels()[0][:4]\n",
" axis.legend(handles=patches)\n",
"\n",
"ax[0].set(\n",
" xlabel='Wavelength [nm]',\n",
" ylabel=f'Irradiance [{flux_unit.to_string(\"latex\")}]',\n",
")\n",
"ax[1].set(\n",
" xlabel='Wavelength [nm]',\n",
" ylabel=f'Irradiance [{flux_unit.to_string(\"latex\")}]',\n",
" xlim=[420, 500],\n",
" ylim=np.array([0.7, 1]) * model_norm,\n",
")\n",
"\n",
"temps = np.array(temperatures, dtype=np.float64)\n",
"posterior_temperature = (\n",
" temps[0] + posteriors['$f_{\\\\rm cool}$'] * (temps[1] - temps[0])\n",
")\n",
"\n",
"ax[2].hist(posterior_temperature, 30, color=\"DodgerBlue\", alpha=0.5)\n",
"ax[2].axvline(5772.0, ls='--', color='k', label=\"Convention\")\n",
"ax[2].set_xlabel('Temperature [K]')\n",
"ax[2].legend()\n",
"fig.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "8b2bd6b6-928d-4926-977c-1084e289937f",
"metadata": {},
"source": [
"For reference, the conventional value for the solar effective temperature given in [Prsa et al. (2016)](https://arxiv.org/abs/1605.09788) is $5772.0 \\pm 0.8$ K."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0be204ff-9e04-4a76-bdd6-e108bcb17790",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment