Skip to content

Instantly share code, notes, and snippets.

@goldingn
Last active February 20, 2018 05:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goldingn/79859cbdd4496f64015fb9c21244d259 to your computer and use it in GitHub Desktop.
Save goldingn/79859cbdd4496f64015fb9c21244d259 to your computer and use it in GitHub Desktop.
hack greta v0.2.4 to use tensorflow HMC
# get greta working with bayesflow's HMC implementation & working via
# tensorflow's run syntax
build_function <- function (dag) {
# temporarily pass float type info to options, so it can be accessed by
# nodes on definition, without clunky explicit passing
old_float_type <- options()$greta_tf_float
on.exit(options(greta_tf_float = old_float_type))
options(greta_tf_float = dag$tf_float)
# define all nodes
dag$on_graph(lapply(dag$node_list,
function (x) x$define_tf(dag)))
# define an overall log density with adjustment
dag$on_graph(dag$define_joint_density())
}
foo <- function (free_state) {
# flush the tf environment
vals <- ls(dag$tf_environment)
vals <- vals[!vals %in% c("config", "sess", "n_cores")]
rm(list = vals, envir = dag$tf_environment)
# split up the free state into free state variables
nelem <- dim(free_state)[1]
params <- dag$parameters_example
lengths <- vapply(params,
function (x) as.integer(prod(dim(x))),
FUN.VALUE = 1L)
args <- dag$on_graph(tf$split(free_state, lengths))
# put these tensors in the tf environment, with the correct names
names <- paste0(names(params), "_free")
for (i in seq_along(names))
assign(names[i], args[[i]], envir = dag$tf_environment)
# define the functions in the tf* environment, except for the free state
# variables
build_function(dag)
# return the log density
dag$tf_environment$joint_density_adj
}
# temporary hack:
# Hack to replace the tf() method in variable nodes.
# In full implementation, add an option to skip defining the free state variable
replace_tf <- function (variable_node) {
# replace the tf function
new_tf <- function (dag) {
tf_name <- dag$tf_name(self)
free_name <- sprintf('%s_free', tf_name)
# get the log jacobian adjustment for the free state
tf_adj <- self$tf_adjustment(dag)
adj_name <- sprintf('%s_adj', tf_name)
assign(adj_name,
tf_adj,
envir = dag$tf_environment)
# map from the free to constrained state in a new tensor
tf_free <- get(free_name, envir = dag$tf_environment)
node <- self$tf_from_free(tf_free, dag$tf_environment)
assign(tf_name,
node,
envir = dag$tf_environment)
}
environment(new_tf) <- variable_node$.__enclos_env__
unlockBinding("tf", variable_node)
variable_node$tf <- new_tf
lockBinding("tf", variable_node)
invisible (NULL)
}
# /temporary hack
run_mcmc <- function (i = 1, targets) {
names <- names(targets)
for (i in seq_along(names))
assign(names[i], targets[[i]])
# build up the call as text :|
call_text <- sprintf("greta::model(%s, n_cores = 1)",
paste(names, collapse = ", "))
model <- eval(parse(text = call_text))
dag <- model$dag
# make sure this dag is in scope of foo
environment(foo)$dag <- dag
# modify nodes so that variable nodes don't define their free states
# (temporary hack)
nodes <- dag$node_list
which_are_variables <- dag$node_types == "variable"
lapply(nodes[which_are_variables], replace_tf)
# create the free state
nelem <- length(dag$example_parameters())
free_state <- dag$on_graph(tf$zeros(nelem))
# variable step size to be tuned externally
ss <- dag$on_graph(tf$constant(0.05))
lf <- dag$on_graph(tf$constant(15L))
draws_tensor <- dag$on_graph(
tf$contrib$bayesflow$hmc$chain(10000, ss, lf,
free_state,
foo,
event_dims = 0L)
)
dag$tf_run(sess$run(tf$global_variables_initializer()))
draws <- dag$tf_environment$sess$run(draws_tensor,
feed_dict = dict(ss = 0.05, lf = 15L))
draws
}
library(tensorflow)
library (greta)
# define a model
z <- normal(0, 1, 2)
x <- normal(0, 1, 3)
model <- model(x, z, n_cores = 1)
targets <- model$target_greta_arrays
# run multiple mcmc chains in parallel, each with tensorflow HMC
library (snowfall)
sfInit(parallel = TRUE, cpus = 2)
sfLibrary(greta)
sfLibrary(tensorflow)
sfExport("foo", "replace_tf", "build_function")
# 2x 10,000 iterations in ~ 18 seconds
system.time(draws <- sfLapply(1:2, run_mcmc, model$target_greta_arrays))
plot(draws[[1]][[1]][, 4], type = "l")
lines(draws[[2]][[1]][, 4], col = "light blue")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment