Skip to content

Instantly share code, notes, and snippets.

@jrnold
Created August 17, 2012 08:22
Show Gist options
  • Save jrnold/3376975 to your computer and use it in GitHub Desktop.
Save jrnold/3376975 to your computer and use it in GitHub Desktop.
compiling stan model as shared object
// Code generated by Stan version 1.0
#include <stan/model/model_header.hpp>
namespace foo_namespace {
using std::vector;
using std::string;
using std::stringstream;
using stan::agrad::var;
using stan::model::prob_grad_ad;
using stan::math::get_base1;
using stan::io::dump;
using std::istream;
using namespace stan::math;
using namespace stan::prob;
using namespace stan::agrad;
typedef Eigen::Matrix<double,Eigen::Dynamic,1> vector_d;
typedef Eigen::Matrix<double,1,Eigen::Dynamic> row_vector_d;
typedef Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> matrix_d;
typedef Eigen::Matrix<stan::agrad::var,Eigen::Dynamic,1> vector_v;
typedef Eigen::Matrix<stan::agrad::var,1,Eigen::Dynamic> row_vector_v;
typedef Eigen::Matrix<stan::agrad::var,Eigen::Dynamic,Eigen::Dynamic> matrix_v;
class foo : public prob_grad_ad {
private:
public:
foo(stan::io::var_context& context__)
: prob_grad_ad::prob_grad_ad(0) {
static const char* function__ = "foo_namespace::foo(%1%)";
size_t pos__;
std::vector<int> vals_i__;
std::vector<double> vals_r__;
// validate data
// validate transformed data
set_param_ranges();
} // dump ctor
void set_param_ranges() {
num_params_r__ = 0U;
param_ranges_i__.clear();
++num_params_r__;
}
void transform_inits(const stan::io::var_context& var_context__,
std::vector<int>& params_i__,
std::vector<double>& params_r__) {
params_r__.clear();
params_i__.clear();
stan::io::writer<double> writer__(params_r__,params_i__);
size_t pos__;
std::vector<double> vals_r__;
std::vector<int> vals_i__;
if (!(var_context__.contains_r("y")))
throw std::runtime_error("variable y missing");
if (var_context__.dims_r("y").size() != 0)
throw std::runtime_error("require 0 dimensions for variable y");
vals_r__ = var_context__.vals_r("y");
pos__ = 0U;
double y(0);
y = vals_r__[pos__++];
writer__.scalar_unconstrain(y);
params_r__ = writer__.data_r();
params_i__ = writer__.data_i();
}
var log_prob(vector<var>& params_r__,
vector<int>& params_i__) {
var lp__(0.0);
// model parameters
stan::io::reader<var> in__(params_r__,params_i__);
var y = in__.scalar_constrain(lp__);
// transformed parameters
// validate transformed parameters
// model body
lp__ += stan::prob::normal_log<true>(y, 0, 1);
return lp__;
} // log_prob()
void get_param_names(std::vector<std::string>& names__) {
names__.resize(0);
names__.push_back("y");
}
void get_dims(std::vector<std::vector<size_t> >& dimss__) {
dimss__.resize(0);
std::vector<size_t> dims__;
dims__.resize(0);
dimss__.push_back(dims__);
}
void write_array(std::vector<double>& params_r__,
std::vector<int>& params_i__,
std::vector<double>& vars__) {
vars__.resize(0);
stan::io::reader<double> in__(params_r__,params_i__);
static const char* function__ = "foo_namespace::write_array(%1%)";
// read-transform, write parameters
double y = in__.scalar_constrain();
vars__.push_back(y);
// declare and define transformed parameters
double lp__ = 0.0;
// validate transformed parameters
// write transformed parameters
// declare and define generated quantities
// validate generated quantities
// write generated quantities
}
void write_csv_header(std::ostream& o__) {
stan::io::csv_writer writer__(o__);
writer__.comma();
o__ << "y";
writer__.newline();
}
void write_csv(std::vector<double>& params_r__,
std::vector<int>& params_i__,
std::ostream& o__) {
stan::io::reader<double> in__(params_r__,params_i__);
stan::io::csv_writer writer__(o__);
static const char* function__ = "foo_namespace::write_csv(%1%)";
// read-transform, write parameters
double y = in__.scalar_constrain();
writer__.write(y);
// declare, define and validate transformed parameters
double lp__ = 0.0;
// write transformed parameters
// declare and define generated quantities
// validate generated quantities
// write generated quantities
writer__.newline();
}
}; // model
} // namespace
int main(int argc, const char* argv[]) {
try {
stan::gm::nuts_command<foo_namespace::foo>(argc,argv);
} catch (std::exception& e) {
std::cerr << std::endl << "Exception: " << e.what() << std::endl;
std::cerr << "Diagnostic information: " << std::endl << boost::diagnostic_information(e) << std::endl;
return -1;
}
}
parameters {
real y;
}
model {
y ~ normal(0, 1);
}
library(rstan)
library(stringr)
stan_shared <- function(cppfile, outfile, verbose=TRUE) {
cppcode <- readLines(cppfile)
## get model name from the cpp code
## Assumes only 1 class in the cpp file
model.name <- na.omit(str_match(cppcode, "class\\s+(.*?)\\s+:"))[ , 2]
## Patch cpp code to be used with Rcpp
newcppcode <-
paste("#include <rstan/rstaninc.hpp>",
paste(cppcode, collapse="\n"),
rstan:::get_Rcpp_module_def_code(model.name),
sep = "\n")
## Write out new cpp code to a tempfile because
newcppfile <- tempfile(fileext=".cpp")
writeLines(newcppcode, newcppfile)
## Before compiling Set all the environment variables for linking,
## etc. This is what inline:::cxxfunction appears to do
settings <- getPlugin("rstan")
### Copied from Rcpp::cxxfunction
if (!is.null(env <- settings$env)) {
do.call(Sys.setenv, env)
if (isTRUE(verbose)) {
cat(" >> setting environment variables: \n")
writeLines(sprintf("%s = %s", names(env), env))
}
}
LinkingTo <- settings$LinkingTo
if (!is.null(LinkingTo)) {
paths <- .find.package(LinkingTo, quiet = TRUE)
if (length(paths)) {
flag <- paste(paste0("-I\"", paths, "/include\""),
collapse = " ")
Sys.setenv(CLINK_CPPFLAGS = flag)
if (isTRUE(verbose)) {
cat(sprintf("\n >> LinkingTo : %s\n", paste(LinkingTo,
collapse = ", ")))
cat("CLINK_CPPFLAGS = ", flag, "\n\n")
}
}
}
### End of cxxfunction copied
## Also see inline:::compileCode for some platform indep hacks.
## Use R CMD SHLIB to compile
## Could also do R CMD COMPILE and then R CMD SHLIB
cmd <- sprintf("-o %s %s", shQuote(outfile), shQuote(newcppfile))
R <- file.path(R.home(component = "bin"), "R")
system2(R, c("CMD", "SHLIB",
sprintf("-o %s", shQuote(outfile)),
shQuote(newcppfile)))
}
## Compile foo.cpp into a shared object foo.so
## This will need to be generalized for cross platform
stan_shared("foo.cpp", "foo.so")
## Load the so
dyn.load('foo.so')
## foo.so should appear
getLoadedDLLs()
## See section 3.3 of "Exposing C++ functions and classes with Rcpp modules"
## Load Rcpp Module
mod <- Module("foo", getDynLib('foo'))
foo <- `$`(mod, "foo")
stanmodel_object <- new(foo)
##' List classes within a module
##' see getMethods("show", "Module") from which this was extracted
##' as far as I know, there is no public facing code to view the classes / functions in
##' a Module
module_classes <- function(object) {
pointer <- Rcpp:::.getModulePointer(object, FALSE)
if (identical(pointer, Rcpp:::.badModulePointer)) {
object <- as.environment(object)
txt <- sprintf("Uninitialized module named \"%s\" from package \"%s\"",
get("moduleName", envir = object), get("packageName",
envir = object))
writeLines(txt)
} else {
info <- .Call(Rcpp:::Module__classes_info, pointer)
names(info)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment