Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save aphearin/208f7a53cb156a49f72040b2b1d54320 to your computer and use it in GitHub Desktop.
Save aphearin/208f7a53cb156a49f72040b2b1d54320 to your computer and use it in GitHub Desktop.
Basic demo of how to implement a physical model for dark matter halo mass assembly in JAX
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "unsigned-addiction",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"id": "synthetic-palace",
"metadata": {},
"source": [
"This notebook provides a quick self-contained startup guide to implementing your physical model based on JAX. We'll start by demonstrating the basic syntax of differentiation with JAX, and then show how to build a simple differentiable model for the mass assembly history of a Milky Way-like dark matter halo. The [JAX documentation](https://jax.readthedocs.io/en/latest/jax.html) is quite good, so if this demo is interesting to you, the official docs are the next place to look. "
]
},
{
"cell_type": "markdown",
"id": "electoral-moisture",
"metadata": {},
"source": [
"## First look at jit and vmap\n",
"\n",
"When writing differentiable code with JAX, the typical pattern is you compose your function so that it behaves on scalar arguments, and then you use `vmap` to create a vectorized map of your function over a selection of its arguments. This next cell shows a basic example that uses `vmap` as well as the `jit` decorator that informs JAX to compile the function for us the first time the function is called."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "sensitive-ethiopia",
"metadata": {},
"outputs": [],
"source": [
"from jax import numpy as jnp\n",
"from jax import vmap\n",
"from jax import jit as jjit\n",
"\n",
"def my_scalar_function(x, y):\n",
" return x*jnp.sin(x) + y*x\n",
"\n",
"my_vmapped_function = jjit(vmap(my_scalar_function, in_axes=(0, None)))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "color-trail",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"\n",
"x = np.linspace(-10, 10, 500)\n",
"__=ax.plot(x, my_vmapped_function(x, 5), label=r'${\\rm y=5}$')\n",
"__=ax.plot(x, my_vmapped_function(x, 1), label=r'${\\rm y=1}$')\n",
"leg = ax.legend()\n",
"xlabel = ax.set_xlabel(r'$x$')\n",
"ylabel = ax.set_ylabel(r'$f(x)$')\n"
]
},
{
"cell_type": "markdown",
"id": "classical-night",
"metadata": {},
"source": [
"## First look at grad\n",
"\n",
"The JAX library has a `grad` function that uses autodiff to calculate the derivative of any JAX-implemented function. The first thing to know about `grad` is that is operates on _scalar-valued_ functions only. So if you want vectorized behavior of your derivative, you just need to write your original function as a scalar, use `grad` to take the derivative, and then use `vmap` to operate on the differentiated function. This next cells show the basic pattern."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "polyphonic-armstrong",
"metadata": {},
"outputs": [],
"source": [
"from jax import grad \n",
"\n",
"my_scalar_function_deriv = grad(my_scalar_function, argnums=(0, ))\n",
"my_vector_function_deriv = jjit(vmap(my_scalar_function_deriv, in_axes=(0, None)))"
]
},
{
"cell_type": "markdown",
"id": "static-movement",
"metadata": {},
"source": [
"The function `my_scalar_function_deriv` works fine when passed scalar arguments:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "eight-supervisor",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(DeviceArray(1.4593868, dtype=float32),)\n"
]
}
],
"source": [
"result = my_scalar_function_deriv(5.0, 1.0)\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"id": "unlimited-eligibility",
"metadata": {},
"source": [
"But `my_scalar_function_deriv` fails when passed a vector argument:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "acute-topic",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "Gradient only defined for scalar-output functions. Output had shape: (500,).",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-b4b59260f837>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmy_scalar_function_deriv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output had shape: (500,).\n\nThe stack trace above excludes JAX-internal frames.\nThe following is the original exception that occurred, unmodified.\n\n--------------------",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-b4b59260f837>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmy_scalar_function_deriv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/opt/miniconda3/envs/diffhalos/lib/python3.8/site-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreraise_with_filtered_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_under_reraiser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/opt/miniconda3/envs/diffhalos/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 758\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mapi_boundary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 759\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 760\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 761\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 762\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/opt/miniconda3/envs/diffhalos/lib/python3.8/site-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreraise_with_filtered_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_under_reraiser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/opt/miniconda3/envs/diffhalos/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 824\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 825\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 826\u001b[0;31m \u001b[0m_check_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 827\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 828\u001b[0m \u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_check_output_dtype_grad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mholomorphic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/opt/miniconda3/envs/diffhalos/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 845\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mShapedArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 846\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 847\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"had shape: {aval.shape}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 848\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 849\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"had abstract value {aval}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: Gradient only defined for scalar-output functions. Output had shape: (500,)."
]
}
],
"source": [
"result = my_scalar_function_deriv(x, 1.0)"
]
},
{
"cell_type": "markdown",
"id": "threaded-reynolds",
"metadata": {},
"source": [
"The above failure is expected. Let's call the vmapped version and inspect the results:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "characteristic-score",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'tuple'>\n",
"1\n",
"500\n"
]
}
],
"source": [
"result = my_vector_function_deriv(x, 1.0)\n",
"print(type(result))\n",
"print(len(result))\n",
"print(len(result[0]))\n"
]
},
{
"cell_type": "markdown",
"id": "sized-privacy",
"metadata": {},
"source": [
"The `my_vector_function_deriv` function returns a tuple with an element for every argument for which we requested a derivative. Let's plot the result for $\\partial f/\\partial x$:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "based-measure",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"\n",
"x = np.linspace(-10, 10, 500)\n",
"__=ax.plot(x, my_vector_function_deriv(x, 5)[0], label=r'${\\rm y=5}$')\n",
"__=ax.plot(x, my_vector_function_deriv(x, 1)[0], label=r'${\\rm y=1}$')\n",
"leg = ax.legend()\n",
"xlabel = ax.set_xlabel(r'$x$')\n",
"ylabel = ax.set_ylabel(r'$\\partial f/\\partial x$')\n"
]
},
{
"cell_type": "markdown",
"id": "sapphire-microwave",
"metadata": {},
"source": [
"### Now for a more physically interesting example\n",
"\n",
"It turns out that the mass assembly history of dark matter halos is reasonably simple to approximate: \n",
"\n",
"$$M_{\\rm halo}(t)\\propto (t/t_0)^{\\alpha(t)},$$ \n",
"where $t_0$ is the present-day cosmic time in Gyr, and $\\alpha(t)$ is a simple sigmoid function that smoothly transitions the power-law growth from a period of rapid early-time growth to a period of slower late-time growth. The next cell gives a JAX-based implementation of this function, and then shows a plot of the typical growth history of a dark matter halo with Milky Way mass."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "hydraulic-bridges",
"metadata": {},
"outputs": [],
"source": [
"@jjit\n",
"def _sigmoid(x, x0, k, ymin, ymax):\n",
" \"\"\"Basic sigmoid function with asymptotic values ymin and ymax\"\"\"\n",
" height_diff = ymax - ymin\n",
" return ymin + height_diff / (1 + jnp.exp(-k * (x - x0)))\n",
"\n",
"@jjit\n",
"def _rolling_plaw_vs_t(t, t0, logmp, x0, k, early, late):\n",
" \"\"\"Kernel of the rolling power-law between halo mass and time.\"\"\"\n",
" logt = jnp.log10(t)\n",
" logt0 = jnp.log10(t0)\n",
" rolling_index = _sigmoid(logt, x0, k, early, late)\n",
" log_halo_mass = rolling_index * (logt - logt0) + logmp\n",
" return log_halo_mass"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "detected-triple",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"TODAY = 13.8\n",
"tarr = np.linspace(0.5, TODAY, 5000)\n",
"MILKY_WAY_MASS = 12.0\n",
"x0, k, early, late = 0.25, 3.0, 1.5, 0.5\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"__=ax.loglog()\n",
"__=ax.plot(tarr, 10**_rolling_plaw_vs_t(tarr, TODAY, MILKY_WAY_MASS, x0, k, early, late))\n",
"\n",
"xlabel = ax.set_xlabel(r'${\\rm cosmic\\ time\\ [Gyr]}$')\n",
"ylabel = ax.set_ylabel(r'$M_{\\rm halo}\\ [M_{\\odot}]$')\n",
"title = ax.set_title(r'${\\rm Milky\\ Way\\ halo\\ mass\\ growth}$')\n"
]
},
{
"cell_type": "markdown",
"id": "veterinary-ocean",
"metadata": {},
"source": [
"Since we have implemented our function in JAX, it is now straightforward to calculate $dM_{\\rm halo}/dt,$ the _mass accretion rate_ of our dark matter halo across time, via automatic differentiation. We just need to take care of the appropriate Jacobi factors since we have implemented our function based on logarithmic variables. The next cell shows how to do that using the usual scalar-then-vmap pattern, and plots the result (also converting to the more conventional units of $M_{\\odot}/{\\rm yr}$"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "appointed-butter",
"metadata": {},
"outputs": [],
"source": [
"_d_log_mh_dt = jjit(\n",
" vmap(grad(_rolling_plaw_vs_t, argnums=0), in_axes=(0, *[None] * 6))\n",
")\n",
"\n",
"@jjit\n",
"def _mass_accretion_rate(t, t0, logmp, x0, k, early, late):\n",
" d_log_mh_dt = _d_log_mh_dt(t, t0, logmp, x0, k, early, late)\n",
" log_mah = _rolling_plaw_vs_t(t, t0, logmp, x0, k, early, late)\n",
" dmhdt = d_log_mh_dt * (10.0 ** (log_mah - 9.0)) / jnp.log10(jnp.e)\n",
" return dmhdt"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "viral-bidding",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"__=ax.loglog()\n",
"__=ax.plot(tarr, _mass_accretion_rate(tarr, TODAY, MILKY_WAY_MASS, x0, k, early, late))\n",
"\n",
"xlabel = ax.set_xlabel(r'${\\rm cosmic\\ time\\ [Gyr]}$')\n",
"ylabel = ax.set_ylabel(r'$dM_{\\rm halo}/dt\\ [M_{\\odot}/{\\rm yr}]$')\n",
"title = ax.set_title(r'${\\rm Milky\\ Way\\ halo\\ mass\\ accretion\\ rate}$')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "defensive-socket",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment