Skip to content

Instantly share code, notes, and snippets.

@dfm
Last active July 13, 2021 16:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dfm/6b7f9167ce08284721d8c11c4019c12a to your computer and use it in GitHub Desktop.
Save dfm/6b7f9167ce08284721d8c11c4019c12a to your computer and use it in GitHub Desktop.
Noncentral chi squared distribution
import tensorflow_probability.substrates.jax as tfp
import jax.numpy as jnp
import jax.scipy as jsp
from jax import lax
import jax.random as random
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample
def _random_chi2(key, df, shape=(), dtype=jnp.float_):
return 2.0 * random.gamma(key, 0.5 * df, shape=shape, dtype=dtype)
class NoncentralChi2(Distribution):
arg_constraints = {
"df": constraints.positive,
"nc": constraints.positive,
}
support = constraints.positive
reparametrized_params = ["df", "nc"]
def __init__(self, df, nc, validate_args=None):
self.df, self.nc = promote_shapes(df, nc)
batch_shape = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nc))
super(NoncentralChi2, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)
def sample(self, key, sample_shape=()):
# Ref: https://github.com/numpy/numpy/blob/89c80ba606f4346f8df2a31cfcc0e967045a68ed/numpy/random/src/distributions/distributions.c#L797-L813
assert is_prng_key(key)
shape = sample_shape + self.batch_shape + self.event_shape
key1, key2, key3 = random.split(key, 3)
i = random.poisson(key1, 0.5 * self.nc, shape=shape)
n = random.normal(key2, shape=shape) + jnp.sqrt(self.nc)
cond = jnp.greater(self.df, 1.0)
chi2 = _random_chi2(key3, jnp.where(cond, self.df - 1.0, self.df + 2.0 * i), shape=shape)
return jnp.where(cond, chi2 + n * n, chi2)
@validate_sample
def log_prob(self, value):
# Ref: https://github.com/scipy/scipy/blob/500878e88eacddc7edba93dda7d9ee5f784e50e6/scipy/stats/_distn_infrastructure.py#L597-L610
df2 = self.df / 2.0 - 1.0
xs, ns = jnp.sqrt(value), jnp.sqrt(self.nc)
res = jsp.special.xlogy(df2 / 2.0, value / self.nc) - 0.5 * (xs - ns) ** 2
corr = tfp.math.bessel_ive(df2, xs * ns) / 2.0
return jnp.where(
jnp.greater(corr, 0.0),
res + jnp.log(corr),
-jnp.inf,
)
@property
def mean(self):
return self.df + self.nc
@property
def variance(self):
return 2.0 * (self.df + 2.0 * self.nc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment