Skip to content

Instantly share code, notes, and snippets.

@goldingn
Last active April 27, 2018 07:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goldingn/4192ddff71037ca3f18927b8d273cd2a to your computer and use it in GitHub Desktop.
Save goldingn/4192ddff71037ca3f18927b8d273cd2a to your computer and use it in GitHub Desktop.
prototype of helper functions for marginalising a Poisson random variable in a greta model
# marginalise over a Poisson random variable in a greta model
# likelihood function must be a function taking a single value of N (drawn from
# N ~ Poisson(lambda)), and returning a distribution. Lambda is a (possibly
# variable) scalar greta array for the rate of the poisson distribution. max_n
# is a scalar positive integer giving the maximum value of N to consider when
# marginalising the Poisson distribution
marginal_poisson <- function (likelihood_function, lambda, max_n) {
n_seq <- seq_len(max_n)
wt <- poisson_weights(n_seq, lambda)
dists <- lapply(n_seq, likelihood_function)
do.call(mixture, c(dists, list(weights = wt)))
}
# given a positive integer vector Ns of values of N to consider, and a scalar
# greta array lambda giving a Poisson rate parameter, return the vector of
# values of the poisson PMF evaluated at Ns
poisson_weights <- function (Ns, lambda) {
Ns_factorial <- factorial(Ns)
(lambda ^ Ns) * exp(-lambda) / Ns_factorial
}
# simulate from a latent poisson model:
# y ~ bernoulli(p)
# p = 0.9 ^ N
# N ~ poisson(5)
set.seed(123)
n <- 1000
N <- rpois(1, 5)
p <- 0.9 ^ N
y <- rbinom(n, 1, p)
# empirical estimate
log(mean(y)) / log(0.9)
N
library (greta)
lambda <- lognormal(0, 1)
# define the model likelihood conditional on N
likelihood <- function (N)
bernoulli(0.9 ^ N)
# use the marginalised likelihood
distribution(y) <- marginal_poisson(likelihood,
lambda,
max_n = 10)
m <- model(lambda)
draws <- mcmc(m, n_samples = 300)
# plot the posteriors over the weights for reasonable values of N
weights <- poisson_weights(2:6, lambda)
weights_draws <- calculate(weights, draws)
bayesplot::mcmc_intervals(weights_draws)
# get posterior samples of N
N_draws <- rpois(length(draws[[1]]),
draws[[1]][, "lambda"])
hist(N_draws)
summary(N_draws)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment