Skip to content

Instantly share code, notes, and snippets.

@aphearin
Created April 14, 2022 15:58
Show Gist options
  • Save aphearin/6b7f3ac994bae7df7cf1089af4fbc2d0 to your computer and use it in GitHub Desktop.
Save aphearin/6b7f3ac994bae7df7cf1089af4fbc2d0 to your computer and use it in GitHub Desktop.
Demonstration of fitting a toy model for subhalo mass loss to itself
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
"""
"""
from jax import jit as jjit
from jax import numpy as jnp
import numpy as np
from collections import OrderedDict
from jax.experimental import optimizers as jax_opt
from jax import value_and_grad
DEFAULT_PARAMS = OrderedDict(t0=3.0, k=2.0, frac_loss_final=0.2)
def get_default_params(**kwargs):
pars = OrderedDict(
[(key, kwargs.get(key, DEFAULT_PARAMS[key])) for key in DEFAULT_PARAMS.keys()]
)
p_init = np.array(list(pars.values()))
return p_init
def get_adam_opt_funcs(step_size=1e-3, **kwargs):
"""Retrieve the three functions used by the JAX implementation of Adam
Parameters
----------
step_size : float, optional
Step size parameter defining the Adam configuration.
Default is 0.01
**kwargs : floats, optional
All parameters of the DEFAULT_PARAMS dictionary are accepted
Returns
-------
opt_state : state of the optimizer
Used to carry the parameters from one step to the next
opt_update : update function
Used by the train_step function to take the next step down the gradient
get_params : function to retrieve the parameter array
Operates on opt_state, so that get_params(opt_state) returns the parameter array
"""
opt_init, opt_update, get_params = jax_opt.adam(step_size)
p_init = get_default_params(**kwargs)
opt_state = opt_init(p_init)
return opt_state, opt_update, get_params
@jjit
def jax_sigmoid(x, x0, k, ylo, yhi):
return ylo + (yhi - ylo) / (1 + jnp.exp(-k * (x - x0)))
@jjit
def predict_frac_mass_loss(params, tarr):
t0, k, frac_loss_final = params
frac_mass_loss = jax_sigmoid(tarr, t0, k, 1, frac_loss_final)
return frac_mass_loss
@jjit
def _mse(pred, target):
diff = pred - target
return jnp.mean(diff * diff)
@jjit
def _loss_function(params, loss_data):
(tarr, target_mass_loss) = loss_data
pred = predict_frac_mass_loss(params, tarr)
return _mse(pred, target_mass_loss)
_loss_and_grad_func = jjit(value_and_grad(_loss_function))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment