Skip to content

Instantly share code, notes, and snippets.

@ririw
Last active February 15, 2022 20:47
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ririw/2e3a4415dc8271bd2d132c476b98b567 to your computer and use it in GitHub Desktop.
Save ririw/2e3a4415dc8271bd2d132c476b98b567 to your computer and use it in GitHub Desktop.
PYMC3 Zero truncated poisson distribution
import pymc3 as pm
from pymc3.distributions.dist_math import bound, logpow, factln
from pymc3.distributions import draw_values, generate_samples
import theano.tensor as tt
import numpy as np
import scipy.stats.distributions
class ZTP(pm.Discrete):
def __init__(self, mu, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mode = tt.minimum(tt.floor(mu).astype('int32'), 1)
self.mu = mu = tt.as_tensor_variable(mu)
def zpt_cdf(self, mu, size=None):
mu = np.asarray(mu)
dist = scipy.stats.distributions.poisson(mu)
lower_cdf = dist.cdf(0)
upper_cdf = 1
nrm = upper_cdf - lower_cdf
sample = np.random.rand(size) * nrm + lower_cdf
return dist.ppf(sample).astype('int64') # Thanks to @omrihar for this fix!
def random(self, point=None, size=None, repeat=None):
mu = draw_values([self.mu], point=point)
return generate_samples(self.zpt_cdf, mu,
dist_shape=self.shape,
size=size)
def logp(self, value):
mu = self.mu
# mu^k
# PDF = ------------
# k! (e^mu - 1)
# log(PDF) = log(mu^k) - (log(k!) + log(e^mu - 1))
#
# See https://en.wikipedia.org/wiki/Zero-truncated_Poisson_distribution
p = logpow(mu, value) - (factln(value) + pm.math.log(pm.math.exp(mu)-1))
log_prob = bound(
p,
mu >= 0, value >= 0)
# Return zero when mu and value are both zero
return tt.switch(1 * tt.eq(mu, 0) * tt.eq(value, 0),
0, log_prob)
@omrihar
Copy link

omrihar commented Apr 9, 2021

Hi @ririw, thanks for this class and the blogpost it accompanies!
It helped me a lot.

If I can make one suggestion - I noticed when performing prior predictive checks that this class misbehaves when plotting, because it does not output samples that are integers but rather floats. Digging a little deeper I figured out that while poisson.rvs() returns int64 types, poisson.ppf() returns float64.

There is an easy fix though. Just add .astype('int64') to the end of line 23, and everything works nicely with arviz :)

@ririw
Copy link
Author

ririw commented Apr 11, 2021

Thanks! I've fixed it up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment