Skip to content

Instantly share code, notes, and snippets.

@bbbales2
Last active September 4, 2018 17:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bbbales2/523db7848df48117dd39788815dfe1a2 to your computer and use it in GitHub Desktop.
Save bbbales2/523db7848df48117dd39788815dfe1a2 to your computer and use it in GitHub Desktop.
Example use of adj_jac_apply to solve Ax = b
#include <stan/math/rev/core.hpp>
#include <test/unit/math/rev/mat/util.hpp>
#include <gtest/gtest.h>
#include <algorithm>
#include <random>
#include <sstream>
#include <tuple>
namespace stan {
namespace math {
struct SolveFunctor {
double* A_llt_mem_;
int N_;
double* x_mem_;
template<size_t size>
Eigen::VectorXd operator()(std::array<bool, size> needs_adj,
const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>& A,
const Eigen::Matrix<double, Eigen::Dynamic, 1>& b) {
check_square("SolveFunctor", "A", A);
check_size_match("SolveFunctor", "Rows of A", A.rows(), "Rows of b", b.rows());
// Do the solve
auto A_llt = A.llt();
Eigen::Matrix<double, Eigen::Dynamic, 1> x = A_llt.solve(b);
// Save a copy of the decomposition of A for later
A_llt_mem_
= ChainableStack::instance().memalloc_.alloc_array<double>(A.size());
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> > A_llt_matrix(A_llt_mem_, A.rows(),
A.cols());
A_llt_matrix.template triangularView<Eigen::Lower>() = A_llt.matrixL();
A_llt_matrix.template triangularView<Eigen::Upper>() = A_llt_matrix.transpose();
N_ = A.rows();
// Save a copy of the solution for later (we only need this if A contains vars)
if(needs_adj[0]) {
x_mem_
= ChainableStack::instance().memalloc_.alloc_array<double>(x.size());
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, 1> >(x_mem_, x.size()) = x;
}
return x;
}
template<size_t size>
auto multiply_adjoint_jacobian(const std::array<bool, size>& needs_adj, const Eigen::VectorXd& y_adj) {
Eigen::Map<const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> > A_llt(A_llt_mem_, N_, N_);
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> adjA;
// We gotta compute adjb if either inputs contain vars
Eigen::Matrix<double, Eigen::Dynamic, 1> adjb = y_adj;
A_llt.template triangularView<Eigen::Lower>().solveInPlace(adjb);
A_llt.template triangularView<Eigen::Upper>().solveInPlace(adjb);
// We only need to compute adjA in certain cases
if(needs_adj[0]) {
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, 1> > x(x_mem_, N_);
adjA = -adjb * x.transpose();
}
return std::make_tuple(adjA, adjb);
}
};
}
}
TEST(AgradRev, test_custom_solve) {
// Going to work with d * A^-1 * b (it'll be a scalar)
Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1> b(5);
Eigen::Matrix<stan::math::var, Eigen::Dynamic, Eigen::Dynamic> A(5, 5);
Eigen::Matrix<stan::math::var, 1, Eigen::Dynamic> d(5);
std::mt19937 rng;
std::normal_distribution<> dist(0, 1.0);
for(int i = 0; i < b.rows(); i++) {
b(i) = dist(rng);
d(i) = dist(rng);
}
for(int i = 0; i < A.size(); i++) {
A(i) = 5 * dist(rng);
}
Eigen::Matrix<stan::math::var, Eigen::Dynamic, Eigen::Dynamic> x =
stan::math::adj_jac_apply<stan::math::SolveFunctor>(A, b);
stan::math::var y = (d * x)(0);
y.grad();
Eigen::Matrix<double, Eigen::Dynamic, 1> bd(b.rows());
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> Ad(A.rows(), A.cols());
double v = y.val();
for(int i = 0; i < A.rows(); i++) {
bd(i) = b(i).adj();
for(int j = 0; j < A.cols(); j++) {
Ad(i, j) = A(i, j).adj();
}
}
stan::math::set_zero_all_adjoints();
x = stan::math::mdivide_left_spd(A, b);
y = (d * x)(0);
y.grad();
std::cout.precision(3);
std::cout << std::scientific;
std::cout << "value, left reference, right adj_jac_apply" << std::endl;
std::cout << std::setw(15) << y.val() << " " << std::setw(15) << v << std::endl;
std::cout << "adjb, left reference, right adj_jac_apply" << std::endl;
for(int i = 0; i < b.rows(); i++)
std::cout << std::setw(15) << b(i).adj() << " " << std::setw(15) << bd(i) << std::endl;
std::cout << "adjA, left reference, right adj_jac_apply" << std::endl;
for(int i = 0; i < A.size(); i++)
std::cout << std::setw(15) << A(i).adj() << " " << std::setw(15) << Ad(i) << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment