Skip to content

Instantly share code, notes, and snippets.

@aphearin
Created April 14, 2022 15:58
Show Gist options
  • Save aphearin/6b7f3ac994bae7df7cf1089af4fbc2d0 to your computer and use it in GitHub Desktop.
Save aphearin/6b7f3ac994bae7df7cf1089af4fbc2d0 to your computer and use it in GitHub Desktop.
Demonstration of fitting a toy model for subhalo mass loss to itself
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "annual-helmet",
"metadata": {},
"source": [
"# Demo notebook for toy model of subhalo mass loss\n",
"\n",
"Executing this notebook requires the toy_model.py module"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "excellent-niagara",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "hydraulic-development",
"metadata": {},
"outputs": [],
"source": [
"from toy_model import get_adam_opt_funcs\n",
"\n",
"opt_state, opt_update, get_params = get_adam_opt_funcs()"
]
},
{
"cell_type": "markdown",
"id": "boring-fitness",
"metadata": {},
"source": [
"## Define our target data\n",
"\n",
"First choose some random point in parameter space for both the target data and also the initial guess"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "lesser-standard",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Target parameters = [3.07721998 1.72745557 0.08644907]\n"
]
}
],
"source": [
"from toy_model import get_default_params\n",
"p_default = get_default_params()\n",
"\n",
"rng = np.random.RandomState(seed=42)\n",
"p_init = rng.normal(loc=p_default, scale=0.3)\n",
"\n",
"rng = np.random.RandomState(seed=43)\n",
"p_target = rng.normal(loc=p_default, scale=0.3)\n",
"\n",
"print(\"Target parameters = {}\".format(p_target))"
]
},
{
"cell_type": "markdown",
"id": "neural-moses",
"metadata": {},
"source": [
"Now define the time array and the target data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "comparable-collapse",
"metadata": {},
"outputs": [],
"source": [
"from toy_model import predict_frac_mass_loss\n",
"tarr = np.linspace(0, 15, 50)\n",
"mass_loss_target = predict_frac_mass_loss(p_target, tarr)"
]
},
{
"cell_type": "markdown",
"id": "through-washington",
"metadata": {},
"source": [
"### Calculate the predictions of the initial guess"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "turned-cliff",
"metadata": {},
"outputs": [],
"source": [
"frac_mass_loss_init = predict_frac_mass_loss(p_init, tarr)"
]
},
{
"cell_type": "markdown",
"id": "executed-algorithm",
"metadata": {},
"source": [
"### Compare predictions of initial guess to target"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "rubber-optics",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"ylim = ax.set_ylim(-0.05, 1.1)\n",
"xlim = ax.set_xlim(-0.1, 15)\n",
"__=ax.plot(tarr, mass_loss_target, label=r'${\\rm target}$')\n",
"__=ax.plot(tarr, frac_mass_loss_init, '--', label=r'${\\rm initial\\ guess}$')\n",
"\n",
"xlabel = ax.set_xlabel(r'${\\rm time}$')\n",
"ylabel = ax.set_ylabel(r'${\\rm fractional\\ mass\\ loss}$')\n",
"leg = ax.legend()"
]
},
{
"cell_type": "markdown",
"id": "unavailable-chair",
"metadata": {},
"source": [
"Now define the `train_step` function based on the previously-retrieved Adam functions"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "compressed-lesson",
"metadata": {},
"outputs": [],
"source": [
"from toy_model import _loss_and_grad_func\n",
"\n",
"def train_step(step_i, state, train_step_data):\n",
" \"\"\"This is unfinished\"\"\"\n",
" params = get_params(state)\n",
" t_target, m_target = train_step_data\n",
" loss_data = t_target, m_target\n",
" loss, grads = _loss_and_grad_func(params, loss_data)\n",
"\n",
" return loss, opt_update(step_i, grads, opt_state)"
]
},
{
"cell_type": "markdown",
"id": "textile-conservation",
"metadata": {},
"source": [
"Finally, set up a simple loop where we take one step down the gradient at each step, with a step size that is adaptively set by the Adam algorithm:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "living-madagascar",
"metadata": {},
"outputs": [],
"source": [
"tstep_data = (tarr, mass_loss_target)\n",
"\n",
"n_steps = 400\n",
"loss_history = []\n",
"for istep in range(n_steps): \n",
" loss, opt_state = train_step(istep, opt_state, tstep_data)\n",
" loss_history.append(float(loss))"
]
},
{
"cell_type": "markdown",
"id": "cross-vehicle",
"metadata": {},
"source": [
"### Inspect the loss curve\n",
"\n",
"This should flatten over the course of training, or otherwise it training has not converged and should continue"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "stunning-flash",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"__=ax.plot(loss_history)\n",
"xlabel = ax.set_xlabel(r'${\\rm train\\ step}$')\n",
"ylabel = ax.set_ylabel(r'${\\rm MSE\\ loss\\ history}$')"
]
},
{
"cell_type": "markdown",
"id": "beautiful-scratch",
"metadata": {},
"source": [
"### Calculate the predictions of the best-fitting model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "stable-chair",
"metadata": {},
"outputs": [],
"source": [
"p_best = get_params(opt_state)\n",
"frac_mass_loss_best_fit = predict_frac_mass_loss(p_best, tarr)"
]
},
{
"cell_type": "markdown",
"id": "meaningful-skating",
"metadata": {},
"source": [
"### Check the quality of the fit"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "strange-coach",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"ylim = ax.set_ylim(-0.05, 1.1)\n",
"xlim = ax.set_xlim(-0.1, 15)\n",
"__=ax.plot(tarr, mass_loss_target, label=r'${\\rm target}$')\n",
"__=ax.plot(tarr, frac_mass_loss_init, '--', label=r'${\\rm initial\\ guess}$')\n",
"__=ax.plot(tarr, frac_mass_loss_best_fit, ':', color='k', label=r'${\\rm best-fit}$')\n",
"\n",
"xlabel = ax.set_xlabel(r'${\\rm time}$')\n",
"ylabel = ax.set_ylabel(r'${\\rm fractional\\ mass\\ loss}$')\n",
"leg = ax.legend()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bibliographic-development",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Target parameters = [3.07721998 1.72745557 0.08644907]\n",
"Best-fit parameters = [3.075733 1.7490137 0.08698298]\n"
]
}
],
"source": [
"print(\"Target parameters = {}\".format(p_target))\n",
"print(\"Best-fit parameters = {}\".format(p_best))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "absolute-period",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ordered-round",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
"""
"""
from jax import jit as jjit
from jax import numpy as jnp
import numpy as np
from collections import OrderedDict
from jax.experimental import optimizers as jax_opt
from jax import value_and_grad
DEFAULT_PARAMS = OrderedDict(t0=3.0, k=2.0, frac_loss_final=0.2)
def get_default_params(**kwargs):
pars = OrderedDict(
[(key, kwargs.get(key, DEFAULT_PARAMS[key])) for key in DEFAULT_PARAMS.keys()]
)
p_init = np.array(list(pars.values()))
return p_init
def get_adam_opt_funcs(step_size=1e-3, **kwargs):
"""Retrieve the three functions used by the JAX implementation of Adam
Parameters
----------
step_size : float, optional
Step size parameter defining the Adam configuration.
Default is 0.01
**kwargs : floats, optional
All parameters of the DEFAULT_PARAMS dictionary are accepted
Returns
-------
opt_state : state of the optimizer
Used to carry the parameters from one step to the next
opt_update : update function
Used by the train_step function to take the next step down the gradient
get_params : function to retrieve the parameter array
Operates on opt_state, so that get_params(opt_state) returns the parameter array
"""
opt_init, opt_update, get_params = jax_opt.adam(step_size)
p_init = get_default_params(**kwargs)
opt_state = opt_init(p_init)
return opt_state, opt_update, get_params
@jjit
def jax_sigmoid(x, x0, k, ylo, yhi):
return ylo + (yhi - ylo) / (1 + jnp.exp(-k * (x - x0)))
@jjit
def predict_frac_mass_loss(params, tarr):
t0, k, frac_loss_final = params
frac_mass_loss = jax_sigmoid(tarr, t0, k, 1, frac_loss_final)
return frac_mass_loss
@jjit
def _mse(pred, target):
diff = pred - target
return jnp.mean(diff * diff)
@jjit
def _loss_function(params, loss_data):
(tarr, target_mass_loss) = loss_data
pred = predict_frac_mass_loss(params, tarr)
return _mse(pred, target_mass_loss)
_loss_and_grad_func = jjit(value_and_grad(_loss_function))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment