Created
April 14, 2022 15:58
-
-
Save aphearin/6b7f3ac994bae7df7cf1089af4fbc2d0 to your computer and use it in GitHub Desktop.
Demonstration of fitting a toy model for subhalo mass loss to itself
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
""" | |
from jax import jit as jjit | |
from jax import numpy as jnp | |
import numpy as np | |
from collections import OrderedDict | |
from jax.experimental import optimizers as jax_opt | |
from jax import value_and_grad | |
DEFAULT_PARAMS = OrderedDict(t0=3.0, k=2.0, frac_loss_final=0.2) | |
def get_default_params(**kwargs): | |
pars = OrderedDict( | |
[(key, kwargs.get(key, DEFAULT_PARAMS[key])) for key in DEFAULT_PARAMS.keys()] | |
) | |
p_init = np.array(list(pars.values())) | |
return p_init | |
def get_adam_opt_funcs(step_size=1e-3, **kwargs): | |
"""Retrieve the three functions used by the JAX implementation of Adam | |
Parameters | |
---------- | |
step_size : float, optional | |
Step size parameter defining the Adam configuration. | |
Default is 0.01 | |
**kwargs : floats, optional | |
All parameters of the DEFAULT_PARAMS dictionary are accepted | |
Returns | |
------- | |
opt_state : state of the optimizer | |
Used to carry the parameters from one step to the next | |
opt_update : update function | |
Used by the train_step function to take the next step down the gradient | |
get_params : function to retrieve the parameter array | |
Operates on opt_state, so that get_params(opt_state) returns the parameter array | |
""" | |
opt_init, opt_update, get_params = jax_opt.adam(step_size) | |
p_init = get_default_params(**kwargs) | |
opt_state = opt_init(p_init) | |
return opt_state, opt_update, get_params | |
@jjit | |
def jax_sigmoid(x, x0, k, ylo, yhi): | |
return ylo + (yhi - ylo) / (1 + jnp.exp(-k * (x - x0))) | |
@jjit | |
def predict_frac_mass_loss(params, tarr): | |
t0, k, frac_loss_final = params | |
frac_mass_loss = jax_sigmoid(tarr, t0, k, 1, frac_loss_final) | |
return frac_mass_loss | |
@jjit | |
def _mse(pred, target): | |
diff = pred - target | |
return jnp.mean(diff * diff) | |
@jjit | |
def _loss_function(params, loss_data): | |
(tarr, target_mass_loss) = loss_data | |
pred = predict_frac_mass_loss(params, tarr) | |
return _mse(pred, target_mass_loss) | |
_loss_and_grad_func = jjit(value_and_grad(_loss_function)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment