Skip to content

Instantly share code, notes, and snippets.

@aphearin
Created February 28, 2024 19:25
Show Gist options
  • Save aphearin/c043dbfdc73d8aa46879d8ca2b52c93f to your computer and use it in GitHub Desktop.
Save aphearin/c043dbfdc73d8aa46879d8ca2b52c93f to your computer and use it in GitHub Desktop.
Demonstrate optimization of a two-component Gaussian model using a loss function based on Gaussian KDE PDF estimation
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
"""
"""
from functools import partial
from jax import jit as jjit
from jax import numpy as jnp
from jax import random as jran
from jax.scipy.stats import gaussian_kde
@jjit
def _mse(pred, target):
diff = pred - target
return jnp.mean(diff * diff)
@partial(jjit, static_argnames=["npts"])
def mc_double_gaussian(params, ran_key, npts):
mu1, mu2, sig1, sig2, frac1 = params
pop1_key, pop2_key = jran.split(ran_key, 2)
pop1 = jran.normal(pop1_key, shape=(npts,)) * sig1 + mu1
pop2 = jran.normal(pop2_key, shape=(npts,)) * sig2 + mu2
return pop1, pop2, frac1
@partial(jjit, static_argnames=["npts_pred"])
def predict_pdf(params, ran_key, pdf_abscissa, npts_pred):
"""Toy model prediction for the pdf at the input abscissa"""
pop1, pop2, frac1 = mc_double_gaussian(params, ran_key, npts_pred)
pred_kde_pop1 = gaussian_kde(pop1.T)
pred_kde_pop2 = gaussian_kde(pop2.T)
pred_pdf_pop1 = pred_kde_pop1.pdf(pdf_abscissa)
pred_pdf_pop2 = pred_kde_pop2.pdf(pdf_abscissa)
return frac1 * pred_pdf_pop1 + (1 - frac1) * pred_pdf_pop2
def build_loss_func(ran_key, target_data, ncells):
"""
Parameters
----------
ran_key : jax.random.PRNGKey
target_data : ndarray of shape (ndim, npts)
ncells : int
Number of points at which to evaluate the target PDF
Returns
-------
loss_func accepts (params, ran_key) returns MSE loss
"""
npts_data = len(target_data)
# precompute kde model of data
# returned loss_kern will treat this kde model as fixed data
data_kde = gaussian_kde(target_data)
# randomly select points at which to predict pdf
# do this by drawing samples from a KDE model of the data
# this eliminates the need to hand-tune bin edges
target_abscissa = data_kde.resample(ran_key, (ncells,))
# evaluate pdf of the target data at the randomly generated abscissa
target_pdf = data_kde.pdf(target_abscissa)
@jjit
def loss_func(params, pred_key):
pred_pdf = predict_pdf(params, pred_key, target_abscissa, npts_data)
return _mse(pred_pdf, target_pdf)
return loss_func
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment