Skip to content

Instantly share code, notes, and snippets.

@SteveBronder
Created October 4, 2022 16:14
Show Gist options
  • Save SteveBronder/37013c9f1a5a5c17390b08327c20a28b to your computer and use it in GitHub Desktop.
Save SteveBronder/37013c9f1a5a5c17390b08327c20a28b to your computer and use it in GitHub Desktop.
#include <stan/services/sample/hmc_nuts_diag_e_adapt.hpp>
#include <gtest/gtest.h>
#include <stan/io/empty_var_context.hpp>
#include <test/test-models/good/tester.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <iostream>
class ServicesSampleHmcNutsDiagEAdapt : public testing::Test {
public:
ServicesSampleHmcNutsDiagEAdapt() : model(context, 0, &model_log) {}
std::stringstream model_log;
stan::test::unit::instrumented_logger logger;
stan::test::unit::instrumented_writer init, parameter, diagnostic;
stan::io::empty_var_context context;
stan_model model;
};
template <typename M>
void log_density_gradient(M&& model, bool propto, bool jacobian,
const double* theta_unc, double* val,
double* grad) {
static thread_local stan::math::ChainableStack thread_instance;
auto logp = [&model](auto& x) {
return model.template log_prob<true, true, stan::math::var>(x);
};
int N = 5;
Eigen::VectorXd params_unc = Eigen::VectorXd::Map(theta_unc, N);
Eigen::VectorXd grad_vec(N);
stan::math::gradient(logp, params_unc, val, grad, grad + N);
Eigen::VectorXd::Map(grad, N) = grad_vec;
}
TEST_F(ServicesSampleHmcNutsDiagEAdapt, call_count) {
int param_unc_num_ = 5;
int N = param_unc_num_;
double val = 0;
double theta_unc[5];
theta_unc[0] = 1;
theta_unc[1] = 2;
theta_unc[2] = 3;
theta_unc[3] = 4;
theta_unc[4] = 5;
double grad_vs[5];
Eigen::VectorXd params_unc = Eigen::VectorXd::Map(theta_unc, N);
log_density_gradient(model, true, true, theta_unc, &val, grad_vs);
Eigen::Map<Eigen::VectorXd> grad_vec = Eigen::Map<Eigen::VectorXd>(grad_vs, N);
std::cout << "val: " << val << std::endl;
std::cout << "\ngrad:\n" << grad_vec << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment