Last active
July 17, 2024 00:10
-
-
Save hrlai/04e4045c58842b1508cd1ceac1be2972 to your computer and use it in GitHub Desktop.
Obtain LOOIC or WAIC from greta
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Obtain LOO-IC or WAIC from \code{greta} Output | |
#' | |
#' @param observed Vector of the observed response variable | |
#' @param posterior MCMC posterior from a \code{greta} model | |
#' @param mean Name of the mean/location parameter, defaults to \code{mu} | |
#' @param scale Name of the scale/variance parameter, defaults to \code{sd} | |
#' @param family Distribution of the observational model. Currently, only | |
#' \code{normal}, \code{lognormal}, and \code{student} are implemented | |
#' @param method | |
#' \code{"loo"} to calculate the Leave-One-Out Information Criteria or | |
#' \code{"waic"} for the Widely Applicable Information Criteria in the \code{loo} package | |
#' @param moment_match Logical, whether to use moment_match for loo. | |
#' @param ... additional arguments passed to distribution functions | |
#' | |
#' @return LOO-IC or WAIC; see \code{?loo::loo} for more details. | |
#' | |
#' @examples | |
#' obs <- iris$Sepal.Length | |
#' loo_greta(obs, posterior_lognormal, mean = "mean", scale = "sd", | |
#' family = "lognormal", method = "loo") | |
#' loo_greta(obs, posterior_normal, mean = "mean", scale = "sd", | |
#' family = "normal", method = "loo") | |
#' | |
#' @import loo | |
#' @export | |
loo_greta <- function(observed, | |
posterior, | |
mean = "mu", | |
scale = "sd", | |
family = c("normal", "student"), | |
method = c("loo", "waic"), | |
moment_match = FALSE, | |
...) { | |
# check that arguments are consistent with the implemented options | |
family <- match.arg(family) | |
method <- match.arg(method) | |
# extract the relative variables locally | |
mu <- posterior[[mean]] | |
sd <- posterior[[scale]] | |
nsim <- nrow(mu) | |
# determine which density function should be used | |
# dfunc <- switch(family, | |
# normal = dnorm, | |
# student = dlst | |
# ) | |
# generate a matrix of log-likelihood values for each posterior predicted value | |
if (family == "normal") { | |
LL_mat <- t(sapply( | |
seq_len(nsim), | |
function(i, mu, sd){ | |
dnorm(observed, mu[i,,], sd[i,,], log = TRUE, ...) | |
}, | |
mu = mu, | |
sd = sd | |
)) | |
} else if (family == "student") { | |
LL_mat <- t(sapply( | |
seq_len(nsim), | |
function(i, mu, sd){ | |
dlst(x = observed, df = 2, mu = mu[i,,], sigma = sd[i,,], log = TRUE, ...) | |
}, | |
mu = mu, | |
sd = sd | |
)) | |
} | |
# compute the relative efficiencies of the posterior predicted values | |
rel_eff <- relative_eff(exp(LL_mat), chain_id = rep(1, nsim)) | |
# estimate the LOO-IC, WAIC, etc | |
out <- switch(method, | |
loo = loo(LL_mat, r_eff = rel_eff, moment_match = moment_match), | |
waic = waic(LL_mat, r_eff = rel_eff) | |
) | |
return(out) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment