Skip to content

Instantly share code, notes, and snippets.

@bamford
Last active April 21, 2024 20:50
Show Gist options
  • Save bamford/f42cebb7e71c30aa810fa8ad9d4b1ff9 to your computer and use it in GitHub Desktop.
Save bamford/f42cebb7e71c30aa810fa8ad9d4b1ff9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{"metadata":{"kernelspec":{"display_name":"Python [conda env:icl]","language":"python","name":"conda-env-icl-py"},"language_info":{"name":"python","version":"3.12.0","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"gist_info":{"gist_url":"https://gist.github.com/bamford/f42cebb7e71c30aa810fa8ad9d4b1ff9","gist_id":"f42cebb7e71c30aa810fa8ad9d4b1ff9","create_date":"2024-04-19T23:29:43Z"}},"nbformat_minor":5,"nbformat":4,"cells":[{"id":"ab842dba-9fc4-478d-9ee2-4a3bd52436e9","cell_type":"code","source":"import arviz as az\nimport pymc as pm\nimport numpy as np\nimport os\nimport matplotlib.pyplot as plt\nfrom scipy import stats\nfrom functools import partial","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"7b600b5f-6529-424f-a58b-3799870ff7b9","cell_type":"code","source":"RANDOM_SEED = 8923\nrng = np.random.default_rng(RANDOM_SEED)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"a210dc71-73cf-4b04-81c6-277491305f59","cell_type":"markdown","source":"Create a toy \"true\" background distribution, using a mixture of Normal distributions, truncated to match the range of the data.","metadata":{}},{"id":"6b0e90dc-f647-410c-8518-3ad4564bbeb0","cell_type":"code","source":"LOWER = 22\nUPPER = 26.2\nNBINS = 20\nbkgd_n_components_true = 3\nbkgd_weight_true = np.array([0.5, 0.2, 0.3])\nbkgd_mu_true = np.array([23, 25, 27])\nbkgd_sigma_true = np.array([2, 1.5, 2])\nBKGD_NSAMP = 10000 # this is the expectation of the number of background objects in the outermost region","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"ef6c1a08-7d79-4b74-b7ed-6168dae56f43","cell_type":"code","source":"def trunc_limits(mu, sigma):\n a = (LOWER - mu) / sigma\n b = (UPPER - mu) / sigma\n return a, b","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"aaa027f8-1a81-4aff-b0c7-4fb460220ddc","cell_type":"code","source":"def create_mixnorm_sample(nsamp, weight, mu, sigma):\n a, b = trunc_limits(mu, sigma)\n component = rng.choice(weight.size, size=stats.poisson.rvs(nsamp), p=weight)\n samples = stats.truncnorm.rvs(a[component], b[component], mu[component], sigma[component])\n return samples","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"83812f1b-5113-4bf1-a44a-02037607d818","cell_type":"code","source":"bkgd_samples = create_mixnorm_sample(BKGD_NSAMP, bkgd_weight_true, bkgd_mu_true, bkgd_sigma_true)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"a954947a-eaf1-49c6-b2ba-a32ff13135b2","cell_type":"code","source":"def create_density(weight, mu, sigma, n_points=1000):\n n_comp = len(weight)\n a, b = trunc_limits(mu, sigma)\n m = np.linspace(LOWER, UPPER, n_points)\n density_comp = []\n for i in range(n_comp):\n density_comp.append(stats.truncnorm.pdf(m, a[i], b[i], loc=mu[i], scale=sigma[i]) * weight[i])\n density = np.sum(density_comp, axis=0)\n return m, density, density_comp","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"870e6f4f-8732-40a1-b8da-c8d94da8e5cd","cell_type":"code","source":"m, bkgd_true, _ = create_density(bkgd_weight_true, bkgd_mu_true, bkgd_sigma_true)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"bbac206c-406d-4980-9c23-54f39e7307cc","cell_type":"code","source":"plt.plot(m, bkgd_true, label=\"true\", lw=3)\nbkgd_hist, bins, _ = plt.hist(bkgd_samples, bins=NBINS, density=True, histtype=\"step\", label=f\"~{BKGD_NSAMP} samples\")\nplt.legend();","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"ca3ef68d-1804-43a1-8866-3280042ac057","cell_type":"markdown","source":"Model the background with a simple Gaussian mixture model. At first I tried a general model with free means and sigmas. However, this (a) takes quite a long time to sample and (b) proved difficult to robustly sample with truncated distributions. Instead, I used an overlapping set of Gauusians with fixed the means and sigmas and just allowed the heights to vary. This is similar to the approach [Blanton et al. (2008)](https://iopscience.iop.org/article/10.1086/375776) used to model the SDSS galaxy luminosity function.","metadata":{}},{"id":"b412268d-0058-4c46-9071-0288916f485a","cell_type":"raw","source":"# original model, not used\nn_components = 3\nwith pm.Model() as model_bkgd:\nwith model_bkgd:\n w = pm.Dirichlet(\"w\", 1.1 * np.ones(n_components))\n mu = pm.Normal(\"mu\", mu=25, sigma=2, shape=n_components,\n transform=pm.distributions.transforms.ordered, initval=np.linspace(lower, upper, n_components))\n sigma = pm.Gamma(\"sigma\", alpha=1.5, beta=1/3, shape=n_components)\n components = pm.TruncatedNormal.dist(mu=mu, sigma=sigma, lower=lower, upper=upper)\n likelihood = pm.Mixture(\"bkgd\", w=w, comp_dists=components, observed=bkgd_samples[0])\n idata_bkgd = pm.sample_prior_predictive(samples=10)","metadata":{}},{"id":"f74974d6-d50b-4545-8866-c2abf2bd3fb9","cell_type":"code","source":"def create_bkgd_model(n_components=7):\n with pm.Model() as model:\n pm.Data(\"lower\", LOWER)\n pm.Data(\"upper\", UPPER)\n pm.Data(\"bkgd_n_components\", n_components)\n magrange = UPPER - LOWER\n # margin = 0.5 * magrange / n_components\n margin = 0.0\n mu = pm.Data(\"bkgd_mu\", np.linspace(LOWER - margin, UPPER + margin, n_components))\n sigma = pm.Data(\"bkgd_sigma\", np.ones(n_components) * (magrange + 2 * margin) / n_components)\n w = pm.Dirichlet(\"bkgd_w\", 1.1 * np.ones(n_components))\n components = pm.TruncatedNormal.dist(mu=mu, sigma=sigma, lower=LOWER, upper=UPPER)\n likelihood = pm.Mixture(\"bkgd\", w=w, comp_dists=components, observed=bkgd_samples)\n return model","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"aedd6ee9-e0bf-472f-9459-de46463fd54e","cell_type":"code","source":"bkgd_model = create_bkgd_model()","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"d281a28e-3f9b-430d-87e1-c23152eb99c3","cell_type":"code","source":"if not os.path.exists(\"idata_bkgd.nc\"):\n with bkgd_model:\n idata_bkgd = pm.sample_prior_predictive(samples=10)\n idata_bkgd.extend(pm.sample(nuts_sampler=\"numpyro\"))\n thinned_idata_bkgd = idata_bkgd.sel(chain=[0])\n idata_bkgd.extend(pm.sample_posterior_predictive(thinned_idata_bkgd, return_inferencedata=True))\n idata_bkgd.to_netcdf(\"idata_bkgd.nc\")\nelse:\n idata_bkgd = az.from_netcdf(\"idata_bkgd.nc\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"5471f73e-6357-4bd8-a88c-0c561be8db9a","cell_type":"code","source":"az.summary(idata_bkgd)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"022148e1-7770-4b0e-be5b-16856d0373c9","cell_type":"code","source":"az.plot_pair(idata_bkgd, divergences=True)\nplt.tight_layout();","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"c9cfda77-c537-4ec1-9d4e-f379d5c9a6b4","cell_type":"code","source":"az.plot_trace(idata_bkgd)\nplt.tight_layout();","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"c71232eb-5224-43cb-913d-d1f5657ebd25","cell_type":"code","source":"def get_var_mean(var_name, n_var, df, idata, batched=True):\n if var_name in idata.constant_data:\n var = idata.constant_data[var_name].data\n else:\n template = var_name\n template += \"[{}]\" if batched else \"{}\"\n var = df.loc[[template.format(i) for i in range(n_var)], \"mean\"].values\n return var\n\n\ndef get_var_samples(var_name, n_samp, idata):\n if var_name in idata.constant_data:\n var = np.ones((n_samp, 1)) * idata.constant_data[var_name].values[None, :]\n else:\n var = idata.posterior[var_name][0].values\n return var\n\n\ndef hist_quantiles_from_posterior_predictive(samples, q=(0.05, 0.95)):\n histogram_samples = np.array([np.histogram(x, bins=NBINS, range=(LOWER, UPPER), density=True)[0] for x in samples])\n bin_edges = np.linspace(LOWER, UPPER, NBINS + 1) \n hist_quantiles = np.quantile(histogram_samples, q, axis=0)\n hist_quantiles = np.concatenate((hist_quantiles, hist_quantiles[:, [-1]]), axis=-1)\n return bin_edges, hist_quantiles\n\n\ndef create_mean_density(name, n_components, idata):\n df = az.summary(idata)\n w = get_var_mean(f\"{name}_w\", n_components, df, idata)\n mu = get_var_mean(f\"{name}_mu\", n_components, df, idata)\n sigma = get_var_mean(f\"{name}_sigma\", n_components, df, idata)\n return create_density(w, mu, sigma) \n\n\ndef dens_quantiles_from_posterior(name, n_draw, idata, q=(0.05, 0.95)):\n w = get_var_samples(f\"{name}_w\", n_draw, idata)\n mu = get_var_samples(f\"{name}_mu\", n_draw, idata)\n sigma = get_var_samples(f\"{name}_sigma\", n_draw, idata)\n dens_samples = np.array([create_density(w[i], mu[i], sigma[i])[1]\n for i in range(len(w))])\n dens_quantiles = np.quantile(dens_samples, q, axis=0)\n return dens_quantiles\n\n\ndef plot_background(idata, true_density):\n n_components = int(idata.constant_data.bkgd_n_components.data)\n try:\n bkgd_samples = idata.observed_data.bkgd.data\n except AttributeError:\n bkgd_samples = None \n\n m, bkgd_mean, bkgd_mean_comp = create_mean_density(\"bkgd\", n_components, idata)\n\n plt.plot(m, true_density, lw=2, color=\"firebrick\", label=\"true background\")\n plt.plot(m, bkgd_mean, lw=2, color=\"steelblue\", label=\"mean posterior\")\n for comp in bkgd_mean_comp:\n plt.plot(m, comp, lw=2, color=\"steelblue\", ls=\":\")\n if bkgd_samples is not None:\n bins, (hist_05, hist_95) = hist_quantiles_from_posterior_predictive(idata.posterior_predictive.bkgd[0])\n plt.hist(bkgd_samples, bins=bins, color=\"darkorange\", density=True, histtype=\"step\", lw=2, label=f\"~{BKGD_NSAMP} observed samples\")\n plt.fill_between(bins, hist_05, hist_95, step=\"post\", color=\"lightgreen\", lw=0, label=f\"95% interval on sampled\")\n dens_05, dens_95 = dens_quantiles_from_posterior(\"bkgd\", len(idata.posterior.draw), idata)\n plt.fill_between(m, dens_05, dens_95, step=\"mid\", color=\"skyblue\", lw=0, label=\"95% interval on mean\")\n plt.legend(loc=\"lower center\")\n plt.ylabel(\"normalised density\")\n plt.xlabel(\"magnitude\")\n plt.xlim(xmin=LOWER, xmax=UPPER)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"024ebc13-30e1-410e-805e-dc564c3306f5","cell_type":"code","source":"plot_background(idata_bkgd, true_density=bkgd_true)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"b1b3d6d9-5e97-4e80-a721-8d232eaa790e","cell_type":"markdown","source":"You can see that the model is flexible enough such that the mean posterior closely follows the histogram of the observed samples. This may suggest that it is \"overfitting\" the observations. However, this doesn't particularly matter, as we will not be making any use of the mean posterior. Instead, what we care about is the posterior distribution. You can see that the credible interval on the mean (created from the posterior samples) encompasses the \"true\" distribution and the credible interval on the the histogram of the samples reflects the variation in the observed histogram from the \"truth\".","metadata":{}},{"id":"07d4c28e-6943-4be1-baed-27fd5dd7b8c6","cell_type":"markdown","source":"Now create a distribution for the GCs.","metadata":{}},{"id":"7d7805a2-e1d9-486b-8c59-1ef74a3bef9d","cell_type":"code","source":"N_REGIONS = 5\nREGION_AREAS = np.array([1.0, 0.2, 0.1, 0.05, 0.02])\ngclf_n_components_true = 2\ngclf_mu_true = np.array([26.3, 26.6])\ngclf_sigma_true = np.array([1.5, 0.6])\n#gclf_mu_true = np.array([25.0, 26.6])\n#gclf_sigma_true = np.array([1.0, 0.6])\n# The surface density of each GC population in each region, relative to the background population\ngclf_reldens_true = np.array([[0.02, 0.01], [0.5, 0.2], [2.0, 0.5], [7.0, 1.5], [25.0, 5.0]])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"682ca179-b932-4a1d-bb8a-91717e9931de","cell_type":"code","source":"ref_val_plot_pair = {\"gclf_mu[0]\": gclf_mu_true[0], \"gclf_mu[1]\": gclf_mu_true[1],\n \"gclf_sigma[0]\": gclf_sigma_true[0], \"gclf_sigma[1]\": gclf_sigma_true[1]}\nref_val_plot_posterior = {k: [{\"ref_val\": v}] for k, v in ref_val_plot_pair.items()}\ngclf_reldens_sum = gclf_reldens_true.sum(axis=-1)\ngclf_comp0_w = gclf_reldens_true[:, 0] / gclf_reldens_sum\nref_val_plot_pair.update({f\"gclf_comp0_w_region {i}\": w for i, w in enumerate(gclf_comp0_w)})\nref_val_plot_posterior.update({\"gclf_comp0_w_region\": [{f\"gclf_comp0_w_region_dim_0\": i ,\"ref_val\": w} for i, w in enumerate(gclf_comp0_w)]})\ngclf_w = gclf_reldens_sum / (gclf_reldens_sum + 1)\nref_val_plot_pair.update({f\"gclf_weight_region {i}\": w for i, w in enumerate(gclf_w)})\nref_val_plot_posterior.update({\"gclf_weight_region\": [{f\"gclf_weight_region_dim_0\": i ,\"ref_val\": w} for i, w in enumerate(gclf_w)]})","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"c313ce70-69ea-4ada-b725-cb6c3d3308cc","cell_type":"code","source":"gclf_samples = []\nbkgd_samples = []\nfor r in range(N_REGIONS):\n bkgd_n = int(BKGD_NSAMP * REGION_AREAS[r])\n gclf_n = int(bkgd_n * gclf_reldens_true[r].sum())\n gclf_weight = gclf_reldens_true[r] / gclf_reldens_true[r].sum()\n gclf_samples.append(create_mixnorm_sample(gclf_n, gclf_weight, gclf_mu_true, gclf_sigma_true))\n bkgd_samples.append(create_mixnorm_sample(bkgd_n, bkgd_weight_true, bkgd_mu_true, bkgd_sigma_true))","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"17d872a1-83ac-4074-9063-560361c677d8","cell_type":"code","source":"gclf_true = []\ngclf_true_comp = []\nfor r in range(N_REGIONS):\n gclf_count = (BKGD_NSAMP * REGION_AREAS[r] * gclf_reldens_true[r]).astype(int)\n gclf_count = rng.poisson(gclf_count)\n _, true, comp = create_density(gclf_count, gclf_mu_true, gclf_sigma_true)\n gclf_true.append(true)\n gclf_true_comp.append(comp)\ngclf_true = np.array(gclf_true)\ngclf_true_comp = np.array(gclf_true_comp)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"9ead8293-aba2-42b4-bec7-18d17851af98","cell_type":"code","source":"def plot_samples_in_regions(samples, true, comp=None):\n gclf_hist, bins, _ = plt.hist(samples, bins=NBINS, range=(LOWER, UPPER), density=False, histtype=\"step\")\n plt.gca().set_prop_cycle(None)\n for r in range(N_REGIONS):\n plt.plot(m, true[r] * (UPPER - LOWER) / NBINS)\n if comp is not None:\n n_components = len(comp[0])\n for i in range(n_components):\n plt.gca().set_prop_cycle(None)\n for r in range(N_REGIONS):\n plt.plot(m, comp[r][i] * (UPPER - LOWER) / NBINS, ls=\":\")\n plt.xlim(xmin=LOWER, xmax=UPPER)\n plt.xlabel(\"mag\")\n plt.ylabel(\"count\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"337611c1-cdbd-4c4a-8a98-654edbf17e21","cell_type":"code","source":"plot_samples_in_regions(gclf_samples, gclf_true, gclf_true_comp)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"8100189f-fd2e-452e-af96-878f2d102347","cell_type":"code","source":"bkgd_true_regions = []\nbkgd_true_regions_comp = []\nfor r in range(N_REGIONS):\n bkgd_n = int(BKGD_NSAMP * REGION_AREAS[r])\n bkgd_count = (BKGD_NSAMP * REGION_AREAS[r] * bkgd_weight_true).astype(int)\n bkgd_count = rng.poisson(bkgd_count)\n _, true, comp = create_density(bkgd_count, bkgd_mu_true, bkgd_sigma_true)\n bkgd_true_regions.append(true)\n bkgd_true_regions_comp.append(comp)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"054390b7-02cd-4706-811a-1b380e333f21","cell_type":"code","source":"plot_samples_in_regions(bkgd_samples, bkgd_true_regions)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"e4d8844d-b6fb-4018-817a-cc8bd864b33a","cell_type":"code","source":"full_true = [gclf_true[r] + bkgd_true_regions[r] for r in range(N_REGIONS)]","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"7c157e5f-9a42-4cac-91d0-7097a7b98b81","cell_type":"code","source":"full_samples = [np.concatenate((gclf_samples[r], bkgd_samples[r])) for r in range(N_REGIONS)]","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"c677f2a9-7460-4503-9ed6-7dc49a573324","cell_type":"code","source":"plot_samples_in_regions(full_samples, full_true)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"7595c572-6d04-4e24-a754-64bf6a87f36c","cell_type":"code","source":"def create_model_full(n_bkgd_comp=7, n_gclf_comp=2, gclf_mu0=None, gclf_mu1=None, gclf_sigma0=None, gclf_sigma1=None,\n truncated=True):\n if truncated:\n NormalDist = partial(pm.TruncatedNormal.dist, lower=LOWER, upper=UPPER)\n else:\n NormalDist = pm.Normal.dist\n with pm.Model() as model:\n # the background\n pm.Data(\"lower\", LOWER)\n pm.Data(\"upper\", UPPER)\n pm.Data(\"bkgd_n_components\", n_bkgd_comp)\n magrange = UPPER - LOWER\n # margin = 0.5 * magrange / n_bkgd_comp\n margin = 0.0\n bkgd_mu = pm.Data(\"bkgd_mu\", np.linspace(LOWER - margin, UPPER + margin, n_bkgd_comp))\n bkgd_sigma = pm.Data(\"bkgd_sigma\", np.ones(n_bkgd_comp) * (magrange + 2 * margin) / n_bkgd_comp)\n bkgd_w = pm.Dirichlet(\"bkgd_w\", 1.1 * np.ones(n_bkgd_comp))\n bkgd_components = [NormalDist(mu=bkgd_mu[i], sigma=bkgd_sigma[i]) for i in range(n_bkgd_comp)]\n\n # the signal\n pm.Data(\"gclf_n_components\", n_gclf_comp)\n if gclf_mu0 is not None:\n mu0 = pm.Data(\"gclf_mu0\", gclf_mu0)\n else:\n mu0 = pm.Normal(\"gclf_mu[0]\", mu=26.0, sigma=1.0, initval=26.0)\n if gclf_sigma0 is not None:\n sigma0 = pm.Data(\"gclf_sigma[0]\", gclf_sigma0)\n else:\n sigma0 = pm.Gamma(\"gclf_sigma[0]\", mu=1.5, sigma=0.5)\n if n_gclf_comp > 1:\n if gclf_mu1 is not None:\n mu1 = pm.Data(\"gclf_mu[1]\", gclf_mu1)\n else:\n mu1 = pm.Normal(\"gclf_mu[1]\", mu=26.5, sigma=1.0, initval=26.5)\n if gclf_sigma1 is not None:\n sigma1 = pm.Data(\"gclf_sigma[1]\", gclf_sigma1)\n else:\n sigma1 = pm.Gamma(\"gclf_sigma[1]\", mu=0.5, sigma=0.5)\n\n if n_gclf_comp > 1:\n if gclf_mu0 is None or gclf_mu1 is None:\n mu_constraint = mu0 < mu1\n pm.Potential(\"mu_constraint\", pm.math.log(pm.math.switch(mu_constraint, 1, 0)))\n if gclf_sigma0 is None or gclf_sigma1 is None:\n sigma_constraint = sigma1 < sigma0\n pm.Potential(\"sigma_constraint\", pm.math.log(pm.math.switch(sigma_constraint, 1, 0)))\n\n gclf_comp0 = NormalDist(mu=mu0, sigma=sigma0)\n if n_gclf_comp > 1:\n gclf_comp1 = NormalDist(mu=mu1, sigma=sigma1)\n\n if n_gclf_comp > 1:\n region_components = [gclf_comp0, gclf_comp1] + bkgd_components\n else:\n region_components = [gclf_comp0] + bkgd_components\n\n # Using Beta to mimic a Uniform distibution as there seems to be a bug with Uniform, at least for shape > 1\n region_gclf_weight = pm.Beta(\"gclf_weight_region\", alpha=1, beta=1, shape=N_REGIONS,\n initval=np.linspace(0.1, 0.9, N_REGIONS))\n # impose ordering on region weights\n #region_constraint = pm.math.min(region_gclf_weight[1:] - region_gclf_weight[:-1]) > 0\n #pm.Potential(\"region_constraint\", pm.math.log(pm.math.switch(region_constraint, 1, 0.1)))\n\n if n_gclf_comp > 1:\n region_gclf_comp0_w = pm.Beta(\"gclf_comp0_w_region\", alpha=1, beta=1, shape=N_REGIONS)\n for r in range(N_REGIONS):\n if n_gclf_comp > 1:\n region_w = pm.math.concatenate((region_gclf_weight[None, r] * region_gclf_comp0_w[r],\n region_gclf_weight[None, r] * (1 - region_gclf_comp0_w[r]),\n (1 - region_gclf_weight[r]) * bkgd_w))\n else:\n region_w = pm.math.concatenate((region_gclf_weight[None, r],\n (1 - region_gclf_weight[r]) * bkgd_w))\n region_like = pm.Mixture(f\"full_region{r}\", w=region_w, comp_dists=region_components, observed=full_samples[r])\n return model","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"02b52483-645c-4b95-a156-f47a3186f2a9","cell_type":"code","source":"def sample_model(model):\n with model:\n idata = pm.sample_prior_predictive()\n idata.extend(pm.sample(init=\"adapt_diag\", nuts_sampler=\"numpyro\"))\n thinned_idata = idata.sel(chain=[0])\n idata.extend(pm.sample_posterior_predictive(thinned_idata, return_inferencedata=True))\n return idata","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"ddc0b2d7-266f-4040-a2ad-03f9178c7ebe","cell_type":"code","source":"def plot_model(idata, priors=False):\n ax = az.plot_ppc(idata, group=\"prior\")\n plt.suptitle(\"Prior Predictive Check\")\n plt.tight_layout()\n plot_pair_kwargs = dict(filter_vars=\"like\",\n kind=\"kde\", kde_kwargs=dict(hdi_probs=[0.68, 0.95],\n contour_kwargs=dict(alpha=[0, 1, 1, 1]),\n contourf_kwargs=dict(cmap=plt.cm.Blues)),\n reference_values=ref_val_plot_pair,\n reference_values_kwargs=dict(markersize=10, color=\"orange\"))\n az.plot_trace(idata)\n plt.suptitle(\"MCMC Traces\")\n plt.tight_layout()\n try:\n az.plot_posterior(idata, var_names=[\"gclf_sigma\", \"gclf_mu\"], filter_vars=\"like\",\n ref_val=ref_val_plot_posterior)\n plt.suptitle(\"Posterior – GCLF Distribution Parameters\")\n plt.tight_layout()\n except ValueError:\n pass\n az.plot_posterior(idata, var_names=\"gclf_weight\", filter_vars=\"like\",\n ref_val=ref_val_plot_posterior)\n plt.suptitle(\"Posterior – GCLF Weights Vs Background\")\n plt.tight_layout()\n try:\n az.plot_posterior(idata, var_names=\"gclf_comp\", filter_vars=\"like\",\n ref_val=ref_val_plot_posterior)\n plt.suptitle(\"Posterior – GCLF Component Weights\")\n plt.tight_layout()\n except ValueError:\n pass\n try:\n if priors:\n az.plot_pair(idata, var_names=[\"gclf_sigma\", \"gclf_mu\"], group=\"prior\", **plot_pair_kwargs) \n plt.suptitle(\"Prior Pair Plot – GCLF Distribution Parameters\")\n plt.tight_layout()\n az.plot_pair(idata, var_names=[\"gclf_sigma\", \"gclf_mu\"], **plot_pair_kwargs) \n plt.suptitle(\"Posterior Pair Plot – GCLF Distribution Parameters\")\n plt.tight_layout()\n except ValueError:\n pass\n if priors:\n az.plot_pair(idata, var_names=\"gclf_weight\", group=\"prior\", **plot_pair_kwargs)\n plt.suptitle(\"Prior Pair Plot – GCLF Weights Vs Background\")\n plt.tight_layout()\n az.plot_pair(idata, var_names=\"gclf_weight\", **plot_pair_kwargs)\n plt.suptitle(\"Posterior Pair Plot – GCLF Weights Vs Background\")\n plt.tight_layout()\n try:\n if priors:\n az.plot_pair(idata, var_names=\"gclf_comp\", group=\"prior\", **plot_pair_kwargs)\n plt.suptitle(\"Prior Pair Plot – GCLF Component Weights\")\n plt.tight_layout()\n az.plot_pair(idata, var_names=\"gclf_comp\", **plot_pair_kwargs)\n plt.suptitle(\"Posterior Pair Plot – GCLF Component Weights\")\n plt.tight_layout()\n except ValueError:\n pass\n ax = az.plot_ppc(idata)\n plt.suptitle(\"Posterior Predictive Check\")\n plt.tight_layout()\n plt.figure()\n plot_background(idata, true_density=bkgd_true)\n #plot_full TBD","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"0f18f175-bfa5-4682-9490-d22f26b16a8a","cell_type":"markdown","source":"To be done:","metadata":{}},{"id":"cd9edca7-16f2-427a-aea8-53d5c6c1c9c3","cell_type":"raw","source":"def plot_full(idata):\n lower = idata.constant_data.lower.data\n upper = idata.constant_data.upper.data\n n_bkgd_comp = int(idata.constant_data.n_bkgd_components.data)\n n_gclf_comp = int(idata.constant_data.n_gclf_components.data)\n samples = [idata.observed_data[f\"region{r}_full\"].data for r in range(n_regions)]\n m = np.linspace(lower, upper, 1000)\n df = az.summary(idata)\n bkgd_w = get_var_mean(\"bkgd_w\", n_bkgd_comp, df, idata)\n bkgd_mu = get_var_mean(\"bkgd_mu\", n_bkgd_comp, df, idata)\n bkgd_sigma = get_var_mean(\"bkgd_sigma\", n_bkgd_comp, df, idata)\n bkgd_mean, bkgd_mean_comp = create_mean(m, n_bkgd_comp, bkgd_w, bkgd_mu, bkgd_sigma, lower, upper)\n\n gclf_mu = get_var_mean(\"gclf_mu\", n_gclf_comp, df, idata)\n gclf_sigma = get_var_mean(\"gclf_sigma\", n_gclf_comp, df, idata) \n\n region_gclf_weight = get_var_mean(\"gclf_weight_region\", n_gclf_comp, df, idata)\n if n_gclf_comp > 1:\n region_gclf_comp1_w = get_var_mean(\"gclf_comp1_w_region\", n_gclf_comp, df, idata)\n else:\n region_gclf_comp1_w = np.ones(n_regions)\n \n gclf_mean, gclf_mean_comp = zip([create_mean(m, n_gclf_comp, region_gclf_comp1_w[r], gclf_mu, gclf_sigma, lower, upper) for r in len(n_regions)])\n\n hist_quantiles = []\n for r in range(n_regions):\n samples_posterior = idata.posterior_predictive[f\"region{r}_full\"][0].T\n histogram_samples = np.array([np.histogram(x, bins=20, range=(22, 28), density=True)[0] for x in idata.posterior_predictive.like[0]])\n bin_edges = np.linspace(22, 28, 21) \n hq = np.quantile(histogram_samples, (0.05, 0.16, 0.84, 0.95), axis=0)\n hist_quantiles.append(np.concatenate((hq, hist_quantiles[:, [-1]]), axis=-1))\n\n bkgd_w = get_var_samples(\"bkgd_w\", n_bkgd_comp, idata)\n bkgd_mu = get_var_samples(\"bkgd_mu\", n_bkgd_comp, idata)\n bkgd_sigma = get_var_samples(\"bkgd_sigma\", n_bkgd_comp, idata)\n bkgd_dens_samples = np.array([create_mean(m, n_components, w[i], mu[i], sigma[i], lower, upper)[0]\n for i in range(len(bkgd_w))])\n bkgd_dens_quantiles = np.quantile(bkgd_dens_samples, (0.05, 0.16, 0.5, 0.84, 0.95), axis=0)\n\n gclf_mu = get_var_samples(\"gclf_mu\", n_gclf_comp, idata)\n gclf_sigma = get_var_samples(\"gclf_sigma\", n_gclf_comp, idata) \n region_gclf_weight = get_var_samples(\"gclf_weight_region\", n_gclf_comp, idata)\n if n_gclf_comp > 1:\n region_gclf_comp1_w = get_var_samples(\"gclf_comp1_w_region\", n_gclf_comp, idata)\n else:\n region_gclf_comp1_w = np.ones(region_gclf_weight.shape)\n \n gclf_dens_samples = np.array([[create_mean(m, n_gclf_comp, region_gclf_comp1_w[r][i], gclf_mu[i], gclf_sigma[i], lower, upper)[0]\n for i in range(len(gclf_mu))] for r in len(n_regions)])\n gclf_dens_quantiles = [np.quantile(gclf_dens_samples, (0.05, 0.16, 0.5, 0.84, 0.95), axis=0) for r in len(n_regions)]\n \n full_dens_samples = bkgd_dens_samples + gclf_dens_samples\n full_dens_quantiles = [np.quantile(full_dens_samples, (0.05, 0.16, 0.5, 0.84, 0.95), axis=0) for r in len(n_regions)]\n\n plt.plot(m, bkgd_true, lw=2, color=\"firebrick\", label=\"true background\")\n plt.plot(m, bkgd_mean, lw=2, color=\"steelblue\", label=\"mean posterior\")\n for comp in bkgd_mean_comp:\n plt.plot(m, comp, lw=2, color=\"steelblue\", ls=\":\")\n _, bins, _ = plt.hist(bkgd_samples, bins=20, range=(22, 28), color=\"darkorange\", density=True, histtype=\"step\", lw=2, label=f\"~{nsamp} observed samples\")\n hist_05, hist_16, hist_84, hist_95 = hist_quantiles\n plt.fill_between(bin_edges, hist_05, hist_95, step=\"post\", color=\"lightgreen\", lw=0, label=f\"95% interval on sampled\")\n dens_05, dens_16, dens_50, dens_84, dens_95 = dens_quantiles\n plt.fill_between(m, dens_05, dens_95, step=\"mid\", color=\"skyblue\", lw=0, label=\"95% interval on mean\")\n plt.legend(loc=\"lower center\")\n plt.ylabel(\"normalised density\")\n plt.xlabel(\"magnitude\")\n plt.xlim(xmin=lower, xmax=upper)","metadata":{}},{"id":"10cdff69-ad80-43dd-a783-107c25cf91f5","cell_type":"markdown","source":"## Sample and explore models","metadata":{}},{"id":"5a216234-fa81-4142-bcfe-0da0eb5d7de5","cell_type":"code","source":"model = {}\nidata = {}","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"c55386b6-3415-4cd7-851a-ae53cbe7c5f6","cell_type":"markdown","source":"### Two-component GCLF model with fixed means and sigmas","metadata":{}},{"id":"01e3edf4-831e-48fd-9bba-74b3dd1540d2","cell_type":"code","source":"name = \"full_2_comp_all_fixed\"\nmodel[name] = create_model_full(gclf_mu0=gclf_mu_true[0], gclf_mu1=gclf_mu_true[1],\n gclf_sigma0=gclf_sigma_true[0], gclf_sigma1=gclf_sigma_true[1])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"a0f5e8f4-42aa-4bbf-b766-90bd10a1f9e3","cell_type":"code","source":"pm.model_to_graphviz(model[name], var_names=[\"full_region0\"])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"78f9b546-8a4c-487e-8a6c-fc9e46c862e4","cell_type":"code","source":"if not os.path.exists(f\"{name}.nc\"):\n idata[name] = sample_model(model[name])\n idata[name].to_netcdf(f\"{name}.nc\")\nelse:\n idata[name] = az.from_netcdf(f\"{name}.nc\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"dff82b81-19a7-4e6c-b735-f89cf096eaf5","cell_type":"code","source":"az.summary(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"f498b41b-7448-4fa7-b2a9-aec523c5be6f","cell_type":"code","source":"plot_model(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"aeb6c25a-7a80-484c-a941-25459f9bf4d0","cell_type":"markdown","source":"### Two-component GCLF model with fixed means and free sigmas","metadata":{}},{"id":"a364ee11-2d5d-40f0-bc9e-d5a4d27accfa","cell_type":"code","source":"name = \"full_2comp_mean_fixed\"\nmodel[name] = create_model_full(gclf_mu0=gclf_mu_true[0], gclf_mu1=gclf_mu_true[1])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"aa6fd7b3-7105-424f-86b4-5b9c0e9ca3ab","cell_type":"code","source":"pm.model_to_graphviz(model[name], var_names=[\"full_region0\"])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"8109239f-642e-488c-a8bb-509f0a3c68e6","cell_type":"code","source":"if not os.path.exists(f\"{name}.nc\"):\n idata[name] = sample_model(model[name])\n idata[name].to_netcdf(f\"{name}.nc\")\nelse:\n idata[name] = az.from_netcdf(f\"{name}.nc\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"de59407f-04c2-495a-a847-1c20e3751da9","cell_type":"code","source":"az.summary(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"7237cc69-530e-4a22-a7a1-41d531fd07e0","cell_type":"code","source":"plot_model(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"2949a974-161b-4eb7-81f1-79493e8aff9e","cell_type":"markdown","source":"### One-component GCLF model with fixed mean and free sigma","metadata":{}},{"id":"e93df12d-014a-41a0-961c-3c5977ca8b26","cell_type":"code","source":"name = \"full_1comp_mean_fixed\"\nmodel[name] = create_model_full(n_gclf_comp=1, gclf_mu0=gclf_mu_true[0])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"1f78d602-3106-4e92-bdcd-30c6045f912d","cell_type":"code","source":"pm.model_to_graphviz(model[name], var_names=[\"full_region0\"])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"5098df29-4ef4-444a-acd0-471dc78171a6","cell_type":"code","source":"if not os.path.exists(f\"{name}.nc\"):\n idata[name] = sample_model(model[name])\n idata[name].to_netcdf(f\"{name}.nc\")\nelse:\n idata[name] = az.from_netcdf(f\"{name}.nc\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"9ab1a4d1-0b44-493d-9f87-c047c478e48a","cell_type":"code","source":"az.summary(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"bcab1aad-c05f-416b-b5c9-95ed59667108","cell_type":"code","source":"plot_model(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"c1f4441b-405a-4280-b36f-ec50e70d58c7","cell_type":"markdown","source":"### One-component GCLF model with free mean and sigma","metadata":{}},{"id":"1350726b-76b5-41c8-bead-5715815a7f0c","cell_type":"code","source":"name = \"full_1comp_all_free\"\nmodel[name] = create_model_full(n_gclf_comp=1)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"7763678c-d7c4-4767-bc4a-802d3595faa4","cell_type":"code","source":"pm.model_to_graphviz(model[name], var_names=[\"full_region0\"])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"e71fa1a2-7629-40ca-905a-3b6ded4eb9c8","cell_type":"code","source":"if not os.path.exists(f\"{name}.nc\"):\n idata[name] = sample_model(model[name])\n idata[name].to_netcdf(f\"{name}.nc\")\nelse:\n idata[name] = az.from_netcdf(f\"{name}.nc\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"8dd53465-6d83-48e0-8cc5-d4fe178f64ca","cell_type":"code","source":"az.summary(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"aab357ca-4009-44de-93e7-f4de184e33da","cell_type":"code","source":"plot_model(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"eb5d5644-310b-46d6-b21d-844a64c1ed1a","cell_type":"markdown","source":"### Two-component GCLF model with free means and sigmas","metadata":{}},{"id":"5dbec5ca-aa47-4f66-b64a-dd2f84413bc7","cell_type":"code","source":"name = \"full_2comp_all_free\"\nmodel[name] = create_model_full()","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"6792bafe-5d82-43ff-a5d8-2209556ae4e4","cell_type":"code","source":"pm.model_to_graphviz(model[name], var_names=[\"full_region0\"])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"f7db152c-2606-45a3-9c10-34105b79cd8a","cell_type":"code","source":"if not os.path.exists(f\"{name}.nc\"):\n idata[name] = sample_model(model[name])\n idata[name].to_netcdf(f\"{name}.nc\")\nelse:\n idata[name] = az.from_netcdf(f\"{name}.nc\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"71ea76a4-7ff0-4a34-b045-57d38044a608","cell_type":"code","source":"az.summary(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"7654e2ec-e8fb-48ea-8f95-ffca417e7835","cell_type":"code","source":"plot_model(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"4267ea77-aa20-49b9-acef-84c50dcbd336","cell_type":"markdown","source":"## Model comparison","metadata":{}},{"id":"23199dab-c00b-4fad-a033-d8ce48e865b6","cell_type":"code","source":"for name in model:\n with model[name]:\n pm.compute_log_likelihood(idata[name])","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"4e7c31fd-3d33-480a-ab1b-0770a91db12d","cell_type":"code","source":"df_comp_loo = az.compare(idata, var_name=\"full_region0\")\ndf_comp_loo","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"b779e31e-b019-48e5-8e23-eca18bd1086e","cell_type":"code","source":"az.plot_bf(idata[\"full_2comp_mean_fixed\"], var_name=\"gclf_sigma[0]\", ref_val=gclf_sigma_true[0]);","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"e5e2bc3b-e9af-495f-8fa3-f2999b54f770","cell_type":"markdown","source":"To be done:","metadata":{}},{"id":"d717707e-ef5b-47a2-8633-9abb5c97faf7","cell_type":"raw","source":"with full_model_1comp:\n idata_smc_full_1comp = pm.sample_smc()\n\nwith full_model_2comp:\n idata_smc_full_2comp = pm.sample_smc()\n\nBF_smc = np.exp(\n idata_smc_full_2comp.sample_stats[\"log_marginal_likelihood\"].mean()\n - idata_smc_full_1comp.sample_stats[\"log_marginal_likelihood\"].mean()\n)\nBF_smc","metadata":{}}]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment