Skip to content

Instantly share code, notes, and snippets.

@omarfsosa
Last active December 28, 2021 23:26
Show Gist options
  • Save omarfsosa/e418b63d5d4f093e393aeafc5076202d to your computer and use it in GitHub Desktop.
Save omarfsosa/e418b63d5d4f093e393aeafc5076202d to your computer and use it in GitHub Desktop.
Temme approximation for Poisson inverse CDF
"""
Approximate inverse CDF for the poisson distribution in Jax.
Based on the approximation method proposed in [1].
References
----------
[1]: https://people.maths.ox.ac.uk/gilesm/codes/poissinv/paper.pdf
"""
import jax
import jax.numpy as jnp
from jax.scipy.special import ndtri
from jax.scipy.stats import poisson
@jax.jit
def _f(r):
sign = jnp.sign(r - 1)
sqrt = jnp.sqrt(2 * (1 - r + r * jnp.log(r)))
return sign * sqrt
@jax.jit
@jax.vmap
def approx_inv_f(y):
"""
Taylor approximation to the inverse of `_f` up to order 10
"""
a = 0.38965499961165273634331739362539203787
b = -0.55394297489909075530074661825619291878
above = (
+ 1.41421356237309504880168872420969807857
+ 1.12430667119464473014907337153145075765 * (y - a)
+ 0.1531719705475007453738542225714086365 * (y - a) ** 2
- 0.00963534867560723115436808341214909258 * (y - a) ** 3
+ 0.00199843564334916907185984291067380747 * (y - a) ** 4
- 0.0005604127451221559799488967766232013 * (y - a) ** 5
+ 0.00018374201212080445371642159413933106 * (y - a) ** 6
- 0.00006638816844820539098080138927093808 * (y - a) ** 7
+ 0.00002563763078265987339961466411981206 * (y - a) ** 8
- 0.00001039337704951179380953670412494928 * (y - a) ** 9
+ 4.37242293758710967917333950699327e-6 * (y - a) ** 10
)
below = (
0.5
+ 0.79917078282219776763162980202313238014 * (y - b) ** 1
+ 0.20006420570681594195022346298336739573 * (y - b) ** 2
- 0.02957827705804215416248662471355374274 * (y - b) ** 3
+ 0.0131778297383059328355669790592304563 * (y - b) ** 4
- 0.0078595297071194875108074176691291394 * (y - b) ** 5
+ 0.00545781700824722781820501229768977374 * (y - b) ** 6
- 0.00416696194941366171881547621363511748 * (y - b) ** 7
+ 0.00339537485270363975244716047986888621 * (y - b) ** 8
- 0.00290139838458528912391935122813000491 * (y - b) ** 9
+ 0.00257093985067377249483640946187906634 * (y - b) ** 10
)
return jnp.where(y < 0, below, above)
@jax.jit
def c_zero(r):
num = jnp.log(_f(r) * jnp.sqrt(r) / (r - 1))
den = jnp.log(r)
return num / den
@jax.jit
def upper_c(x, lam):
eta = jnp.sqrt(2 * (-1 - jnp.log(lam / x) + (lam / x)))
return ndtr(-jnp.sqrt(x) * eta)
# @jax.jit
# def normal_approximation_q2(w, lam):
# q1 = lam + jnp.sqrt(lam) * w + (1/3 + 1/6 * w ** 2)
# q2 = q1 + lam ** (-0.5) * (-1/36 * w - 1/72 * w ** 3)
# delta = (1/40 + 1/80 * w ** 2 + 1/160 * w ** 4) / lam
# return q2, delta
@jax.jit
def temme_approximation(w, lam):
r = approx_inv_f(w / jnp.sqrt(lam)) # TODO: Clip at 0?
x = lam * r + c_zero(r)
empirical_correction = - 0.0218 / (x + 0.065 * lam)
x = x + empirical_correction
delta = 0.01 / lam
return x, delta
def table_lookup(u, lam):
"""
Best for when lam < 4.
"""
k = jnp.arange(30)
cdf = poisson.cdf(k, lam)
return jnp.searchsorted(cdf, u)
def inverse_poisson_cdf(u, lam):
w = ndtri(u)
x, d = temme_approximation(w, lam)
n = jnp.floor(x + d)
condition1 = x - n > d
condition2 = upper_c(x, lam) < u
n = jnp.where(condition1 | condition2, n, n - 1)
return jnp.where(lam < 4, table_lookup(u, lam), n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment