Skip to content

Instantly share code, notes, and snippets.

@colehaus
Created September 6, 2022 20:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save colehaus/fc2586559fe32ce41cab5781929a7924 to your computer and use it in GitHub Desktop.
Save colehaus/fc2586559fe32ce41cab5781929a7924 to your computer and use it in GitHub Desktop.
batched-skew-t-perf.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPnOKhAvRV+mG/KcmuEPMZE",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/colehaus/fc2586559fe32ce41cab5781929a7924/batched-skew-t-perf.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install numpyro"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9_a70v1ffvTl",
"outputId": "249bae9c-1faa-4704-c33b-ea221f67eeaf"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting numpyro\n",
" Downloading numpyro-0.10.1-py3-none-any.whl (292 kB)\n",
"\u001b[K |████████████████████████████████| 292 kB 25.6 MB/s \n",
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from numpyro) (1.21.6)\n",
"Requirement already satisfied: jaxlib>=0.1.65 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.3.15+cuda11.cudnn805)\n",
"Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.3.17)\n",
"Requirement already satisfied: multipledispatch in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.6.0)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from numpyro) (4.64.0)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (4.1.1)\n",
"Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (0.7.1)\n",
"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (1.2.0)\n",
"Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (1.7.3)\n",
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (3.3.0)\n",
"Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.2.13->numpyro) (3.8.1)\n",
"Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.2.13->numpyro) (5.9.0)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from multipledispatch->numpyro) (1.15.0)\n",
"Installing collected packages: numpyro\n",
"Successfully installed numpyro-0.10.1\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sl3GbgcdfIll",
"outputId": "cb08c340-4dcc-4753-bbd5-7172aac3c45d"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"log_prob_skew 0.694505484000004\n",
"log_prob_batch_skew 0.7215251789999968\n",
"mcmc_skew 2.087386464000005\n",
"mcmc_batch_t 1.4569239829999958\n",
"mcmc_batch_skew 100.78919051900002\n",
"jaxpr_skew 456\n",
"jaxpr_batch_t 326\n",
"jaxpr_batch_skew 490\n",
"log_prob_from_trace_skew 0.20444896600000106\n",
"log_prob_from_trace_batch_t 0.16642409199999975\n",
"log_prob_from_trace_batch_skew 1.5547143510000012\n",
"log_prob_from_trace_no_jit_skew 1.39140366700002\n",
"log_prob_from_trace_no_jit_batch_t 1.1463060190000078\n",
"log_prob_from_trace_no_jit_batch_skew 1.5753536870000175\n"
]
}
],
"source": [
"from __future__ import annotations\n",
"\n",
"from timeit import timeit\n",
"from typing import Any, Callable, Mapping\n",
"from typing_extensions import TypedDict\n",
"\n",
"import jax\n",
"from jax import lax\n",
"from jax.core import ClosedJaxpr\n",
"import jax.numpy as jnp\n",
"from jax.random import PRNGKey\n",
"from jax.scipy.linalg import cho_solve\n",
"from numpyro.distributions import Distribution, MultivariateStudentT, Normal, constraints\n",
"from numpyro.distributions.util import promote_shapes, validate_sample\n",
"from numpyro.handlers import seed\n",
"from numpyro.infer import MCMC, NUTS\n",
"from numpy.random import default_rng\n",
"from numpy.typing import NDArray\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"\n",
"def t_cdf_approx(df: NDArray[float] | float, t: NDArray[float] | float):\n",
" a = df - 1 / 2\n",
" b = 48 * a**2\n",
" # Add epsilon to avoid undefined gradient at 0\n",
" z = jnp.sqrt(a * jnp.log(1 + t**2 / df) + 1e-24)\n",
" u = (\n",
" z\n",
" + (z**3 + 3 * z) / b\n",
" - (4 * z**7 + 33 * z**5 + 240 * z**3 + 855 * z) / (10 * b * (b + 0.8 * z**4 + 100))\n",
" )\n",
" return Normal(loc=0, scale=1).cdf(u * jnp.sign(t))\n",
"\n",
"class SkewMultivariateStudentT(Distribution): # type: ignore # pylint: disable=too-many-instance-attributes\n",
" arg_constraints = {\n",
" \"df\": constraints.positive,\n",
" \"loc\": constraints.real_vector,\n",
" \"scale_tril\": constraints.lower_cholesky,\n",
" \"skewers\": constraints.real_vector,\n",
" }\n",
" support = constraints.real_vector\n",
" reparametrized_params = [\"df\", \"loc\", \"scale_tril\", \"skewers\"]\n",
"\n",
" def __init__( # pylint: disable=too-many-arguments\n",
" self,\n",
" df: float,\n",
" loc: NDArray[float],\n",
" scale_tril: NDArray[float],\n",
" skewers: NDArray[float],\n",
" validate_args: None = None,\n",
" ):\n",
" batch_shape = lax.broadcast_shapes(\n",
" jnp.shape(df), jnp.shape(loc)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]\n",
" )\n",
" (self.df,) = promote_shapes(df, shape=batch_shape)\n",
" (self.loc,) = promote_shapes(loc, shape=batch_shape + loc.shape[-1:])\n",
" (self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])\n",
" (self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])\n",
"\n",
" self._width = scale_tril.shape[-1]\n",
"\n",
" self._mv_t = MultivariateStudentT(df=df, scale_tril=scale_tril, loc=loc)\n",
" eye = jnp.broadcast_to(jnp.eye(self._width), shape=batch_shape + scale_tril.shape[-2:])\n",
" prec_scale_tril = jnp.linalg.cholesky(cho_solve((self.scale_tril, True), eye))\n",
" self.prec = jnp.einsum(\"...ij,...hj->...ih\", prec_scale_tril, prec_scale_tril)\n",
" self._std_devs = jnp.sqrt(jnp.sum(self.scale_tril * self.scale_tril, axis=-1))\n",
"\n",
" event_shape = jnp.shape(self.scale_tril)[-1:]\n",
" super().__init__(\n",
" batch_shape=batch_shape,\n",
" event_shape=event_shape,\n",
" validate_args=validate_args,\n",
" )\n",
" @validate_sample\n",
" def log_prob(self, value: NDArray[float]) -> NDArray[float]:\n",
" distance = value - self.loc\n",
" Qy = jnp.einsum(\"...j,...jk,...k->...\", distance, self.prec, distance)\n",
" df_term = jnp.sqrt((self.df + self._width) / (Qy + self.df))\n",
" distance_term = distance / self._std_devs * df_term[..., jnp.newaxis]\n",
" x = jnp.squeeze(self.skewers @ distance_term[..., jnp.newaxis])\n",
" skew = t_cdf_approx(self.df + self._width, x)\n",
" return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)\n",
"\n",
"shape = (120, 5)\n",
"length, width = shape\n",
"\n",
"t = SkewMultivariateStudentT(df=3.0, loc=jnp.zeros(width), scale_tril=jnp.eye(width), skewers=jnp.zeros(width))\n",
"batch_t = SkewMultivariateStudentT(\n",
" df=3.0,\n",
" loc=jnp.zeros(width),\n",
" scale_tril=jnp.repeat(jnp.eye(width)[jnp.newaxis, ...], length, axis=0),\n",
" skewers=jnp.zeros(width),\n",
")\n",
"\n",
"def log_prob_wrapper(dist: Distribution[NDArray[float]]):\n",
" @jax.jit\n",
" def log_prob(xs: NDArray[float]):\n",
" return t.log_prob(xs)\n",
" ones = jnp.ones(shape)\n",
" return timeit(lambda: log_prob(ones), lambda: log_prob(ones), number=100_000)\n",
"\n",
"BaseLatents = TypedDict(\"BaseLatents\", {\"cov_chol\": NDArray[float], \"loc\": NDArray[float], \"df\": float})\n",
"\n",
"def base_model(width: int) -> BaseLatents:\n",
" variances = numpyro.sample(\"variances\", dist.HalfNormal(1).expand((width,)))\n",
" corr_chol = numpyro.sample(\"correlation\", dist.LKJCholesky(width, concentration=1))\n",
" cov_chol = jnp.sqrt(variances)[..., jnp.newaxis] * corr_chol\n",
"\n",
" loc = numpyro.sample(\"loc\", dist.Normal(loc=0, scale=1).expand((width,)))\n",
" df = numpyro.sample(\"df\", dist.Gamma(concentration=2, rate=0.1)) + 2\n",
"\n",
" return {\"cov_chol\": cov_chol, \"loc\": loc, \"df\": df}\n",
"\n",
"def skew_model(shape: tuple[int, int], observed: NDArray[float]) -> None:\n",
" length, width = shape\n",
" base_latents = base_model(width)\n",
"\n",
" with numpyro.plate(\"data\", length):\n",
" numpyro.sample(\n",
" \"obs\",\n",
" SkewMultivariateStudentT(\n",
" df=base_latents[\"df\"],\n",
" scale_tril=base_latents[\"cov_chol\"],\n",
" loc=base_latents[\"loc\"],\n",
" skewers=jnp.zeros(width),\n",
" ),\n",
" obs=observed,\n",
" )\n",
"\n",
"def batch_skew_model(shape: tuple[int, int], observed: NDArray[float]) -> None:\n",
" length, width = shape\n",
" base_latents = base_model(width)\n",
"\n",
" with numpyro.plate(\"data\", length):\n",
" numpyro.sample(\n",
" \"obs\",\n",
" SkewMultivariateStudentT(\n",
" df=base_latents[\"df\"],\n",
" scale_tril=jnp.repeat(base_latents[\"cov_chol\"][jnp.newaxis, ...], repeats=length, axis=0),\n",
" loc=base_latents[\"loc\"],\n",
" skewers=jnp.zeros(width),\n",
" ),\n",
" obs=observed,\n",
" )\n",
"\n",
"def batch_t_model(shape: tuple[int, int], observed: NDArray[float]) -> None:\n",
" length, width = shape\n",
" base_latents = base_model(width)\n",
"\n",
" with numpyro.plate(\"data\", length):\n",
" numpyro.sample(\n",
" \"obs\",\n",
" dist.MultivariateStudentT(\n",
" df=base_latents[\"df\"],\n",
" scale_tril=jnp.repeat(base_latents[\"cov_chol\"][jnp.newaxis, ...], repeats=length, axis=0),\n",
" loc=base_latents[\"loc\"],\n",
" ),\n",
" obs=observed,\n",
" )\n",
"\n",
"rng = default_rng(1234)\n",
"\n",
"def mcmc_wrapper(model: Callable[..., None]):\n",
" mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=500, chain_method=\"parallel\", progress_bar=False)\n",
"\n",
" @jax.jit\n",
" def run(data: NDArray[float]):\n",
" mcmc.run(PRNGKey(123), shape=shape, observed=data)\n",
" return mcmc.get_samples()\n",
" data = rng.standard_normal(shape)\n",
" return timeit(lambda: run(data), setup=lambda: run(data), number=20)\n",
"\n",
"def jaxpr_wrapper(model: Callable[..., None]) -> ClosedJaxpr:\n",
" def run(data: NDArray[float]):\n",
" return log_prob_from_trace(\n",
" seed(model, PRNGKey(0)), model_args=(), model_kwargs={\"shape\": shape, \"observed\": data}, params={}\n",
" )\n",
" data = rng.standard_normal(shape)\n",
" return jax.make_jaxpr(run)(data)\n",
"\n",
"def log_prob_from_trace(\n",
" model: Callable[..., None],\n",
" model_args: tuple[Any, ...],\n",
" model_kwargs: Mapping[str, Any],\n",
" params: Mapping[str, NDArray[float]],\n",
"):\n",
" site = numpyro.handlers.trace(model).get_trace(*model_args, **model_kwargs)[\"obs\"]\n",
" assert site[\"type\"] == \"sample\"\n",
" return site[\"fn\"].log_prob(site[\"value\"])\n",
"\n",
"def log_prob_from_trace_wrapper(model: Callable[..., None]):\n",
" @jax.jit\n",
" def run(data: NDArray[float]):\n",
" return log_prob_from_trace(\n",
" seed(model, PRNGKey(0)), model_args=(), model_kwargs={\"shape\": shape, \"observed\": data}, params={}\n",
" )\n",
" data = rng.standard_normal(shape)\n",
" return timeit(lambda: run(data), setup=lambda: run(data), number=10_000)\n",
"\n",
"def log_prob_from_trace_no_jit_wrapper(model: Callable[..., None]):\n",
" def run(data: NDArray[float]):\n",
" return log_prob_from_trace(\n",
" seed(model, PRNGKey(0)), model_args=(), model_kwargs={\"shape\": shape, \"observed\": data}, params={}\n",
" )\n",
" data = rng.standard_normal(shape)\n",
" return timeit(lambda: run(data), setup=lambda: run(data), number=100)\n",
"\n",
"print(\"log_prob_skew\", log_prob_wrapper(t))\n",
"print(\"log_prob_batch_skew\", log_prob_wrapper(batch_t))\n",
"\n",
"print(\"mcmc_skew\", mcmc_wrapper(skew_model))\n",
"print(\"mcmc_batch_t\", mcmc_wrapper(batch_t_model))\n",
"print(\"mcmc_batch_skew\", mcmc_wrapper(batch_skew_model))\n",
"\n",
"print(\"jaxpr_skew\", jaxpr_wrapper(skew_model).pretty_print().count(\"\\n\"))\n",
"print(\"jaxpr_batch_t\", jaxpr_wrapper(batch_t_model).pretty_print().count(\"\\n\"))\n",
"print(\"jaxpr_batch_skew\", jaxpr_wrapper(batch_skew_model).pretty_print().count(\"\\n\"))\n",
"\n",
"print(\"log_prob_from_trace_skew\", log_prob_from_trace_wrapper(skew_model))\n",
"print(\"log_prob_from_trace_batch_t\", log_prob_from_trace_wrapper(batch_t_model))\n",
"print(\"log_prob_from_trace_batch_skew\", log_prob_from_trace_wrapper(batch_skew_model))\n",
"\n",
"print(\"log_prob_from_trace_no_jit_skew\", log_prob_from_trace_no_jit_wrapper(skew_model))\n",
"print(\"log_prob_from_trace_no_jit_batch_t\", log_prob_from_trace_no_jit_wrapper(batch_t_model))\n",
"print(\"log_prob_from_trace_no_jit_batch_skew\", log_prob_from_trace_no_jit_wrapper(batch_skew_model))\n"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment