Last active
September 4, 2018 17:22
-
-
Save bbbales2/523db7848df48117dd39788815dfe1a2 to your computer and use it in GitHub Desktop.
Example use of adj_jac_apply to solve Ax = b
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stan/math/rev/core.hpp> | |
#include <test/unit/math/rev/mat/util.hpp> | |
#include <gtest/gtest.h> | |
#include <algorithm> | |
#include <random> | |
#include <sstream> | |
#include <tuple> | |
namespace stan { | |
namespace math { | |
struct SolveFunctor { | |
double* A_llt_mem_; | |
int N_; | |
double* x_mem_; | |
template<size_t size> | |
Eigen::VectorXd operator()(std::array<bool, size> needs_adj, | |
const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>& A, | |
const Eigen::Matrix<double, Eigen::Dynamic, 1>& b) { | |
check_square("SolveFunctor", "A", A); | |
check_size_match("SolveFunctor", "Rows of A", A.rows(), "Rows of b", b.rows()); | |
// Do the solve | |
auto A_llt = A.llt(); | |
Eigen::Matrix<double, Eigen::Dynamic, 1> x = A_llt.solve(b); | |
// Save a copy of the decomposition of A for later | |
A_llt_mem_ | |
= ChainableStack::instance().memalloc_.alloc_array<double>(A.size()); | |
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> > A_llt_matrix(A_llt_mem_, A.rows(), | |
A.cols()); | |
A_llt_matrix.template triangularView<Eigen::Lower>() = A_llt.matrixL(); | |
A_llt_matrix.template triangularView<Eigen::Upper>() = A_llt_matrix.transpose(); | |
N_ = A.rows(); | |
// Save a copy of the solution for later (we only need this if A contains vars) | |
if(needs_adj[0]) { | |
x_mem_ | |
= ChainableStack::instance().memalloc_.alloc_array<double>(x.size()); | |
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, 1> >(x_mem_, x.size()) = x; | |
} | |
return x; | |
} | |
template<size_t size> | |
auto multiply_adjoint_jacobian(const std::array<bool, size>& needs_adj, const Eigen::VectorXd& y_adj) { | |
Eigen::Map<const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> > A_llt(A_llt_mem_, N_, N_); | |
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> adjA; | |
// We gotta compute adjb if either inputs contain vars | |
Eigen::Matrix<double, Eigen::Dynamic, 1> adjb = y_adj; | |
A_llt.template triangularView<Eigen::Lower>().solveInPlace(adjb); | |
A_llt.template triangularView<Eigen::Upper>().solveInPlace(adjb); | |
// We only need to compute adjA in certain cases | |
if(needs_adj[0]) { | |
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, 1> > x(x_mem_, N_); | |
adjA = -adjb * x.transpose(); | |
} | |
return std::make_tuple(adjA, adjb); | |
} | |
}; | |
} | |
} | |
TEST(AgradRev, test_custom_solve) { | |
// Going to work with d * A^-1 * b (it'll be a scalar) | |
Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1> b(5); | |
Eigen::Matrix<stan::math::var, Eigen::Dynamic, Eigen::Dynamic> A(5, 5); | |
Eigen::Matrix<stan::math::var, 1, Eigen::Dynamic> d(5); | |
std::mt19937 rng; | |
std::normal_distribution<> dist(0, 1.0); | |
for(int i = 0; i < b.rows(); i++) { | |
b(i) = dist(rng); | |
d(i) = dist(rng); | |
} | |
for(int i = 0; i < A.size(); i++) { | |
A(i) = 5 * dist(rng); | |
} | |
Eigen::Matrix<stan::math::var, Eigen::Dynamic, Eigen::Dynamic> x = | |
stan::math::adj_jac_apply<stan::math::SolveFunctor>(A, b); | |
stan::math::var y = (d * x)(0); | |
y.grad(); | |
Eigen::Matrix<double, Eigen::Dynamic, 1> bd(b.rows()); | |
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> Ad(A.rows(), A.cols()); | |
double v = y.val(); | |
for(int i = 0; i < A.rows(); i++) { | |
bd(i) = b(i).adj(); | |
for(int j = 0; j < A.cols(); j++) { | |
Ad(i, j) = A(i, j).adj(); | |
} | |
} | |
stan::math::set_zero_all_adjoints(); | |
x = stan::math::mdivide_left_spd(A, b); | |
y = (d * x)(0); | |
y.grad(); | |
std::cout.precision(3); | |
std::cout << std::scientific; | |
std::cout << "value, left reference, right adj_jac_apply" << std::endl; | |
std::cout << std::setw(15) << y.val() << " " << std::setw(15) << v << std::endl; | |
std::cout << "adjb, left reference, right adj_jac_apply" << std::endl; | |
for(int i = 0; i < b.rows(); i++) | |
std::cout << std::setw(15) << b(i).adj() << " " << std::setw(15) << bd(i) << std::endl; | |
std::cout << "adjA, left reference, right adj_jac_apply" << std::endl; | |
for(int i = 0; i < A.size(); i++) | |
std::cout << std::setw(15) << A(i).adj() << " " << std::setw(15) << Ad(i) << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment