Skip to content

Instantly share code, notes, and snippets.

@ririw
Last active February 15, 2022 20:47
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ririw/2e3a4415dc8271bd2d132c476b98b567 to your computer and use it in GitHub Desktop.
Save ririw/2e3a4415dc8271bd2d132c476b98b567 to your computer and use it in GitHub Desktop.
PYMC3 Zero truncated poisson distribution
import pymc3 as pm
from pymc3.distributions.dist_math import bound, logpow, factln
from pymc3.distributions import draw_values, generate_samples
import theano.tensor as tt
import numpy as np
import scipy.stats.distributions
class ZTP(pm.Discrete):
def __init__(self, mu, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mode = tt.minimum(tt.floor(mu).astype('int32'), 1)
self.mu = mu = tt.as_tensor_variable(mu)
def zpt_cdf(self, mu, size=None):
mu = np.asarray(mu)
dist = scipy.stats.distributions.poisson(mu)
lower_cdf = dist.cdf(0)
upper_cdf = 1
nrm = upper_cdf - lower_cdf
sample = np.random.rand(size) * nrm + lower_cdf
return dist.ppf(sample).astype('int64') # Thanks to @omrihar for this fix!
def random(self, point=None, size=None, repeat=None):
mu = draw_values([self.mu], point=point)
return generate_samples(self.zpt_cdf, mu,
dist_shape=self.shape,
size=size)
def logp(self, value):
mu = self.mu
# mu^k
# PDF = ------------
# k! (e^mu - 1)
# log(PDF) = log(mu^k) - (log(k!) + log(e^mu - 1))
#
# See https://en.wikipedia.org/wiki/Zero-truncated_Poisson_distribution
p = logpow(mu, value) - (factln(value) + pm.math.log(pm.math.exp(mu)-1))
log_prob = bound(
p,
mu >= 0, value >= 0)
# Return zero when mu and value are both zero
return tt.switch(1 * tt.eq(mu, 0) * tt.eq(value, 0),
0, log_prob)
@ririw
Copy link
Author

ririw commented Apr 11, 2021

Thanks! I've fixed it up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment