Skip to content

Instantly share code, notes, and snippets.

@aphearin
Created February 28, 2024 19:25
Show Gist options
  • Save aphearin/c043dbfdc73d8aa46879d8ca2b52c93f to your computer and use it in GitHub Desktop.
Save aphearin/c043dbfdc73d8aa46879d8ca2b52c93f to your computer and use it in GitHub Desktop.
Demonstrate optimization of a two-component Gaussian model using a loss function based on Gaussian KDE PDF estimation
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "3d6f3934",
"metadata": {},
"source": [
"# Optimize double-Gaussian model using KDE-based loss function"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cbdb358d",
"metadata": {},
"outputs": [],
"source": [
"from jax import random as jran\n",
"ran_key = jran.PRNGKey(0)"
]
},
{
"cell_type": "markdown",
"id": "4325f3b4",
"metadata": {},
"source": [
"## Generate a target dataset from a double Gaussian with $10^5$ points"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8759d5d1",
"metadata": {},
"outputs": [],
"source": [
"from kde_experiment import mc_double_gaussian\n",
"\n",
"mu1, mu2, sig1, sig2, frac1 = -1.0, 1.0, 0.5, 0.5, 0.35\n",
"fid_params = np.array((mu1, mu2, sig1, sig2, frac1))\n",
"\n",
"NPTS_TARGET_DATA = int(1e5)\n",
"\n",
"ran_key, target_data_key = jran.split(ran_key, 2)\n",
"pop1, pop2, frac1 = mc_double_gaussian(fid_params, target_data_key, NPTS_TARGET_DATA)\n",
"\n",
"TARGET_DATA = np.where(np.random.uniform(0, 1, NPTS_TARGET_DATA) < frac1, pop1, pop2)"
]
},
{
"cell_type": "markdown",
"id": "9908a962",
"metadata": {},
"source": [
"## Build loss function and evaluate on $p_{\\rm init}$"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "601336f9",
"metadata": {},
"outputs": [],
"source": [
"from kde_experiment import build_loss_func\n",
"\n",
"NCELLS = 100\n",
"loss_func = build_loss_func(ran_key, TARGET_DATA, NCELLS)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "808fbdea",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(0.02030281, dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"p_init = np.array(fid_params) + 0.2\n",
"ran_key, pred_init_key = jran.split(ran_key, 2)\n",
"\n",
"loss_func(p_init, pred_init_key)"
]
},
{
"cell_type": "markdown",
"id": "86f63771",
"metadata": {},
"source": [
"## Test evaluation of gradient of loss function"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "08a0efe1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Array(0.02030281, dtype=float32),\n",
" Array([-0.00609772, 0.01659485, -0.027687 , 0.04816183, 0.09830149], dtype=float32))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from jax import value_and_grad\n",
"\n",
"loss_and_grad_func = value_and_grad(loss_func)\n",
"\n",
"loss, grads = loss_and_grad_func(p_init, pred_init_key)\n",
"loss, grads"
]
},
{
"cell_type": "markdown",
"id": "5064315a",
"metadata": {},
"source": [
"## Walk down the gradient"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a54c9cb9",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"nsteps = 100\n",
"\n",
"p_best = np.copy(p_init)\n",
"loss_collector = []\n",
"\n",
"# Take some big steps\n",
"for __ in range(nsteps):\n",
" ran_key, pred_key = jran.split(ran_key, 2)\n",
" loss, grads = loss_and_grad_func(p_best, pred_key)\n",
" p_best = p_best - 1.0*grads\n",
" loss_collector.append(loss)\n",
" \n",
"# Take some small steps\n",
"for __ in range(nsteps):\n",
" ran_key, pred_key = jran.split(ran_key, 2)\n",
" loss, grads = loss_and_grad_func(p_best, pred_key)\n",
" p_best = p_best - 0.2*grads\n",
" loss_collector.append(loss)\n",
" \n",
"fig, ax = plt.subplots(1, 1)\n",
"yscale = ax.set_yscale('log')\n",
"__=ax.plot(loss_collector)\n",
"xlabel = ax.set_xlabel(r'$N_{\\rm step}$')\n",
"ylabel = ax.set_ylabel(r'${\\rm loss}$')"
]
},
{
"cell_type": "markdown",
"id": "f086e11f",
"metadata": {},
"source": [
"## Generate MC realization of initial and final prediction"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "44b68844",
"metadata": {},
"outputs": [],
"source": [
"pop1, pop2, frac1 = mc_double_gaussian(p_init, target_data_key, NPTS_TARGET_DATA)\n",
"pred_init = np.where(np.random.uniform(0, 1, NPTS_TARGET_DATA) < frac1, pop1, pop2)\n",
"\n",
"pop1, pop2, frac1 = mc_double_gaussian(p_best, target_data_key, NPTS_TARGET_DATA)\n",
"pred_best = np.where(np.random.uniform(0, 1, NPTS_TARGET_DATA) < frac1, pop1, pop2)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "206e62a8",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"target_hist, target_bins = np.histogram(TARGET_DATA, bins=100, density=True)\n",
"target_mids = 0.5*(target_bins[:-1] + target_bins[1:])\n",
"__=ax.plot(target_mids, target_hist, drawstyle='steps', color='k', label=r'${\\rm target\\ PDF}$')\n",
"\n",
"__=ax.hist(pred_init, bins=target_bins, density=True, alpha=0.7, label=r'${\\rm initial\\ prediction}$')\n",
"__=ax.hist(pred_best, bins=target_bins, density=True, alpha=0.7, label=r'${\\rm best\\ prediction}$')\n",
"\n",
"leg =ax.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25284a7a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
"""
"""
from functools import partial
from jax import jit as jjit
from jax import numpy as jnp
from jax import random as jran
from jax.scipy.stats import gaussian_kde
@jjit
def _mse(pred, target):
diff = pred - target
return jnp.mean(diff * diff)
@partial(jjit, static_argnames=["npts"])
def mc_double_gaussian(params, ran_key, npts):
mu1, mu2, sig1, sig2, frac1 = params
pop1_key, pop2_key = jran.split(ran_key, 2)
pop1 = jran.normal(pop1_key, shape=(npts,)) * sig1 + mu1
pop2 = jran.normal(pop2_key, shape=(npts,)) * sig2 + mu2
return pop1, pop2, frac1
@partial(jjit, static_argnames=["npts_pred"])
def predict_pdf(params, ran_key, pdf_abscissa, npts_pred):
"""Toy model prediction for the pdf at the input abscissa"""
pop1, pop2, frac1 = mc_double_gaussian(params, ran_key, npts_pred)
pred_kde_pop1 = gaussian_kde(pop1.T)
pred_kde_pop2 = gaussian_kde(pop2.T)
pred_pdf_pop1 = pred_kde_pop1.pdf(pdf_abscissa)
pred_pdf_pop2 = pred_kde_pop2.pdf(pdf_abscissa)
return frac1 * pred_pdf_pop1 + (1 - frac1) * pred_pdf_pop2
def build_loss_func(ran_key, target_data, ncells):
"""
Parameters
----------
ran_key : jax.random.PRNGKey
target_data : ndarray of shape (ndim, npts)
ncells : int
Number of points at which to evaluate the target PDF
Returns
-------
loss_func accepts (params, ran_key) returns MSE loss
"""
npts_data = len(target_data)
# precompute kde model of data
# returned loss_kern will treat this kde model as fixed data
data_kde = gaussian_kde(target_data)
# randomly select points at which to predict pdf
# do this by drawing samples from a KDE model of the data
# this eliminates the need to hand-tune bin edges
target_abscissa = data_kde.resample(ran_key, (ncells,))
# evaluate pdf of the target data at the randomly generated abscissa
target_pdf = data_kde.pdf(target_abscissa)
@jjit
def loss_func(params, pred_key):
pred_pdf = predict_pdf(params, pred_key, target_abscissa, npts_data)
return _mse(pred_pdf, target_pdf)
return loss_func
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment