Skip to content

Instantly share code, notes, and snippets.

@twiecki
Last active April 13, 2021 20:27
Show Gist options
  • Save twiecki/a77104299535b64b58953de3c84df56f to your computer and use it in GitHub Desktop.
Save twiecki/a77104299535b64b58953de3c84df56f to your computer and use it in GitHub Desktop.
stochastic_volatility.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"cell_type": "markdown",
"source": "<a href=\"https://colab.research.google.com/gist/junpenglao/c8b884797f950d1ef033ca69b253a4a0/stochastic_volatility_jax.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
},
{
"metadata": {
"id": "QNyI1zUJkAKw"
},
"cell_type": "markdown",
"source": "# Stochastic Volatility model"
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lGnXuq-QkAK1",
"outputId": "9caf1d60-7859-4f37-e01f-9906969fb7e6",
"trusted": false
},
"cell_type": "code",
"source": "import matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nimport pymc3 as pm\nimport pymc3.sampling_jax\nimport arviz as az\n\n%matplotlib inline\n\nnp.random.seed(0)",
"execution_count": 1,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": "/Users/twiecki/projects/pymc/pymc3/sampling_jax.py:24: UserWarning: This module is experimental.\n warnings.warn(\"This module is experimental.\")\n"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 235
},
"id": "e5QZXr0ckAK3",
"outputId": "1cfe592d-850f-4c3b-a8e3-e05dfb26cd9c",
"trusted": false
},
"cell_type": "code",
"source": "returns = pd.read_csv(pm.get_data(\"SP500.csv\"), index_col=\"Date\")\nreturns[\"change\"] = np.log(returns[\"Close\"]).diff()\nreturns = returns.dropna()\nreturns.head()",
"execution_count": 2,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>Close</th>\n <th>change</th>\n </tr>\n <tr>\n <th>Date</th>\n <th></th>\n <th></th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>2008-05-05</th>\n <td>1407.489990</td>\n <td>-0.004544</td>\n </tr>\n <tr>\n <th>2008-05-06</th>\n <td>1418.260010</td>\n <td>0.007623</td>\n </tr>\n <tr>\n <th>2008-05-07</th>\n <td>1392.569946</td>\n <td>-0.018280</td>\n </tr>\n <tr>\n <th>2008-05-08</th>\n <td>1397.680054</td>\n <td>0.003663</td>\n </tr>\n <tr>\n <th>2008-05-09</th>\n <td>1388.280029</td>\n <td>-0.006748</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " Close change\nDate \n2008-05-05 1407.489990 -0.004544\n2008-05-06 1418.260010 0.007623\n2008-05-07 1392.569946 -0.018280\n2008-05-08 1397.680054 0.003663\n2008-05-09 1388.280029 -0.006748"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gMYC4zx9kAK3",
"outputId": "726bcb51-4971-4643-9c56-f264a099ca8f",
"trusted": false
},
"cell_type": "code",
"source": "with pm.Model(check_bounds=False) as model:\n step_size = pm.Exponential(\"step_size\", 10)\n volatility = pm.GaussianRandomWalk(\"volatility\", sigma=step_size, \n shape=returns.shape[0], \n init=pm.Normal.dist(0, step_size))\n nu = pm.Exponential(\"nu\", 0.1)\n obs = pm.StudentT(\n \"obs\", nu=nu, sigma=np.exp(volatility), observed=returns[\"change\"]\n )",
"execution_count": 3,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": "/Users/twiecki/projects/pymc/pymc3/distributions/continuous.py:138: UserWarning: The variable specified for nu has negative support for StudentT, likely making it unsuitable for this parameter.\n warnings.warn(msg)\n"
}
]
},
{
"metadata": {
"id": "RVKPflYjkAK4",
"trusted": false
},
"cell_type": "code",
"source": "# %%time\n# with model:\n# trace = pm.sampling_jax.sample_numpyro_nuts(2000, tune=2000)",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"id": "_-RM-GykkAK4",
"trusted": false
},
"cell_type": "code",
"source": "# %%time\n# with model:\n# trace = pm.sampling_jax.sample_tfp_nuts(2000, tune=2000)",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"id": "x_uEuuvTkAK4",
"trusted": false
},
"cell_type": "code",
"source": "import os\nimport jax\n\nimport matplotlib\nimport matplotlib.dates as mdates\nimport matplotlib.pyplot as plt\n\nimport jax.numpy as jnp\nimport jax.random as random\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import SP500, load_dataset\nfrom numpyro.infer.hmc import hmc\nfrom numpyro.infer.util import initialize_model\nfrom numpyro.util import fori_collect\nfrom numpyro.infer import MCMC, NUTS\n\nmatplotlib.use('Agg') # noqa: E402\n\n\ndef model_numpyro(returns):\n step_size = numpyro.sample('step_size',\n dist.Exponential(10.))\n volatility = numpyro.sample('volatility',\n dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.shape(returns)[0]))\n nu = numpyro.sample('nu', dist.Exponential(.1))\n return numpyro.sample('r', dist.StudentT(df=nu, loc=0., scale=jnp.exp(volatility)),\n obs=returns)",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "naiiXZPpkAK5",
"outputId": "3de89d6b-daa7-4397-bb25-ebdd1f74cdec",
"trusted": false
},
"cell_type": "code",
"source": "init_rng_key, sample_rng_key = random.split(random.PRNGKey(1))\nmodel_info = initialize_model(init_rng_key, model_numpyro, model_args=(returns[\"change\"].values,))\ninit_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')\nhmc_state = init_kernel(model_info.param_info, 2000, rng_key=sample_rng_key)\nhmc_state.z",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "{'nu': DeviceArray(1.85527508, dtype=float64),\n 'step_size': DeviceArray(0.86971885, dtype=float64),\n 'volatility': DeviceArray([-0.05782139, -1.0270261 , 1.03265798, ..., 1.56501771,\n -0.18145358, 1.22302515], dtype=float64)}"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"id": "iMjkPn08kAK5",
"trusted": false
},
"cell_type": "code",
"source": "# Run NUTS.\n\n# num_warmup, num_samples = 1000, 2000\n# kernel = NUTS(model_numpyro)\n# mcmc = MCMC(kernel, num_warmup, num_samples)\n# mcmc.run(sample_rng_key, returns=returns[\"change\"].values)\n# mcmc.print_summary()\n# samples_1 = mcmc.get_samples()",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5_p-7J2_kAK5",
"outputId": "cdcb1006-3658-4d05-ada4-33d1196c96e4",
"trusted": false
},
"cell_type": "code",
"source": "fn = jax.jit(model_info.potential_fn)\nprint(fn(hmc_state.z))\n%timeit fn(hmc_state.z).block_until_ready()",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "8714.197040735067\n594 µs ± 50.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HYlcxAn6kAK6",
"outputId": "e758fa72-b9e2-4f0e-b834-7a1da4ab4aae",
"trusted": false
},
"cell_type": "code",
"source": "from theano.link.jax.jax_dispatch import jax_funcify\nimport theano\n\nfgraph = theano.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])\nfns = jax_funcify(fgraph)\nlogp_fn_jax = fns[0]\n\nrv_names = [rv.name for rv in model.free_RVs]\ninit_state = [model.test_point[rv_name] for rv_name in rv_names]\nfn2 = jax.jit(logp_fn_jax)\nfn2(*init_state)\n%timeit fn2(*init_state).block_until_ready()",
"execution_count": 7,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "671 µs ± 45.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IHldqMazkep8",
"outputId": "d65bf37a-713e-4700-c896-cc13abb07619",
"trusted": false
},
"cell_type": "code",
"source": "fn_with_grad = jax.jit(jax.value_and_grad(model_info.potential_fn))\nfn_with_grad(hmc_state.z)\n%timeit fn(hmc_state.z).block_until_ready()\nfn2_with_grad = jax.jit(jax.value_and_grad(logp_fn_jax))\nfn2_with_grad(*init_state)\n%timeit fn2(*init_state).block_until_ready()",
"execution_count": 8,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "667 µs ± 90.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n678 µs ± 57.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TOBIX1HNkAK6",
"outputId": "a2672950-f5c8-4716-8e59-cd0c9d9d1053",
"trusted": false
},
"cell_type": "code",
"source": "init_state2 = [hmc_state.z['step_size'], hmc_state.z['volatility'], hmc_state.z['nu']]\nfn(hmc_state.z), fn2(*init_state2)",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "(DeviceArray(8714.19704074, dtype=float64),\n DeviceArray(-8714.19704074, dtype=float64))"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "189J9pS6kAK6",
"outputId": "57cca92b-35de-49e9-d5b6-53d93d558e4b",
"trusted": false
},
"cell_type": "code",
"source": "model.test_point",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "{'step_size_log__': array(-2.66909801),\n 'volatility': array([0., 0., 0., ..., 0., 0., 0.]),\n 'nu_log__': array(1.93607218)}"
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "w_ZlLLOmkAK6",
"outputId": "53587241-617d-4e0c-a5ed-ba14fc5b3208",
"trusted": false
},
"cell_type": "code",
"source": "z2 = {\n 'step_size': model.test_point['step_size_log__'],\n 'nu': model.test_point['nu_log__'],\n 'volatility': model.test_point['volatility']\n}\nfn(z2), fn2(*init_state)",
"execution_count": 11,
"outputs": [
{
"data": {
"text/plain": "(DeviceArray(-2307.90263155, dtype=float64),\n DeviceArray(2307.90263155, dtype=float64))"
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"id": "9Pcu4Q2EkAK7",
"trusted": false
},
"cell_type": "code",
"source": "draws=2000\ntune=2000\nchains=4\ntarget_accept=0.8\nrandom_seed=10\nprogress_bar=True\nseed = jax.random.PRNGKey(random_seed)",
"execution_count": 12,
"outputs": []
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CIyw3GLBkAK7",
"outputId": "c0307a4a-7d18-4ece-f83e-6a859026dab8",
"trusted": false
},
"cell_type": "code",
"source": "%%time\ninit_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), hmc_state.z)\n\n@jax.jit\ndef _sample(current_state, seed):\n step_size = jax.tree_map(jax.numpy.ones_like, init_state)\n nuts_kernel = NUTS(\n potential_fn=model_info.potential_fn,\n # model=model,\n target_accept_prob=target_accept,\n adapt_step_size=True,\n adapt_mass_matrix=True,\n dense_mass=False,\n )\n\n pmap_numpyro = MCMC(\n nuts_kernel,\n num_warmup=tune,\n num_samples=draws,\n num_chains=chains,\n postprocess_fn=None,\n chain_method=\"parallel\",\n progress_bar=progress_bar,\n )\n\n pmap_numpyro.run(seed, init_params=current_state, extra_fields=(\"num_steps\",))\n samples = pmap_numpyro.get_samples(group_by_chain=True)\n leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)[\"num_steps\"]\n return samples, leapfrogs_taken\n\nprint(\"Compiling...\")\ntic2 = pd.Timestamp.now()\nmap_seed = jax.random.split(seed, chains)\nposterior, leapfrogs_taken = _sample(init_state_batched, map_seed)\nleapfrogs_taken.block_until_ready()\n# map_seed = jax.random.split(seed, chains)\n# mcmc_samples = _sample(init_state_batched, map_seed)\n# tic4 = pd.Timestamp.now()\n# print(\"Sampling time = \", tic4 - tic3)\n\ntic3 = pd.Timestamp.now()\nprint(\"Compilation + sampling time = \", tic3 - tic2)",
"execution_count": 20,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Compiling...\nCompilation + sampling time = 0 days 00:05:27.456873\nCPU times: user 11min, sys: 16.8 s, total: 11min 17s\nWall time: 5min 27s\n"
}
]
},
{
"metadata": {
"id": "z7u27qErPQ-J",
"trusted": false
},
"cell_type": "code",
"source": "az_trace = az.from_dict(posterior=posterior)",
"execution_count": 21,
"outputs": []
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Df2t46y9LbcU",
"outputId": "e18b968e-be56-4341-9e26-9b486af96e19",
"trusted": false
},
"cell_type": "code",
"source": "%%time\ninit_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), hmc_state.z)\ninit_state_batched_ = [init_state_batched['step_size'], init_state_batched['volatility'], init_state_batched['nu']]\n\n@jax.jit\ndef _sample(current_state, seed):\n step_size = jax.tree_map(jax.numpy.ones_like, init_state)\n nuts_kernel = NUTS(\n potential_fn=lambda x: -logp_fn_jax(*x),\n # model=model,\n target_accept_prob=target_accept,\n adapt_step_size=True,\n adapt_mass_matrix=True,\n dense_mass=False,\n )\n\n pmap_numpyro = MCMC(\n nuts_kernel,\n num_warmup=tune,\n num_samples=draws,\n num_chains=chains,\n postprocess_fn=None,\n chain_method=\"parallel\",\n progress_bar=progress_bar,\n )\n\n pmap_numpyro.run(seed, init_params=current_state, extra_fields=(\"num_steps\",))\n samples = pmap_numpyro.get_samples(group_by_chain=True)\n leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)[\"num_steps\"]\n return samples, leapfrogs_taken\n\nprint(\"Compiling...\")\ntic2 = pd.Timestamp.now()\nposterior_pymc3, leapfrogs_taken_pymc3 = _sample(init_state_batched_, map_seed)\nleapfrogs_taken_pymc3.block_until_ready()\n# map_seed = jax.random.split(seed, chains)\n# mcmc_samples = _sample(init_state_batched, map_seed)\n# tic4 = pd.Timestamp.now()\n# print(\"Sampling time = \", tic4 - tic3)\n\ntic3 = pd.Timestamp.now()\nprint(\"Compilation + sampling time = \", tic3 - tic2)",
"execution_count": 22,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Compiling...\nCompilation + sampling time = 0 days 00:06:11.912123\nCPU times: user 14min 45s, sys: 17.2 s, total: 15min 2s\nWall time: 6min 12s\n"
}
]
},
{
"metadata": {
"id": "QVNVrt7rPaZK",
"trusted": false
},
"cell_type": "code",
"source": "az_trace_pymc3 = az.from_dict(posterior={k:v for k, v in zip(['step_size','volatility','nu'], posterior_pymc3)})",
"execution_count": 16,
"outputs": []
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XGrthQIXM-3j",
"outputId": "f4b00a42-1c1b-4d0c-ff6c-b86afe7b0832",
"trusted": false
},
"cell_type": "code",
"source": "leapfrogs_taken_pymc3.mean(), leapfrogs_taken.mean()",
"execution_count": 23,
"outputs": [
{
"data": {
"text/plain": "(DeviceArray(173.944, dtype=float64), DeviceArray(140.28, dtype=float64))"
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "leapfrogs_taken_pymc3",
"execution_count": 18,
"outputs": [
{
"data": {
"text/plain": "DeviceArray([[127, 127, 127, ..., 127, 127, 127],\n [127, 127, 127, ..., 127, 127, 127],\n [127, 255, 255, ..., 255, 127, 255],\n [127, 255, 127, ..., 255, 127, 127]], dtype=int64)"
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "leapfrogs_taken",
"execution_count": 19,
"outputs": [
{
"data": {
"text/plain": "DeviceArray([[127, 127, 127, ..., 127, 127, 127],\n [127, 127, 127, ..., 127, 127, 127],\n [255, 255, 127, ..., 127, 127, 127],\n [127, 127, 127, ..., 127, 127, 127]], dtype=int64)"
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1ByuHSGZkAK7",
"outputId": "050ed585-f1ac-41ee-b7df-1f707e675240",
"trusted": false
},
"cell_type": "code",
"source": "az.plot_trace(az_trace_pymc3, var_names=[\"step_size\", \"nu\"])",
"execution_count": 25,
"outputs": [
{
"data": {
"text/plain": "array([[<AxesSubplot:title={'center':'step_size'}>,\n <AxesSubplot:title={'center':'step_size'}>],\n [<AxesSubplot:title={'center':'nu'}>,\n <AxesSubplot:title={'center':'nu'}>]], dtype=object)"
},
"execution_count": 25,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"id": "cdc9DMLRkAK8",
"trusted": false
},
"cell_type": "code",
"source": "az.plot_trace(az_trace, var_names=[\"step_size\", \"nu\"])",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"id": "Xk6FvW3LkAK8",
"trusted": false
},
"cell_type": "code",
"source": "from jax import make_jaxpr",
"execution_count": 24,
"outputs": []
},
{
"metadata": {
"id": "u8uXffVEkAK8",
"trusted": false
},
"cell_type": "code",
"source": "print(make_jaxpr(logp_fn_jax)(*init_state))",
"execution_count": 25,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "{ lambda o bh bk bn bq cc cv cw cy dr dt dw ; a b c.\n let d = exp a\n e = mul 10.0 d\n f = sub 2.302585092994046 e\n g = add f a\n h = reduce_sum[ axes=() ] g\n i = reduce_sum[ axes=() ] h\n j = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] i\n k = exp a\n l = pow k -2.0\n m = mul 1.0 l\n n = neg m\n p = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))\n slice_sizes=(1,) ] b o\n q = sub p 0.0\n r = pow q 2.0\n s = mul n r\n t = exp a\n u = pow t -2.0\n v = mul 1.0 u\n w = div v 3.141592653589793\n x = div w 2.0\n y = log x\n z = add s y\n ba = div z 2.0\n bb = exp a\n bc = mul 1.0 bb\n bd = pow bc -2.0\n be = mul 1.0 bd\n bf = neg be\n bg = reshape[ dimensions=None\n new_sizes=(1,) ] bf\n bi = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))\n slice_sizes=(2904,) ] b bh\n bj = broadcast_in_dim[ broadcast_dimensions=(0,)\n shape=(2904,) ] bi\n bl = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))\n slice_sizes=(2904,) ] b bk\n bm = broadcast_in_dim[ broadcast_dimensions=(0,)\n shape=(2904,) ] bl\n bo = add bm bn\n bp = sub bj bo\n br = pow bp bq\n bs = mul bg br\n bt = exp a\n bu = mul 1.0 bt\n bv = pow bu -2.0\n bw = mul 1.0 bv\n bx = div bw 3.141592653589793\n by = div bx 2.0\n bz = log by\n ca = reshape[ dimensions=None\n new_sizes=(1,) ] bz\n cb = add bs ca\n cd = div cb cc\n ce = reduce_sum[ axes=(0,) ] cd\n cf = add ba ce\n cg = reduce_sum[ axes=() ] cf\n ch = reduce_sum[ axes=() ] cg\n ci = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] ch\n cj = exp c\n ck = mul 0.1 cj\n cl = sub -2.3025850929940455 ck\n cm = add cl c\n cn = reduce_sum[ axes=() ] cm\n co = reduce_sum[ axes=() ] cn\n cp = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] co\n cq = exp c\n cr = add cq 1.0\n cs = div cr 2.0\n ct = lgamma cs\n cu = reshape[ dimensions=None\n new_sizes=(1,) ] ct\n cx = exp b\n cz = pow cx cy\n da = mul cw cz\n db = exp c\n dc = mul db 3.141592653589793\n dd = reshape[ dimensions=None\n new_sizes=(1,) ] dc\n de = div da dd\n df = log de\n dg = mul cv df\n dh = add cu dg\n di = exp c\n dj = div di 2.0\n dk = lgamma dj\n dl = reshape[ dimensions=None\n new_sizes=(1,) ] dk\n dm = sub dh dl\n dn = exp c\n do = add dn 1.0\n dp = div do 2.0\n dq = reshape[ dimensions=None\n new_sizes=(1,) ] dp\n ds = exp b\n du = pow ds dt\n dv = mul dr du\n dx = mul dv dw\n dy = exp c\n dz = reshape[ dimensions=None\n new_sizes=(1,) ] dy\n ea = div dx dz\n eb = log1p ea\n ec = mul dq eb\n ed = sub dm ec\n ee = reduce_sum[ axes=(0,) ] ed\n ef = reduce_sum[ axes=() ] ee\n eg = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] ef\n eh = concatenate[ dimension=0 ] j ci cp eg\n ei = reduce_sum[ axes=(0,) ] eh\n in (ei,) }\n"
}
]
},
{
"metadata": {
"id": "3sCUA5pIkAK9",
"trusted": false
},
"cell_type": "code",
"source": "print(make_jaxpr(model_info.potential_fn)(z2))",
"execution_count": 26,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "{ lambda m v y bz ; a b c.\n let d = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] b\n e = reduce_sum[ axes=(0,) ] d\n f = reduce_sum[ axes=() ] e\n g = add 0.0 f\n h = exp b\n i = mul h 10.0\n j = sub 2.302585092994046 i\n k = reduce_sum[ axes=() ] j\n l = add g k\n n = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))\n slice_sizes=(1,) ] c m\n o = sub n 0.0\n p = div o h\n q = integer_pow[ y=2 ] p\n r = mul q -0.5\n s = mul 2.5066282746310002 h\n t = log s\n u = sub r t\n w = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))\n slice_sizes=(2904,) ] c v\n x = broadcast_in_dim[ broadcast_dimensions=(0,)\n shape=(2904,) ] w\n z = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))\n slice_sizes=(2904,) ] c y\n ba = broadcast_in_dim[ broadcast_dimensions=(0,)\n shape=(2904,) ] z\n bb = sub x ba\n bc = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] h\n bd = div bb bc\n be = integer_pow[ y=2 ] bd\n bf = mul be -0.5\n bg = mul 2.5066282746310002 bc\n bh = log bg\n bi = sub bf bh\n bj = reduce_sum[ axes=(0,) ] bi\n bk = add u bj\n bl = reduce_sum[ axes=() ] bk\n bm = add l bl\n bn = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] a\n bo = reduce_sum[ axes=(0,) ] bn\n bp = reduce_sum[ axes=() ] bo\n bq = add bm bp\n br = exp a\n bs = mul br 0.1\n bt = sub -2.3025850929940455 bs\n bu = reduce_sum[ axes=() ] bt\n bv = add bq bu\n bw = broadcast_in_dim[ broadcast_dimensions=()\n shape=(2905,) ] br\n bx = add bw 1.0\n by = mul bx -0.5\n ca = exp c\n cb = div bz ca\n cc = pow cb 2.0\n cd = div cc bw\n ce = log1p cd\n cf = mul by ce\n cg = log ca\n ch = log bw\n ci = mul ch 0.5\n cj = add cg ci\n ck = add cj 0.5723649429247001\n cl = mul bw 0.5\n cm = lgamma cl\n cn = add ck cm\n co = add bw 1.0\n cp = mul co 0.5\n cq = lgamma cp\n cr = sub cn cq\n cs = sub cf cr\n ct = reduce_sum[ axes=(0,) ] cs\n cu = add bv ct\n cv = neg cu\n in (cv,) }\n"
}
]
},
{
"metadata": {
"id": "i0qhlJo6kAK9"
},
"cell_type": "markdown",
"source": "## References\n\n1. Hoffman & Gelman. (2011). [The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo](http://arxiv.org/abs/1111.4246). "
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/a77104299535b64b58953de3c84df56f"
},
"anaconda-cloud": {},
"colab": {
"collapsed_sections": [],
"include_colab_link": true,
"name": "c8b884797f950d1ef033ca69b253a4a0#file-stochastic_volatility_jax-ipynb",
"provenance": []
},
"gist": {
"id": "a77104299535b64b58953de3c84df56f",
"data": {
"description": "stochastic_volatility.ipynb",
"public": true
}
},
"kernelspec": {
"name": "pymc3theano",
"display_name": "pymc3theano",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.8.5",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
@brandonwillard
Copy link

I just noticed that this example isn't optimizing the FunctionGraph.

@twiecki
Copy link
Author

twiecki commented Apr 12, 2021 via email

@brandonwillard
Copy link

brandonwillard commented Apr 12, 2021

Doing something like the following will optimize the FunctionGraph in roughly the same way that aesara.function does:

from aesara.compile.mode import FAST_RUN

fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
_ = FAST_RUN.optimizer.optimize(fgraph)

Without that step, the JAX function will take the exact form of the log-likelihood graph determined by the Distribution.logp implementations (i.e. no CSE, fusions, in-place operations, etc.).

@twiecki
Copy link
Author

twiecki commented Apr 13, 2021

I suppose pm.sample() already does this?

@brandonwillard
Copy link

This looks like something we need to update in PyMC3, as well.

Here's a quick comparison of the timing with and without graph optimizations (the example/model is taken from this notebook):

fgraph = model.logp.f.maker.fgraph
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 198 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 236 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment