Created
February 28, 2024 19:25
-
-
Save aphearin/c043dbfdc73d8aa46879d8ca2b52c93f to your computer and use it in GitHub Desktop.
Demonstrate optimization of a two-component Gaussian model using a loss function based on Gaussian KDE PDF estimation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
""" | |
from functools import partial | |
from jax import jit as jjit | |
from jax import numpy as jnp | |
from jax import random as jran | |
from jax.scipy.stats import gaussian_kde | |
@jjit | |
def _mse(pred, target): | |
diff = pred - target | |
return jnp.mean(diff * diff) | |
@partial(jjit, static_argnames=["npts"]) | |
def mc_double_gaussian(params, ran_key, npts): | |
mu1, mu2, sig1, sig2, frac1 = params | |
pop1_key, pop2_key = jran.split(ran_key, 2) | |
pop1 = jran.normal(pop1_key, shape=(npts,)) * sig1 + mu1 | |
pop2 = jran.normal(pop2_key, shape=(npts,)) * sig2 + mu2 | |
return pop1, pop2, frac1 | |
@partial(jjit, static_argnames=["npts_pred"]) | |
def predict_pdf(params, ran_key, pdf_abscissa, npts_pred): | |
"""Toy model prediction for the pdf at the input abscissa""" | |
pop1, pop2, frac1 = mc_double_gaussian(params, ran_key, npts_pred) | |
pred_kde_pop1 = gaussian_kde(pop1.T) | |
pred_kde_pop2 = gaussian_kde(pop2.T) | |
pred_pdf_pop1 = pred_kde_pop1.pdf(pdf_abscissa) | |
pred_pdf_pop2 = pred_kde_pop2.pdf(pdf_abscissa) | |
return frac1 * pred_pdf_pop1 + (1 - frac1) * pred_pdf_pop2 | |
def build_loss_func(ran_key, target_data, ncells): | |
""" | |
Parameters | |
---------- | |
ran_key : jax.random.PRNGKey | |
target_data : ndarray of shape (ndim, npts) | |
ncells : int | |
Number of points at which to evaluate the target PDF | |
Returns | |
------- | |
loss_func accepts (params, ran_key) returns MSE loss | |
""" | |
npts_data = len(target_data) | |
# precompute kde model of data | |
# returned loss_kern will treat this kde model as fixed data | |
data_kde = gaussian_kde(target_data) | |
# randomly select points at which to predict pdf | |
# do this by drawing samples from a KDE model of the data | |
# this eliminates the need to hand-tune bin edges | |
target_abscissa = data_kde.resample(ran_key, (ncells,)) | |
# evaluate pdf of the target data at the randomly generated abscissa | |
target_pdf = data_kde.pdf(target_abscissa) | |
@jjit | |
def loss_func(params, pred_key): | |
pred_pdf = predict_pdf(params, pred_key, target_abscissa, npts_data) | |
return _mse(pred_pdf, target_pdf) | |
return loss_func |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment