Skip to content

Instantly share code, notes, and snippets.

@hrlai
Last active July 17, 2024 00:10
Show Gist options
  • Save hrlai/04e4045c58842b1508cd1ceac1be2972 to your computer and use it in GitHub Desktop.
Save hrlai/04e4045c58842b1508cd1ceac1be2972 to your computer and use it in GitHub Desktop.
Obtain LOOIC or WAIC from greta
#' Obtain LOO-IC or WAIC from \code{greta} Output
#'
#' @param observed Vector of the observed response variable
#' @param posterior MCMC posterior from a \code{greta} model
#' @param mean Name of the mean/location parameter, defaults to \code{mu}
#' @param scale Name of the scale/variance parameter, defaults to \code{sd}
#' @param family Distribution of the observational model. Currently, only
#' \code{normal}, \code{lognormal}, and \code{student} are implemented
#' @param method
#' \code{"loo"} to calculate the Leave-One-Out Information Criteria or
#' \code{"waic"} for the Widely Applicable Information Criteria in the \code{loo} package
#' @param moment_match Logical, whether to use moment_match for loo.
#' @param ... additional arguments passed to distribution functions
#'
#' @return LOO-IC or WAIC; see \code{?loo::loo} for more details.
#'
#' @examples
#' obs <- iris$Sepal.Length
#' loo_greta(obs, posterior_lognormal, mean = "mean", scale = "sd",
#' family = "lognormal", method = "loo")
#' loo_greta(obs, posterior_normal, mean = "mean", scale = "sd",
#' family = "normal", method = "loo")
#'
#' @import loo
#' @export
loo_greta <- function(observed,
posterior,
mean = "mu",
scale = "sd",
family = c("normal", "student"),
method = c("loo", "waic"),
moment_match = FALSE,
...) {
# check that arguments are consistent with the implemented options
family <- match.arg(family)
method <- match.arg(method)
# extract the relative variables locally
mu <- posterior[[mean]]
sd <- posterior[[scale]]
nsim <- nrow(mu)
# determine which density function should be used
# dfunc <- switch(family,
# normal = dnorm,
# student = dlst
# )
# generate a matrix of log-likelihood values for each posterior predicted value
if (family == "normal") {
LL_mat <- t(sapply(
seq_len(nsim),
function(i, mu, sd){
dnorm(observed, mu[i,,], sd[i,,], log = TRUE, ...)
},
mu = mu,
sd = sd
))
} else if (family == "student") {
LL_mat <- t(sapply(
seq_len(nsim),
function(i, mu, sd){
dlst(x = observed, df = 2, mu = mu[i,,], sigma = sd[i,,], log = TRUE, ...)
},
mu = mu,
sd = sd
))
}
# compute the relative efficiencies of the posterior predicted values
rel_eff <- relative_eff(exp(LL_mat), chain_id = rep(1, nsim))
# estimate the LOO-IC, WAIC, etc
out <- switch(method,
loo = loo(LL_mat, r_eff = rel_eff, moment_match = moment_match),
waic = waic(LL_mat, r_eff = rel_eff)
)
return(out)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment