Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Last active July 31, 2022 17:36
Show Gist options
  • Save mtreviso/9724d40cbca3d6ded0f4501113d0d4f7 to your computer and use it in GitHub Desktop.
Save mtreviso/9724d40cbca3d6ded0f4501113d0d4f7 to your computer and use it in GitHub Desktop.
Kuma and HardKuma distributions in JAX using distrax
"""Adapted from https://github.com/bastings/interpretable_predictions"""
import math
import distrax
import jax
import jax.numpy as jnp
EPS = 1e-6
@jax.jit
def hard_tanh(x, min_val=-1.0, max_val=1.0):
return jnp.where(x > 1, max_val, jnp.where(x < -1, min_val, x))
@jax.jit
def lbeta(x):
x_abs = jnp.abs(x)
log_prod_gamma_x = jax.lax.lgamma(x_abs).sum(-1)
log_gamma_sum_x = jax.lax.lgamma(x_abs.sum(-1))
return log_prod_gamma_x - log_gamma_sum_x
@jax.jit
def _harmonic_number(x):
"""
Compute the harmonic number from its analytic continuation.
"""
one = jnp.ones(1)
return jax.lax.digamma(x + one) - jax.lax.digamma(one)
@jax.jit
def kuma_mean(a, b):
"""
Computes the mean of Kumaraswamy using kuma_moments
"""
return kuma_moments(a, b, 1)
@jax.jit
def kuma_moments(a, b, n):
"""
Computes nth moment of Kumaraswamy using jax.lax.lgamma
"""
arg1 = 1 + n / a
log_value = jax.lax.lgamma(jnp.abs(arg1))
log_value += jax.lax.lgamma(jnp.abs(b))
log_value -= jax.lax.lgamma(jnp.abs(arg1 + b))
return b * jnp.exp(log_value)
class Kuma(distrax.Distribution):
"""
A Kumaraswamy, or Kuma for short, is similar to a Beta distribution, though not an exponential family.
Kuma variables are specified by two shape parameters, similar to Beta, though for settings that typically
yield a symmetric Beta won't necessarily yield a symmetric Kuma.
X ~ Kuma(a,b)
where a, b > 0
Or equivalently,
U ~ U(0,1)
x = (1 - (1 - u)^(1/b))^(1/a)
In practice we sample from U(0 + eps, 1 - eps) for some small positive constant eps to avoid instabilities.
"""
def __init__(self, params: list):
self.a = params[0]
self.b = params[1]
def params(self):
return [self.a, self.b]
def mean(self):
return kuma_moments(self.a, self.b, 1)
@property
def event_shape(self):
return ()
@property
def batch_shape(self):
return self.a.shape
def _sample_n(self, key, n, eps=0.001):
shape = [n] + list(self.a.shape)
u = jax.random.uniform(key, shape=shape, minval=eps, maxval=1.0-eps)
return (1.0 - (1 - u) ** jnp.reciprocal(self.b)) ** jnp.reciprocal(self.a)
def log_prob(self, x):
"""
Kuma(x|a, b) = U(s(x)|0, 1) |det J_s|
where x = t(u) and u = s(x) and J_s is the Jacobian matrix of s(x)
"""
t1 = jnp.log(self.a) + jnp.log(self.b)
t2 = (self.a - 1.0 + EPS) * jnp.log(x)
pow_x_a = (x ** self.a) + EPS
t3b = jnp.log(1.0 - pow_x_a)
t3 = (self.b - 1.0 + EPS) * t3b
return t1 + t2 + t3
def log_cdf(self, x):
r = 1.0 - ((1.0 - (x ** self.a)) ** self.b)
r = jnp.log(r + EPS)
return jax.lax.clamp(math.log(EPS), r, math.log(1 - EPS))
class StretchedVariable(distrax.Distribution):
"""
A continuous variable over the open interval [left, right].
X ~ StretchedVariable(RelaxedBinary, [left, right])
left < 0 and right > 1
Or equivalently,
Y ~ RelaxedBinary()
x = location + y * scale
where location = left
and scale = right - left
"""
def __init__(self, dist: distrax.Distribution, support: list):
"""
:param dist: a RelaxedBinary variable (e.g. BinaryConcrete or Kuma)
:param support: a pair specifying the limits of the stretched support (e.g. [-1, 2])
we use these values to compute location = pair[0] and scale = pair[1] - pair[0]
"""
assert support[0] < support[1], "I need an ordered support, got %s" % support
self._dist = dist
self.loc = support[0]
self.scale = support[1] - support[0]
def params(self):
return self._dist.params()
@property
def event_shape(self):
return self._dist.event_shape
@property
def batch_shape(self):
return self._dist.batch_shape
def _sample_n(self, key, n, eps=0.001):
# sample a relaxed binary variable
x_ = self._dist._sample_n(key, n, eps=eps)
# and stretch it
return x_ * self.scale + self.loc
def log_prob(self, x):
# shrink the stretched variable
x_ = (x - self.loc) / self.scale
# and assess the stretched pdf using the original pdf
# see eq 25 (left) of Louizos et al
return self._dist.log_prob(x_) - jnp.log(self.scale)
def log_cdf(self, x):
# shrink the stretched variable
x_ = (x - self.loc) / self.scale
# assess its cdf
# see eq 25 (right) of Louizos et al
r = self._dist.log_cdf(x_)
return jax.lax.clamp(math.log(EPS), r, math.log(1 - EPS))
class HardBinary(distrax.Distribution):
"""
A continuous variable over the closed interval [0, 1] which can assign non-zero probability mass
to {0} and {1} (which are sets of zero measure in a standard RelaxedBinary or StretchedVariable).
X ~ HardBinary(StretchedVariable)
Or equivalently,
Y ~ StretchedVariable()
x = hardsigmoid(y)
"""
def __init__(self, dist: StretchedVariable):
self._dist = dist
@property
def event_shape(self):
return self._dist.event_shape
@property
def batch_shape(self):
return self._dist.batch_shape
def _sample_n(self, key, n, eps=0.001):
# sample a stretched variable and rectify it
x_ = self._dist._sample_n(key, n, eps=eps)
return hard_tanh(x_, min_val=0.0, max_val=1.0)
def log_prob(self, x):
"""
We obtain pdf(0) by integrating the stretched variable over the interval [left, 0]
HardBinary.pdf(0) = StretchedVariable.cdf(0)
and pdf(1) by integrating the stretched variable over the interval [1, right], or equivalently,
HardBinary.pdf(1) = 1 - StretchedVariable.cdf(1)
finally, for values in the open (0, 1) we scale the pdf of the stretched variable by the remaining probability
mass HardBinary.pdf(x) = StretchedVariable.pdf(x) * (1 - HardBinary.pdf(0) - HardBinary.pdf(1))
See that the total mass over the discrete set {0, 1} is
HardBinary.pdf(0) + HardBinary.pdf(1)
in other words, with this probability we will be sampling a discrete value.
Whenever this probability is greater than 0.5, most probability mass is away from continuous samples.
"""
# cache these for faster computation
log_cdf_0 = self._dist.log_cdf(jnp.zeros(1))
cdf_1 = self._dist.cdf(jnp.ones(1))
# first we fix log_pdf for 0s and 1s
# log Q(0) # log (1-Q(1))
log_p = jnp.where(x == 0.0, log_cdf_0, jnp.log(1.0 - cdf_1))
# then for those that are in the open (0, 1)
log_p = jnp.where((0.0 < x) & (x < 1.0), self._dist.log_prob(x), log_p)
# see eq 26 of Louizos et al
return log_p
def log_cdf(self, x):
"""
Note that HardKuma.cdf(0) = HardKuma.pdf(0) by definition of HardKuma.pdf(0),
also note that HardKuma.cdf(1) = 1 by definition because
the support of HardKuma is the *closed* interval [0, 1]
and not the open interval (left, right) which is the support of the stretched variable.
"""
# all of the mass
log_c = jnp.where(x < 1.0, self._dist.log_cdf(x), 0)
return jax.lax.clamp(math.log(EPS), log_c, math.log(1 - EPS))
class HardKuma(HardBinary):
def __init__(self, params: list, support: list):
super().__init__(StretchedVariable(Kuma(params), support))
# shortcut to underlying a and b
self.a = self._dist._dist.a
self.b = self._dist._dist.b
def mean(self):
return kuma_moments(self.a, self.b, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment