Skip to content

Instantly share code, notes, and snippets.

@smsharma
Created September 12, 2023 15:56
Show Gist options
  • Save smsharma/dccc4a1a1f2ca8e9434a96bfa8f0057b to your computer and use it in GitHub Desktop.
Save smsharma/dccc4a1a1f2ca8e9434a96bfa8f0057b to your computer and use it in GitHub Desktop.
import time
import jax
import jax.numpy as jnp
# Global flag to set a specific platform, must be used at startup.
jax.config.update("jax_platform_name", "gpu")
# from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController, SaveAt
from jax.experimental.ode import odeint
def get_t(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec):
rho_tot_vec = rho_g_vec + rho_nu_vec + rho_NP_vec
rho_tot_flip = jnp.flip(rho_tot_vec)
P_tot_flip = jnp.flip(P_NP_vec)
# returning 0 here cuts us down to 0.05s on GPU. This is because interpolation
# is slow--using dot_interp from https://github.com/google/jax/issues/16182 speeds
# this step up.
def P_tot(rho_tot):
return jnp.interp(rho_tot, rho_tot_flip, P_tot_flip)
# return 0.0
def dt_prime(rho_tot, t, args):
# hcb.id_print((rho_tot,t))
return 1.0 / (-3.0 * (rho_tot + P_tot(rho_tot)))
rho_tot_init = rho_tot_vec[0]
rho_tot_fin = rho_tot_vec[-1]
# sol_t = diffeqsolve(
# ODETerm(dt_prime),
# Dopri5(),
# t0=rho_tot_init,
# t1=rho_tot_fin,
# y0=1,
# dt0=None,
# max_steps=4096,
# saveat=SaveAt(ts=rho_tot_vec),
# stepsize_controller=PIDController(rtol=1e-4, atol=1e-4),
# )
# return sol_t.ys
sol = odeint(
dt_prime,
1.0,
jnp.linspace(rho_tot_init, rho_tot_fin, 1000),
(),
rtol=1e-4,
atol=1e-4,
mxstep=4096,
)
return sol
@jax.jit
def get_abundances(
rho_g_vec,
rho_nu_vec,
rho_NP_vec,
P_NP_vec,
):
t_vec = get_t(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec)
return jnp.array([t_vec[3], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
@jax.jit
def rho_gam(T):
return 2 * jnp.pi**2 / 30.0 * T**4
##############################################################################################
##############################################################################################
##############################################################################################
##############################################################################################
T_gamma_array = jnp.logspace(jnp.log10(8.617), jnp.log10(3.83e-4), num=418)
test_array_CPU = jnp.logspace(
jnp.log10(3.17473950e03), jnp.log10(1.68692195e-15), num=418
)
rho_extra_array_CPU = jnp.concatenate(
(
jnp.logspace(jnp.log10(6.34941711e03), jnp.log10(1.63508741e-20), num=342),
jnp.zeros(76),
)
)
T_gamma_array_gp = T_gamma_array
rho_gamma_array = rho_gam(T_gamma_array_gp)
test_array_gp = test_array_CPU
rho_extra_array_gp = rho_extra_array_CPU
n_batch = 1024
# Batched versions of these 3 arrays
rho_gamma_array = jnp.tile(rho_gamma_array, (n_batch, 1))
test_array_gp = jnp.tile(test_array_gp, (n_batch, 1))
rho_extra_array_gp = jnp.tile(rho_extra_array_gp, (n_batch, 1))
##############################################################################################
##############################################################################################
# compilation run
start_time = time.time()
res = jax.vmap(get_abundances)(
rho_gamma_array, test_array_gp, rho_extra_array_gp, rho_extra_array_gp / 3
)
Neff_vec = res[:, 0]
print("finished in %s seconds" % (time.time() - start_time))
# timing runs
for i in range(10):
start_time = time.time()
res = jax.vmap(get_abundances)(
rho_gamma_array, test_array_gp, rho_extra_array_gp, rho_extra_array_gp / 3
)
Neff_vec = res[:, 0]
print("finished in %s seconds" % (time.time() - start_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment