Created
October 2, 2011 13:19
-
-
Save jrnold/1257444 to your computer and use it in GitHub Desktop.
A general gibbs sampler object using the proto R package.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library("proto") | |
library("mvtnorm") | |
##' Gibbs sampler | |
##' | |
##' General Gibbs sampler object. | |
##' | |
##' TODO | |
GibbsSampler <- proto(expr = { | |
##' Initial Parameter values | |
initpars <- numeric(0) | |
##' Saved parameter values | |
savedpars <- numeric(0) | |
##' Iterations | |
i <- 0 | |
##' Iteration Saved | |
##' | |
##' Number of iterataions saved in savedpars. | |
##' Due to thinning this can be different than \code{i}. | |
itersaved <- 0 | |
##' Random number seed | |
##' | |
##' The random number seed is set for reproducibility. | |
seed <- list(seed = 932987, | |
kind = NULL, | |
normal.kind = NULL) | |
##' Parameter names | |
##' | |
##' The names of paramters to be saved are specified here. | |
pars <- character() | |
##' Metadata about each run of the chain. | |
##' | |
##' This keeps track of the number of runs, their lengths, and their times. | |
runs <- list() | |
##' Convert an object to parameters | |
convert2par <- function(., x) { | |
## using [[ won work because it will only | |
## look within the current object and not parents | |
y <- as.numeric(`$.proto`(., x)) | |
n <- length(y) | |
if (n > 1) { | |
names(y) <- paste(x, seq_len(n), sep="_") | |
} else { | |
names(y) <- x | |
} | |
y | |
} | |
##' Return current parameters | |
get_pars <- function(., pars=.$pars, .as.numeric=TRUE) { | |
if (.as.numeric) { | |
foo <- sapply(pars, function(x) .$convert2par(x)) | |
## Workaround to get the names correct | |
## If .as.numeric then the parameters should be named foo_1, foo_2, etc. | |
names(foo) <- NULL | |
unlist(foo) | |
} else { | |
pardata <- lapply(pars, function(x) `$.proto`(., x)) | |
names(pardata) <- pars | |
pardata | |
} | |
} | |
##' Save parameters | |
save_pars <- function(., pars=.$pars) { | |
parstosave <- .$get_pars(pars=pars, .as.numeric=TRUE) | |
.$savedpars[.$itersaved, ] <- c(.$i, parstosave) | |
parstosave | |
} | |
##' Save initial parameters | |
save_init_pars <- function(., pars=.$pars) { | |
.$initpars <- .$get_pars(pars=pars, .as.numeric=TRUE) | |
} | |
##' Initialize parameters | |
##' If object x in par does not exist, and there exists | |
##' a function init_x then generate that variable | |
initialize_pars <- function(., pars=.$pars) { | |
NULL | |
} | |
##' Inner loop of the Gibbs sampler | |
##' | |
##' The function run inside of .$run(). This is where parameters are updated. | |
sampler <- function(.) NULL | |
##' Main Gibbs sampler loop | |
##' | |
##' Handles the bookkeeping around .$sampler | |
##' This function handles the bookkeeping of running a Gibbs sample. Within | |
##' each iteration it calls, \code{sampler}, which updates the code. | |
run <- function(., mcmc=2000, burnin=0, thin=1, verbose=FALSE, | |
append = FALSE) { | |
## Set seed | |
do.call(set.seed, .$seed) | |
##' initial value of iteration | |
first_iter <- .$i | |
##' Initialize parameters | |
if (!append) { | |
.$initialize_pars() | |
} | |
## Save Initial values | |
.$save_init_pars() | |
ntosave <- floor(mcmc / thin) | |
k <- length(.$get_pars()) | |
## Allocate space to save variables | |
if (!append) { | |
.$savedpars <- matrix(as.numeric(NA), ncol = k + 1, nrow = ntosave) | |
.$itersaved <- 0 | |
} else { | |
.$savedpars <- rbind(.$savedpars, | |
matrix(NA, ncol=k + 1, nrow=ntosave)) | |
} | |
## Time the loops | |
startTime <- proc.time() | |
## Total length of the run | |
totlen <- burnin + mcmc | |
## Progress bar | |
if (verbose) { | |
pb <- txtProgressBar(0, totlen) | |
} | |
for (j in seq_len(totlen)) { | |
## Set iteration value | |
.$i <- .$i + 1 | |
## Run Gibbs sampler | |
.$sampler() | |
## Save parameters | |
if (.$i > burnin && .$i %% thin == 0) { | |
.$itersaved <- .$itersaved + 1 | |
.$save_pars() | |
} | |
## TxtProgress Bar | |
if (verbose) { | |
setTxtProgressBar(pb, j) | |
} | |
## Check whether to continue | |
## This is a hook for convergence checking | |
if (! .$continue()) { | |
break | |
} | |
} | |
## Store the time and number of iterations | |
.$runs <- c(.$runs, | |
list(time = time, | |
mcmc = mcmc, | |
burnin = burnin, | |
thin = thin, | |
append = append, | |
## This may differ from mcmc if continue breaks | |
## the iteration | |
last_iter = .$i, | |
first_iter = first_iter)) | |
## Save the current seed | |
.$seed[["seed"]] <- .Random.seed[1] | |
} | |
## Return parameters as MCMC | |
## Since uniform thinning is not enforced, the time-series parameters | |
## of the mcmc object are dropped. The indices of the iterations are saved in a new | |
## attributes | |
as.mcmc <- function(.) { | |
y <- mcmc(data = .$savedpars[ , -1]) | |
attr(y, "indices") <- .$savedpars[ , 1] | |
y | |
} | |
new <- function(., ...) { | |
.$proto(...) | |
} | |
## Placeholder | |
## Function to continue loop | |
continue <- function(., ...) TRUE | |
## LogLik | |
logLik <- function(., ...) NULL | |
## LogPrior | |
logPrior <- function(., ...) NULL | |
## Log Posterior | |
logPosterior <- function(., ...) .$logPrior(...) + .$logLik(...) | |
}) | |
## Trivial examples just to check things | |
## Bayesian Classical regression. Known variance. | |
## Follows discussion in Petris, Petrone (2009) | |
DlmLinearNormal <- GibbsSampler$new(expr = { | |
##' mu: initial value of the mean | |
##' psi: initial value of the precision | |
##' m0 : mean of prior distribution of beta | |
##' C0 : precision of prior distribution of beta | |
##' V : known precision of the system | |
new <- function(., y, x, V, m0, C0) { | |
V <- as.matrix(V) | |
y <- as.matrix(y) | |
x <- as.matrix(x) | |
m0 <- as.matrix(m0) | |
C0 <- as.matrix(C0) | |
Cn <- C0 + V %*% crossprod(x) | |
mn <- solve(Cn) %*% (V * t(x) %*% y + C0 %*% m0) | |
.$proto(x = x, | |
y = y, | |
V = V, | |
m0 = m0, | |
C0 = C0, | |
Cn = Cn, | |
mn = mn, | |
beta = rep(0, ncol(x)), | |
pars = "beta") | |
} | |
##' draw mu parameter | |
draw_mu <- function(.) { | |
.$beta <- rmvnorm(1, .$mn, solve(.$Cn)) | |
} | |
sampler <- function(.) { | |
.$draw_mu() | |
} | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
continued as a repository here: https://github.com/jrnold/r-proto-gibbs