-
-
Save colehaus/fc2586559fe32ce41cab5781929a7924 to your computer and use it in GitHub Desktop.
batched-skew-t-perf.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyPnOKhAvRV+mG/KcmuEPMZE", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/colehaus/fc2586559fe32ce41cab5781929a7924/batched-skew-t-perf.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install numpyro" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "9_a70v1ffvTl", | |
"outputId": "249bae9c-1faa-4704-c33b-ea221f67eeaf" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting numpyro\n", | |
" Downloading numpyro-0.10.1-py3-none-any.whl (292 kB)\n", | |
"\u001b[K |████████████████████████████████| 292 kB 25.6 MB/s \n", | |
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from numpyro) (1.21.6)\n", | |
"Requirement already satisfied: jaxlib>=0.1.65 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.3.15+cuda11.cudnn805)\n", | |
"Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.3.17)\n", | |
"Requirement already satisfied: multipledispatch in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.6.0)\n", | |
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from numpyro) (4.64.0)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (4.1.1)\n", | |
"Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (0.7.1)\n", | |
"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (1.2.0)\n", | |
"Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (1.7.3)\n", | |
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (3.3.0)\n", | |
"Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.2.13->numpyro) (3.8.1)\n", | |
"Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.2.13->numpyro) (5.9.0)\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from multipledispatch->numpyro) (1.15.0)\n", | |
"Installing collected packages: numpyro\n", | |
"Successfully installed numpyro-0.10.1\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "sl3GbgcdfIll", | |
"outputId": "cb08c340-4dcc-4753-bbd5-7172aac3c45d" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"log_prob_skew 0.694505484000004\n", | |
"log_prob_batch_skew 0.7215251789999968\n", | |
"mcmc_skew 2.087386464000005\n", | |
"mcmc_batch_t 1.4569239829999958\n", | |
"mcmc_batch_skew 100.78919051900002\n", | |
"jaxpr_skew 456\n", | |
"jaxpr_batch_t 326\n", | |
"jaxpr_batch_skew 490\n", | |
"log_prob_from_trace_skew 0.20444896600000106\n", | |
"log_prob_from_trace_batch_t 0.16642409199999975\n", | |
"log_prob_from_trace_batch_skew 1.5547143510000012\n", | |
"log_prob_from_trace_no_jit_skew 1.39140366700002\n", | |
"log_prob_from_trace_no_jit_batch_t 1.1463060190000078\n", | |
"log_prob_from_trace_no_jit_batch_skew 1.5753536870000175\n" | |
] | |
} | |
], | |
"source": [ | |
"from __future__ import annotations\n", | |
"\n", | |
"from timeit import timeit\n", | |
"from typing import Any, Callable, Mapping\n", | |
"from typing_extensions import TypedDict\n", | |
"\n", | |
"import jax\n", | |
"from jax import lax\n", | |
"from jax.core import ClosedJaxpr\n", | |
"import jax.numpy as jnp\n", | |
"from jax.random import PRNGKey\n", | |
"from jax.scipy.linalg import cho_solve\n", | |
"from numpyro.distributions import Distribution, MultivariateStudentT, Normal, constraints\n", | |
"from numpyro.distributions.util import promote_shapes, validate_sample\n", | |
"from numpyro.handlers import seed\n", | |
"from numpyro.infer import MCMC, NUTS\n", | |
"from numpy.random import default_rng\n", | |
"from numpy.typing import NDArray\n", | |
"import numpyro\n", | |
"import numpyro.distributions as dist\n", | |
"\n", | |
"def t_cdf_approx(df: NDArray[float] | float, t: NDArray[float] | float):\n", | |
" a = df - 1 / 2\n", | |
" b = 48 * a**2\n", | |
" # Add epsilon to avoid undefined gradient at 0\n", | |
" z = jnp.sqrt(a * jnp.log(1 + t**2 / df) + 1e-24)\n", | |
" u = (\n", | |
" z\n", | |
" + (z**3 + 3 * z) / b\n", | |
" - (4 * z**7 + 33 * z**5 + 240 * z**3 + 855 * z) / (10 * b * (b + 0.8 * z**4 + 100))\n", | |
" )\n", | |
" return Normal(loc=0, scale=1).cdf(u * jnp.sign(t))\n", | |
"\n", | |
"class SkewMultivariateStudentT(Distribution): # type: ignore # pylint: disable=too-many-instance-attributes\n", | |
" arg_constraints = {\n", | |
" \"df\": constraints.positive,\n", | |
" \"loc\": constraints.real_vector,\n", | |
" \"scale_tril\": constraints.lower_cholesky,\n", | |
" \"skewers\": constraints.real_vector,\n", | |
" }\n", | |
" support = constraints.real_vector\n", | |
" reparametrized_params = [\"df\", \"loc\", \"scale_tril\", \"skewers\"]\n", | |
"\n", | |
" def __init__( # pylint: disable=too-many-arguments\n", | |
" self,\n", | |
" df: float,\n", | |
" loc: NDArray[float],\n", | |
" scale_tril: NDArray[float],\n", | |
" skewers: NDArray[float],\n", | |
" validate_args: None = None,\n", | |
" ):\n", | |
" batch_shape = lax.broadcast_shapes(\n", | |
" jnp.shape(df), jnp.shape(loc)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]\n", | |
" )\n", | |
" (self.df,) = promote_shapes(df, shape=batch_shape)\n", | |
" (self.loc,) = promote_shapes(loc, shape=batch_shape + loc.shape[-1:])\n", | |
" (self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])\n", | |
" (self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])\n", | |
"\n", | |
" self._width = scale_tril.shape[-1]\n", | |
"\n", | |
" self._mv_t = MultivariateStudentT(df=df, scale_tril=scale_tril, loc=loc)\n", | |
" eye = jnp.broadcast_to(jnp.eye(self._width), shape=batch_shape + scale_tril.shape[-2:])\n", | |
" prec_scale_tril = jnp.linalg.cholesky(cho_solve((self.scale_tril, True), eye))\n", | |
" self.prec = jnp.einsum(\"...ij,...hj->...ih\", prec_scale_tril, prec_scale_tril)\n", | |
" self._std_devs = jnp.sqrt(jnp.sum(self.scale_tril * self.scale_tril, axis=-1))\n", | |
"\n", | |
" event_shape = jnp.shape(self.scale_tril)[-1:]\n", | |
" super().__init__(\n", | |
" batch_shape=batch_shape,\n", | |
" event_shape=event_shape,\n", | |
" validate_args=validate_args,\n", | |
" )\n", | |
" @validate_sample\n", | |
" def log_prob(self, value: NDArray[float]) -> NDArray[float]:\n", | |
" distance = value - self.loc\n", | |
" Qy = jnp.einsum(\"...j,...jk,...k->...\", distance, self.prec, distance)\n", | |
" df_term = jnp.sqrt((self.df + self._width) / (Qy + self.df))\n", | |
" distance_term = distance / self._std_devs * df_term[..., jnp.newaxis]\n", | |
" x = jnp.squeeze(self.skewers @ distance_term[..., jnp.newaxis])\n", | |
" skew = t_cdf_approx(self.df + self._width, x)\n", | |
" return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)\n", | |
"\n", | |
"shape = (120, 5)\n", | |
"length, width = shape\n", | |
"\n", | |
"t = SkewMultivariateStudentT(df=3.0, loc=jnp.zeros(width), scale_tril=jnp.eye(width), skewers=jnp.zeros(width))\n", | |
"batch_t = SkewMultivariateStudentT(\n", | |
" df=3.0,\n", | |
" loc=jnp.zeros(width),\n", | |
" scale_tril=jnp.repeat(jnp.eye(width)[jnp.newaxis, ...], length, axis=0),\n", | |
" skewers=jnp.zeros(width),\n", | |
")\n", | |
"\n", | |
"def log_prob_wrapper(dist: Distribution[NDArray[float]]):\n", | |
" @jax.jit\n", | |
" def log_prob(xs: NDArray[float]):\n", | |
" return t.log_prob(xs)\n", | |
" ones = jnp.ones(shape)\n", | |
" return timeit(lambda: log_prob(ones), lambda: log_prob(ones), number=100_000)\n", | |
"\n", | |
"BaseLatents = TypedDict(\"BaseLatents\", {\"cov_chol\": NDArray[float], \"loc\": NDArray[float], \"df\": float})\n", | |
"\n", | |
"def base_model(width: int) -> BaseLatents:\n", | |
" variances = numpyro.sample(\"variances\", dist.HalfNormal(1).expand((width,)))\n", | |
" corr_chol = numpyro.sample(\"correlation\", dist.LKJCholesky(width, concentration=1))\n", | |
" cov_chol = jnp.sqrt(variances)[..., jnp.newaxis] * corr_chol\n", | |
"\n", | |
" loc = numpyro.sample(\"loc\", dist.Normal(loc=0, scale=1).expand((width,)))\n", | |
" df = numpyro.sample(\"df\", dist.Gamma(concentration=2, rate=0.1)) + 2\n", | |
"\n", | |
" return {\"cov_chol\": cov_chol, \"loc\": loc, \"df\": df}\n", | |
"\n", | |
"def skew_model(shape: tuple[int, int], observed: NDArray[float]) -> None:\n", | |
" length, width = shape\n", | |
" base_latents = base_model(width)\n", | |
"\n", | |
" with numpyro.plate(\"data\", length):\n", | |
" numpyro.sample(\n", | |
" \"obs\",\n", | |
" SkewMultivariateStudentT(\n", | |
" df=base_latents[\"df\"],\n", | |
" scale_tril=base_latents[\"cov_chol\"],\n", | |
" loc=base_latents[\"loc\"],\n", | |
" skewers=jnp.zeros(width),\n", | |
" ),\n", | |
" obs=observed,\n", | |
" )\n", | |
"\n", | |
"def batch_skew_model(shape: tuple[int, int], observed: NDArray[float]) -> None:\n", | |
" length, width = shape\n", | |
" base_latents = base_model(width)\n", | |
"\n", | |
" with numpyro.plate(\"data\", length):\n", | |
" numpyro.sample(\n", | |
" \"obs\",\n", | |
" SkewMultivariateStudentT(\n", | |
" df=base_latents[\"df\"],\n", | |
" scale_tril=jnp.repeat(base_latents[\"cov_chol\"][jnp.newaxis, ...], repeats=length, axis=0),\n", | |
" loc=base_latents[\"loc\"],\n", | |
" skewers=jnp.zeros(width),\n", | |
" ),\n", | |
" obs=observed,\n", | |
" )\n", | |
"\n", | |
"def batch_t_model(shape: tuple[int, int], observed: NDArray[float]) -> None:\n", | |
" length, width = shape\n", | |
" base_latents = base_model(width)\n", | |
"\n", | |
" with numpyro.plate(\"data\", length):\n", | |
" numpyro.sample(\n", | |
" \"obs\",\n", | |
" dist.MultivariateStudentT(\n", | |
" df=base_latents[\"df\"],\n", | |
" scale_tril=jnp.repeat(base_latents[\"cov_chol\"][jnp.newaxis, ...], repeats=length, axis=0),\n", | |
" loc=base_latents[\"loc\"],\n", | |
" ),\n", | |
" obs=observed,\n", | |
" )\n", | |
"\n", | |
"rng = default_rng(1234)\n", | |
"\n", | |
"def mcmc_wrapper(model: Callable[..., None]):\n", | |
" mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=500, chain_method=\"parallel\", progress_bar=False)\n", | |
"\n", | |
" @jax.jit\n", | |
" def run(data: NDArray[float]):\n", | |
" mcmc.run(PRNGKey(123), shape=shape, observed=data)\n", | |
" return mcmc.get_samples()\n", | |
" data = rng.standard_normal(shape)\n", | |
" return timeit(lambda: run(data), setup=lambda: run(data), number=20)\n", | |
"\n", | |
"def jaxpr_wrapper(model: Callable[..., None]) -> ClosedJaxpr:\n", | |
" def run(data: NDArray[float]):\n", | |
" return log_prob_from_trace(\n", | |
" seed(model, PRNGKey(0)), model_args=(), model_kwargs={\"shape\": shape, \"observed\": data}, params={}\n", | |
" )\n", | |
" data = rng.standard_normal(shape)\n", | |
" return jax.make_jaxpr(run)(data)\n", | |
"\n", | |
"def log_prob_from_trace(\n", | |
" model: Callable[..., None],\n", | |
" model_args: tuple[Any, ...],\n", | |
" model_kwargs: Mapping[str, Any],\n", | |
" params: Mapping[str, NDArray[float]],\n", | |
"):\n", | |
" site = numpyro.handlers.trace(model).get_trace(*model_args, **model_kwargs)[\"obs\"]\n", | |
" assert site[\"type\"] == \"sample\"\n", | |
" return site[\"fn\"].log_prob(site[\"value\"])\n", | |
"\n", | |
"def log_prob_from_trace_wrapper(model: Callable[..., None]):\n", | |
" @jax.jit\n", | |
" def run(data: NDArray[float]):\n", | |
" return log_prob_from_trace(\n", | |
" seed(model, PRNGKey(0)), model_args=(), model_kwargs={\"shape\": shape, \"observed\": data}, params={}\n", | |
" )\n", | |
" data = rng.standard_normal(shape)\n", | |
" return timeit(lambda: run(data), setup=lambda: run(data), number=10_000)\n", | |
"\n", | |
"def log_prob_from_trace_no_jit_wrapper(model: Callable[..., None]):\n", | |
" def run(data: NDArray[float]):\n", | |
" return log_prob_from_trace(\n", | |
" seed(model, PRNGKey(0)), model_args=(), model_kwargs={\"shape\": shape, \"observed\": data}, params={}\n", | |
" )\n", | |
" data = rng.standard_normal(shape)\n", | |
" return timeit(lambda: run(data), setup=lambda: run(data), number=100)\n", | |
"\n", | |
"print(\"log_prob_skew\", log_prob_wrapper(t))\n", | |
"print(\"log_prob_batch_skew\", log_prob_wrapper(batch_t))\n", | |
"\n", | |
"print(\"mcmc_skew\", mcmc_wrapper(skew_model))\n", | |
"print(\"mcmc_batch_t\", mcmc_wrapper(batch_t_model))\n", | |
"print(\"mcmc_batch_skew\", mcmc_wrapper(batch_skew_model))\n", | |
"\n", | |
"print(\"jaxpr_skew\", jaxpr_wrapper(skew_model).pretty_print().count(\"\\n\"))\n", | |
"print(\"jaxpr_batch_t\", jaxpr_wrapper(batch_t_model).pretty_print().count(\"\\n\"))\n", | |
"print(\"jaxpr_batch_skew\", jaxpr_wrapper(batch_skew_model).pretty_print().count(\"\\n\"))\n", | |
"\n", | |
"print(\"log_prob_from_trace_skew\", log_prob_from_trace_wrapper(skew_model))\n", | |
"print(\"log_prob_from_trace_batch_t\", log_prob_from_trace_wrapper(batch_t_model))\n", | |
"print(\"log_prob_from_trace_batch_skew\", log_prob_from_trace_wrapper(batch_skew_model))\n", | |
"\n", | |
"print(\"log_prob_from_trace_no_jit_skew\", log_prob_from_trace_no_jit_wrapper(skew_model))\n", | |
"print(\"log_prob_from_trace_no_jit_batch_t\", log_prob_from_trace_no_jit_wrapper(batch_t_model))\n", | |
"print(\"log_prob_from_trace_no_jit_batch_skew\", log_prob_from_trace_no_jit_wrapper(batch_skew_model))\n" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment