Skip to content

Instantly share code, notes, and snippets.

@bbbales2
Created December 15, 2020 21:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bbbales2/347ea12273ed90b1df35d55cf247989e to your computer and use it in GitHub Desktop.
Save bbbales2/347ea12273ed90b1df35d55cf247989e to your computer and use it in GitHub Desktop.
Benchmark scalar read/write of varmat/matvar
#include <benchmark/benchmark.h>
#include <stan/math.hpp>
#include <utility>
static void toss_me(benchmark::State& state) {
using stan::math::var;
Eigen::Matrix<double, -1, -1> x_vals = Eigen::MatrixXd::Random(256, 256);
Eigen::Matrix<double, -1, -1> y_vals = Eigen::MatrixXd::Random(256, 256);
using stan::math::var;
using stan::math::sum;
Eigen::Matrix<var, -1, -1> x = x_vals;
Eigen::Matrix<var, -1, -1> y = y_vals;
var lp = 0;
lp -= sum((multiply(x, y) + x).eval());
benchmark::DoNotOptimize(lp.vi_);
for (auto _ : state) {
lp.grad();
benchmark::ClobberMemory();
stan::math::set_zero_all_adjoints();
}
stan::math::recover_memory();
}
static void read_scalar_matvar(benchmark::State& state) {
using stan::math::var;
for (auto _ : state) {
Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1> vector =
Eigen::VectorXd::Random(state.range(0));
auto start = std::chrono::high_resolution_clock::now();
stan::math::var lp = 0.0;
for(size_t i = 0; i < vector.size(); ++i)
lp += vector.coeffRef(i);
lp.grad();
auto end = std::chrono::high_resolution_clock::now();
auto elapsed_seconds =
std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
state.SetIterationTime(elapsed_seconds.count());
stan::math::recover_memory();
benchmark::ClobberMemory();
}
}
static void read_scalar_varmat(benchmark::State& state) {
using stan::math::var;
for (auto _ : state) {
stan::math::var_value<Eigen::MatrixXd> vector =
Eigen::VectorXd::Random(state.range(0));
auto start = std::chrono::high_resolution_clock::now();
stan::math::var lp = 0.0;
for(size_t i = 0; i < vector.size(); ++i)
lp += vector.coeffRef(i);
lp.grad();
auto end = std::chrono::high_resolution_clock::now();
auto elapsed_seconds =
std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
state.SetIterationTime(elapsed_seconds.count());
stan::math::recover_memory();
benchmark::ClobberMemory();
}
}
static void write_scalar_matvar(benchmark::State& state) {
using stan::math::var;
for (auto _ : state) {
Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1> vector =
Eigen::VectorXd::Random(state.range(0));
auto start = std::chrono::high_resolution_clock::now();
stan::math::var scalar = 1.0;
for(size_t i = 0; i < vector.size(); ++i)
vector.coeffRef(i) = scalar + scalar;
stan::math::grad();
auto end = std::chrono::high_resolution_clock::now();
auto elapsed_seconds =
std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
state.SetIterationTime(elapsed_seconds.count());
stan::math::recover_memory();
benchmark::ClobberMemory();
}
}
static void write_scalar_varmat(benchmark::State& state) {
using stan::math::var;
for (auto _ : state) {
stan::math::var_value<Eigen::MatrixXd> vector =
Eigen::VectorXd::Random(state.range(0));
auto start = std::chrono::high_resolution_clock::now();
stan::math::var scalar = 1.0;
for(size_t i = 0; i < vector.size(); ++i)
vector.coeffRef(i) = scalar + scalar;
stan::math::grad();
auto end = std::chrono::high_resolution_clock::now();
auto elapsed_seconds =
std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
state.SetIterationTime(elapsed_seconds.count());
stan::math::recover_memory();
benchmark::ClobberMemory();
}
}
// The start and ending sizes for the benchmark
int start_val = 4;
int end_val = 1024 * 1024;
BENCHMARK(toss_me);
BENCHMARK(read_scalar_matvar)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK(read_scalar_varmat)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK(write_scalar_matvar)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK(write_scalar_varmat)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK_MAIN();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment