Last active
December 28, 2021 23:26
-
-
Save omarfsosa/e418b63d5d4f093e393aeafc5076202d to your computer and use it in GitHub Desktop.
Temme approximation for Poisson inverse CDF
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Approximate inverse CDF for the poisson distribution in Jax. | |
Based on the approximation method proposed in [1]. | |
References | |
---------- | |
[1]: https://people.maths.ox.ac.uk/gilesm/codes/poissinv/paper.pdf | |
""" | |
import jax | |
import jax.numpy as jnp | |
from jax.scipy.special import ndtri | |
from jax.scipy.stats import poisson | |
@jax.jit | |
def _f(r): | |
sign = jnp.sign(r - 1) | |
sqrt = jnp.sqrt(2 * (1 - r + r * jnp.log(r))) | |
return sign * sqrt | |
@jax.jit | |
@jax.vmap | |
def approx_inv_f(y): | |
""" | |
Taylor approximation to the inverse of `_f` up to order 10 | |
""" | |
a = 0.38965499961165273634331739362539203787 | |
b = -0.55394297489909075530074661825619291878 | |
above = ( | |
+ 1.41421356237309504880168872420969807857 | |
+ 1.12430667119464473014907337153145075765 * (y - a) | |
+ 0.1531719705475007453738542225714086365 * (y - a) ** 2 | |
- 0.00963534867560723115436808341214909258 * (y - a) ** 3 | |
+ 0.00199843564334916907185984291067380747 * (y - a) ** 4 | |
- 0.0005604127451221559799488967766232013 * (y - a) ** 5 | |
+ 0.00018374201212080445371642159413933106 * (y - a) ** 6 | |
- 0.00006638816844820539098080138927093808 * (y - a) ** 7 | |
+ 0.00002563763078265987339961466411981206 * (y - a) ** 8 | |
- 0.00001039337704951179380953670412494928 * (y - a) ** 9 | |
+ 4.37242293758710967917333950699327e-6 * (y - a) ** 10 | |
) | |
below = ( | |
0.5 | |
+ 0.79917078282219776763162980202313238014 * (y - b) ** 1 | |
+ 0.20006420570681594195022346298336739573 * (y - b) ** 2 | |
- 0.02957827705804215416248662471355374274 * (y - b) ** 3 | |
+ 0.0131778297383059328355669790592304563 * (y - b) ** 4 | |
- 0.0078595297071194875108074176691291394 * (y - b) ** 5 | |
+ 0.00545781700824722781820501229768977374 * (y - b) ** 6 | |
- 0.00416696194941366171881547621363511748 * (y - b) ** 7 | |
+ 0.00339537485270363975244716047986888621 * (y - b) ** 8 | |
- 0.00290139838458528912391935122813000491 * (y - b) ** 9 | |
+ 0.00257093985067377249483640946187906634 * (y - b) ** 10 | |
) | |
return jnp.where(y < 0, below, above) | |
@jax.jit | |
def c_zero(r): | |
num = jnp.log(_f(r) * jnp.sqrt(r) / (r - 1)) | |
den = jnp.log(r) | |
return num / den | |
@jax.jit | |
def upper_c(x, lam): | |
eta = jnp.sqrt(2 * (-1 - jnp.log(lam / x) + (lam / x))) | |
return ndtr(-jnp.sqrt(x) * eta) | |
# @jax.jit | |
# def normal_approximation_q2(w, lam): | |
# q1 = lam + jnp.sqrt(lam) * w + (1/3 + 1/6 * w ** 2) | |
# q2 = q1 + lam ** (-0.5) * (-1/36 * w - 1/72 * w ** 3) | |
# delta = (1/40 + 1/80 * w ** 2 + 1/160 * w ** 4) / lam | |
# return q2, delta | |
@jax.jit | |
def temme_approximation(w, lam): | |
r = approx_inv_f(w / jnp.sqrt(lam)) # TODO: Clip at 0? | |
x = lam * r + c_zero(r) | |
empirical_correction = - 0.0218 / (x + 0.065 * lam) | |
x = x + empirical_correction | |
delta = 0.01 / lam | |
return x, delta | |
def table_lookup(u, lam): | |
""" | |
Best for when lam < 4. | |
""" | |
k = jnp.arange(30) | |
cdf = poisson.cdf(k, lam) | |
return jnp.searchsorted(cdf, u) | |
def inverse_poisson_cdf(u, lam): | |
w = ndtri(u) | |
x, d = temme_approximation(w, lam) | |
n = jnp.floor(x + d) | |
condition1 = x - n > d | |
condition2 = upper_c(x, lam) < u | |
n = jnp.where(condition1 | condition2, n, n - 1) | |
return jnp.where(lam < 4, table_lookup(u, lam), n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment