-
-
Save nsiccha/9c38758cf7ad754de1fb92701cc12cba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_DG_HPP | |
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_DG_HPP | |
#include <stan/math/prim/meta.hpp> | |
#include <stan/math/prim/err.hpp> | |
#include <stan/math/prim/fun/size.hpp> | |
#include <stan/math/prim/functor/apply.hpp> | |
#include <stan/math/prim/fun/Eigen.hpp> | |
#include <stan/math/rev/core/var.hpp> | |
#include <stan/math/rev/fun/from_var_value.hpp> | |
#include <stan/math/fwd/functor/jacobian.hpp> | |
#include <stan/math/rev/fun/adjoint_of.hpp> | |
#include <stan/math/rev/core/nested_rev_autodiff.hpp> | |
#include <stan/math/prim/functor/ode_dg_quadrature.hpp> | |
#include <stan/math/mix/meta.hpp> | |
#include <stan/math/mix/fun.hpp> | |
#include <stan/math/mix/functor.hpp> | |
#include <stan/math/fwd/fun/pow.hpp> | |
#include <stan/math/rev/fun/pow.hpp> | |
#include <stan/math/prim/fun/abs.hpp> | |
#include <stan/math/prim/fun/max.hpp> | |
// #include <stan/math/mix/core/std_complex.hpp> | |
// #include <stan/math/prim/core/complex_base.hpp> | |
// #include <stan/math | |
// #include <stan/math/mix.hpp> | |
// #include <iostream> | |
namespace stan { | |
namespace math { | |
// fvar<var> fmax(double x, const fvar<var>& y){ | |
// return (x < y) ? y : fvar<var>(x); | |
// } | |
inline fvar<var> square(fvar<var> x){ | |
return fvar<var>(square(x.val_), x.d_ * 2 * x.val_); | |
} | |
inline fvar<var> cos(const fvar<var>& x) { | |
using std::cos; | |
using std::sin; | |
return fvar<var>(cos(x.val_), x.d_ * -sin(x.val_)); | |
} | |
inline fvar<var> sin(const fvar<var>& x) { | |
using std::cos; | |
using std::sin; | |
return fvar<var>(sin(x.val_), x.d_ * cos(x.val_)); | |
} | |
// auto pow(const fvar<var>& x, const fvar<var>& y){return pow(x,y);} | |
inline fvar<var> pow(const fvar<var>& x1, const fvar<var>& x2) { | |
using std::log; | |
using std::pow; | |
var pow_x1_x2(pow(x1.val_, x2.val_)); | |
return fvar<var>(pow_x1_x2, (x2.d_ * log(x1.val_) + x2.val_ * x1.d_ / x1.val_) | |
* pow_x1_x2); | |
} | |
// template <class T> void my_real(T) = delete; | |
// auto my_pow(double x, double y){return std::pow(x,y);} | |
// auto my_pow(const fvar<double>& x, const fvar<double>& y){return pow(x,y);} | |
// auto my_pow(const var& x, const var& y){return pow(x,y);} | |
// // auto my_pow(const fvar<var>& x, double y){return pow(x,y);} | |
// auto my_pow(const fvar<var>& x, const fvar<var>& y){return pow(x,y);} | |
// template <typename X> | |
// auto my_pow(X x, X y){return pow(x,y);} | |
// auto my_real(var x){return x;} | |
// auto my_real(std::complex<var> x){return x.real();} | |
// auto my_real(fvar<var> x){return x;} | |
// template <typename T> auto my_real(T x){return x.hello();} | |
template <typename T, typename V> | |
std::enable_if_t< | |
std::is_same<scalar_type_t<T>, double>::value, | |
void | |
> safe_adjoint_increment(const T& x, const V& inc){} | |
template <typename T, typename V> | |
std::enable_if_t< | |
std::is_same<scalar_type_t<T>, var>::value, | |
void | |
> safe_adjoint_increment(T& x, const V& inc){ | |
x.adj() += inc; | |
} | |
template < | |
typename F, | |
typename Y, | |
typename P | |
> | |
struct ode_solver{ | |
typedef Eigen::Matrix<double, Eigen::Dynamic, 1> vector_type; | |
typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> matrix_type; | |
typedef Eigen::Matrix<var, Eigen::Dynamic, 1> var_vector_type; | |
typedef Eigen::Matrix<var, Eigen::Dynamic, Eigen::Dynamic> var_matrix_type; | |
typedef Eigen::Matrix<fvar<var>, Eigen::Dynamic, 1> fvarvar_vector_type; | |
typedef Eigen::Matrix<fvar<var>, Eigen::Dynamic, 1> fvar_var_vector_type; | |
typedef scalar_type_t<Y> scalar_type_y; | |
typedef scalar_type_t<P> scalar_type_p; | |
static constexpr bool implemented = ( | |
std::is_same<scalar_type_y, double>::value || std::is_same<scalar_type_y, var>::value | |
) && ( | |
std::is_same<scalar_type_p, double>::value || std::is_same<scalar_type_p, var>::value | |
); | |
static constexpr bool parameter_sensitivities = ( | |
std::is_same<scalar_type_p, var>::value | |
); | |
static constexpr bool initial_sensitivities = ( | |
std::is_same<scalar_type_y, var>::value | |
); | |
static constexpr bool reverse_mode = ( | |
parameter_sensitivities || initial_sensitivities | |
// std::is_same<scalar_type_y, var>::value || std::is_same<scalar_type_p, var>::value | |
); | |
typedef std::conditional_t< | |
reverse_mode, var_vector_type, vector_type | |
> y_type; | |
typedef std::vector<y_type> return_type; | |
typedef Eigen::HouseholderQR<matrix_type> qr_type; | |
template <typename T> | |
using carena_t = std::conditional_t<reverse_mode, arena_t<T>, T>; | |
typedef carena_t<vector_type> cvector_type; | |
typedef carena_t<matrix_type> cmatrix_type; | |
typedef carena_t<var_vector_type> cvar_vector_type; | |
typedef carena_t<var_matrix_type> cvar_matrix_type; | |
typedef Eigen::Map<vector_type> vector_map_type; | |
typedef Eigen::Map<matrix_type> matrix_map_type; | |
typedef Eigen::Map<const matrix_type> const_matrix_map_type; | |
typedef Eigen::HouseholderQR<cmatrix_type> cqr_type; | |
// Inputs | |
const F& f_; | |
double t0_; | |
carena_t<const Y&> y0_; | |
carena_t<const std::vector<double>&> ts_; | |
carena_t<const P&> params_; | |
const cvector_type val_params_; | |
// Outputs | |
// MEMORY LEAK?! | |
return_type ys_; | |
return_type ys() const {return ys_;} | |
// carena_t<return_type> ys_; | |
// return_type ys() const { | |
// return_type rv; | |
// rv.reserve(ys_.size()); | |
// for(const auto& y : ys_){ | |
// rv.emplace_back(y); | |
// } | |
// return rv; | |
// } | |
// Forward/Reverse working memory | |
const int no_states_; | |
const int no_dofs_; | |
const int no_iterations_; | |
const bool approximate_; | |
const ode_dg_quadrature& q_; | |
double h_; | |
double Kh_; | |
double Kt_; | |
double tl_; | |
double tr_; | |
cvector_type yl_; | |
cvector_type yr_; | |
cvector_type fl_; | |
cmatrix_type dfl_; | |
cvector_type coefficients_; | |
cvector_type rhs_; | |
int no_blocks_; | |
carena_t<std::vector<matrix_type>> Ks_; | |
// MEMORY LEAK?! | |
std::vector<qr_type> qr_Ks_; | |
// Scratch memory for temporaries | |
cvector_type tmp_y_; | |
cmatrix_type tmp_coefficients_; | |
cmatrix_type tmp_coefficients2_; | |
// cvar_matrix_type tmp_rev_dfl_; | |
cvar_matrix_type tmp_rev_coefficients_; | |
cvar_matrix_type tmp_rev_coefficients2_; | |
// Reverse cache + working memory | |
const int no_timesteps_; | |
int cache_idx_; | |
carena_t<std::vector<int>> Ky_idx_cache_; | |
// const int cache_size_; | |
// carena_t<std::vector<vector_type>> y_cache_; | |
cvector_type t_cache_; | |
cmatrix_type y_cache_; | |
cmatrix_type coefficients_cache_; | |
cvector_type adj_yl_; | |
cvector_type adj_fl_; | |
cmatrix_type adj_dfl_; | |
cvector_type adj_params_; | |
cvector_type adj_coefficients_; | |
cvector_type adj_rhs_; | |
template<typename T> | |
matrix_map_type as_matrix(T &x){ | |
return matrix_map_type(x.data(), no_states_, no_dofs_); | |
} | |
template<typename T> | |
const_matrix_map_type as_matrix(const T &x){ | |
return const_matrix_map_type(x.data(), no_states_, no_dofs_); | |
} | |
template<typename T> | |
vector_map_type as_vector(T &x){ | |
return vector_map_type(x.data(), no_states_ * no_dofs_); | |
} | |
ode_solver( | |
const F& f, | |
double t0, | |
const Y& y0, | |
const std::vector<double>& ts, | |
const P& params, | |
int no_dofs, | |
double h, | |
double Kh, | |
int no_iterations, | |
bool approximate | |
) : | |
f_(f), t0_(t0), y0_(y0), ts_(to_arena_if<reverse_mode>(ts)), params_(params), | |
val_params_(value_of(params)), | |
no_states_(y0.size()), no_dofs_(no_dofs), | |
no_iterations_(no_iterations), approximate_(approximate), | |
q_(ode_dg_quadrature::instance(no_dofs)), | |
h_(h), Kh_(Kh), | |
tl_(t0), tr_(t0), yl_(value_of(y0)), yr_(value_of(y0)), | |
fl_(no_states_), dfl_(no_states_, no_states_), | |
coefficients_(no_dofs_ * no_states_), | |
rhs_(no_dofs_ * no_states_), | |
no_blocks_(1+(no_dofs - 1) / 2), | |
Ks_(no_blocks_), qr_Ks_(no_blocks_), | |
tmp_y_(no_states_), | |
tmp_coefficients_(no_states_, no_dofs_), | |
no_timesteps_(std::ceil((ts[ts.size()-1] - t0) / h_)), | |
cache_idx_(0), Ky_idx_cache_(0){ | |
ys_.reserve(ts.size()); | |
for(int block = 0; block < no_blocks_; ++block){ | |
int bs = block_size(block); | |
Ks_[block] = matrix_type(bs, bs); | |
qr_Ks_[block] = qr_type(bs, bs); | |
} | |
if(reverse_mode){ | |
tmp_coefficients2_ = matrix_type::Zero(no_states_, no_dofs_); | |
// tmp_rev_dfl_ = matrix_type::Zero(no_states_, no_states_); | |
tmp_rev_coefficients_ = matrix_type::Zero(no_states_, no_dofs_); | |
tmp_rev_coefficients2_ = matrix_type::Zero(no_states_, no_dofs_); | |
t_cache_ = vector_type::Zero(1+no_timesteps_); | |
y_cache_ = matrix_type::Zero(no_states_, 1+no_timesteps_); | |
if(approximate_){ | |
coefficients_cache_ = matrix_type::Zero(no_dofs_ * no_states_, 1+no_timesteps_); | |
}else{ | |
coefficients_cache_ = matrix_type::Zero(no_dofs_ * no_states_, no_iterations_); | |
} | |
adj_yl_ = vector_type::Zero(no_states_); | |
adj_fl_ = vector_type::Zero(no_states_); | |
adj_dfl_ = matrix_type::Zero(no_states_, no_states_); | |
adj_params_ = vector_type::Zero(no_states_); | |
adj_coefficients_ = vector_type::Zero(no_dofs_ * no_states_); | |
adj_rhs_ = vector_type::Zero(no_dofs_ * no_states_); | |
} | |
} | |
int block_begin(int block) const { | |
return 2 * block * no_states_; | |
} | |
int no_subblocks(int block) const { | |
return ((no_dofs_ % 2) && (2 * block + 1 == no_dofs_)) ? 1 : 2; | |
} | |
int block_size(int block) const { | |
return no_subblocks(block) * no_states_; | |
} | |
void update_fl(){ | |
fl_ = f_(tl_, yl_, val_params_); | |
} | |
void update_fdf(double t, const vector_type& y){ | |
vector_type fl; | |
matrix_type dfl; | |
jacobian([this, t](const auto& y_){return f_(t, y_, val_params_);}, y, fl, dfl); | |
fl_ = fl; | |
dfl_ = dfl; | |
} | |
void update_K(double t, const vector_type& y){ | |
Kt_ = t; | |
update_fdf(t, y); | |
const auto &Id = matrix_type::Identity(no_states_, no_states_); | |
for(int block = 0; block < no_blocks_; ++block){ | |
int nsb = no_subblocks(block); | |
for(int i = 0; i < nsb; ++i){ | |
for(int j = 0; j < nsb; ++j){ | |
Ks_[block].block( | |
i*no_states_, j*no_states_, | |
no_states_, no_states_ | |
) = ( | |
q_.time_matrix_(2*block+i, 2*block+j) * Id | |
-h_ * q_.mass_matrix_(2*block+i, 2*block+j) * dfl_ | |
); | |
} | |
} | |
qr_Ks_[block].compute(Ks_[block]); | |
} | |
} | |
void update_rhs(){ | |
update_fl(); | |
as_matrix(rhs_).noalias() = -h_ * fl_ * q_.constant_test_weight_.transpose(); | |
} | |
void set_constant(const vector_type &y){ | |
as_matrix(coefficients_).noalias() = y * q_.constant_basis_weight_.transpose(); | |
} | |
void update_coefficients(){ | |
set_constant(yl_); | |
for(int block = 0; block < no_blocks_; ++block){ | |
int bb = block_begin(block); | |
int bs = block_size(block); | |
coefficients_.segment(bb, bs) -= qr_Ks_[block].solve( | |
rhs_.segment(bb, bs) | |
); | |
} | |
} | |
template<bool reverse_mode_engaged=false> | |
void iterate_coefficients(){ | |
for(int it = 0; it < no_iterations_; ++it){ | |
if(reverse_mode_engaged){ | |
coefficients_cache_.col(it) = coefficients_; | |
} | |
// const vector_type& yl = | |
// as_matrix(coefficients_) | |
// * q_.basis_function_coefficients_.col(0); | |
// as_matrix(rhs_) | |
as_matrix(rhs_).noalias() = | |
as_matrix(coefficients_) * q_.time_matrix_.transpose(); | |
as_matrix(rhs_).noalias() -= | |
yl_ * q_.test_function_coefficients_.col(0).transpose(); | |
// as_matrix(rhs_) = as_matrix(coefficients_) * q_.time_matrix_.transpose() | |
// - yl_ * q_.test_function_coefficients_.col(0).transpose(); | |
const auto& val_params = val_params_; | |
tmp_coefficients_.noalias() = | |
as_matrix(coefficients_) * q_.basis_at_quadrature_points_; | |
for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){ | |
double qt = tl_ + q_.quadrature_points_[qi] * h_; | |
tmp_coefficients_.col(qi) = f_(qt, tmp_coefficients_.col(qi), val_params); | |
// fl_ = f_(qt, interpolate(qt), val_params); | |
// tmp_y_ = | |
// as_matrix(coefficients_) * | |
// q_.basis_at_quadrature_points_.col(qi);//interpolate(qt); | |
// fl_ = f_(qt, tmp_y_, val_params); | |
// as_matrix(rhs_) -= | |
// h_ | |
// * fl_ | |
// * q_.weighted_test_functions_.col(qi).transpose(); | |
} | |
as_matrix(rhs_).noalias() -= | |
h_ * tmp_coefficients_ * q_.weighted_test_functions_.transpose(); | |
for(int block = 0; block < no_blocks_; ++block){ | |
int bb = block_begin(block); | |
int bs = block_size(block); | |
coefficients_.segment(bb, bs) -= qr_Ks_[block].solve( | |
rhs_.segment(bb, bs) | |
); | |
} | |
} | |
} | |
void forward_step(){ | |
tl_ = tr_; | |
yl_ = yr_; | |
if(reverse_mode){ | |
t_cache_(cache_idx_) = tl_; | |
y_cache_.col(cache_idx_) = yl_; | |
} | |
if(tl_ >= Kt_ + Kh_){ | |
if(reverse_mode){ | |
Ky_idx_cache_.push_back(cache_idx_); | |
} | |
update_K(tl_, yl_); | |
} | |
if(reverse_mode){++cache_idx_;} | |
update_rhs(); | |
update_coefficients(); | |
iterate_coefficients<false>(); | |
tr_ += h_; | |
yr_ = as_matrix(coefficients_) * q_.basis_at_one_; | |
if(reverse_mode && approximate_){ | |
coefficients_cache_.col(cache_idx_-1) = coefficients_; | |
} | |
} | |
vector_type basis_functions_at(double t) const { | |
double xi = (t - tl_) / h_; | |
vector_type xi_pow(no_dofs_); | |
xi_pow(0) = 1; | |
for(int i = 1; i < no_dofs_; ++i){ | |
xi_pow(i) = xi * xi_pow(i-1); | |
} | |
return q_.basis_function_coefficients_ * xi_pow; | |
} | |
vector_type interpolate(double t){ | |
if(t == tr_){ | |
return yr_; | |
} | |
return as_matrix(coefficients_) * basis_functions_at(t); | |
} | |
void forward_sweep(){ | |
if(reverse_mode){ | |
Ky_idx_cache_.push_back(0); | |
} | |
update_K(t0_, value_of(y0_)); | |
for(auto t: ts_){ | |
while(t > tr_){ | |
forward_step(); | |
} | |
// requires t in (tl, tr]; | |
ys_.emplace_back(interpolate(t)); | |
} | |
} | |
template <typename T_y, typename T_f> | |
vector_type accumulate_f( | |
double t, const T_y& y, const T_f& adj_f | |
){ | |
nested_rev_autodiff nested; | |
var_vector_type rev_y = y; | |
var_vector_type rev_f = f_(t, rev_y, params_); | |
rev_f.adj() += adj_f; | |
grad(); | |
return rev_y.adj(); | |
} | |
// auto a | |
auto accumulate_qf(){ | |
nested_rev_autodiff nested; | |
// tmp_coefficients_.noalias() = | |
// as_matrix(coefficients_) * q_.basis_at_quadrature_points_; | |
// tmp_coefficients2_.noalias() = | |
// -h_ | |
// * as_matrix(adj_rhs_) | |
// * q_.weighted_test_functions_; | |
tmp_rev_coefficients_ = as_matrix(coefficients_) * q_.basis_at_quadrature_points_; | |
auto& rev_y = tmp_rev_coefficients_; | |
auto& rev_f = tmp_rev_coefficients2_; | |
// var_matrix_type rev_y = as_matrix(coefficients_) * q_.basis_at_quadrature_points_; | |
// var_matrix_type rev_f(no_states_, no_dofs_); | |
for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){ | |
double qt = tl_ + q_.quadrature_points_[qi] * h_; | |
rev_f.col(qi) = f_(qt, rev_y.col(qi), params_); | |
// tmp_rev_coefficients2_.col(qi) = f_(qt, rev_y.col(qi), params_); | |
// rev_f.col(qi).adj() += tmp_coefficients2_.col(qi) | |
} | |
rev_f.adj() = -h_ | |
// tmp_rev_coefficients2_.adj() = -h_ | |
* as_matrix(adj_rhs_) | |
* q_.weighted_test_functions_; | |
grad(); | |
tmp_coefficients_ = rev_y.adj(); | |
return tmp_coefficients_ * q_.basis_at_quadrature_points_.transpose(); | |
} | |
// template < | |
// typename T | |
// > | |
void accumulate_solve(const cmatrix_type &x){ | |
for(int block = 0; block < no_blocks_; ++block){ | |
int bb = block_begin(block); | |
int bs = block_size(block); | |
adj_rhs_.segment(bb, bs) = -( | |
qr_Ks_[block].householderQ() | |
* qr_Ks_[block].matrixQR() | |
.template triangularView<Eigen::Upper>() | |
.transpose() | |
.solve(adj_coefficients_.segment(bb, bs)) | |
); | |
} | |
adj_dfl_.noalias() += h_ * as_matrix(adj_rhs_) * x.transpose(); | |
} | |
void reverse_iterate_coefficients(){ | |
for(int it = no_iterations_-1; it >= 0; --it){ | |
// c(i+1) = c(i) - K \ rhs(i) | |
as_vector(tmp_coefficients_) = coefficients_cache_.col(it) - coefficients_; | |
accumulate_solve(tmp_coefficients_); | |
// rhs = c * rho.T - yl_ x t0 - <f x t> | |
as_matrix(adj_coefficients_).noalias() += | |
as_matrix(adj_rhs_) * q_.time_matrix_; | |
adj_yl_.noalias() -= | |
as_matrix(adj_rhs_) | |
* q_.test_function_coefficients_.col(0); | |
coefficients_ = coefficients_cache_.col(it); | |
as_matrix(adj_coefficients_).noalias() += accumulate_qf(); | |
} | |
} | |
// adj(coefficients) -> adj(yl, params/dfl) | |
void approximate_collapse(){ | |
// std::cout << no_dofs_ << "-" << no_iterations_ << "-" << h_ << ": " << max(abs(rhs_)) << std::endl; | |
// rhs(c*; yl, params) =!= 0 | |
// rhs(c) = c * rho.T - yl_ x t0 - <f(c) x t> | |
tmp_coefficients_.noalias() = | |
as_matrix(coefficients_) * q_.basis_at_quadrature_points_; | |
const auto& y = tmp_coefficients_; | |
auto& residual = tmp_coefficients2_; | |
// adj_rhs_ *= 0; | |
adj_rhs_.setZero(); | |
{ | |
// nested_rev_autodiff nested; | |
// tmp_rev_coefficients_ = y; | |
// auto& rev_y = tmp_rev_coefficients_; | |
// auto& rev_f = tmp_rev_coefficients2_; | |
// for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){ | |
// double qt = tl_ + q_.quadrature_points_[qi] * h_; | |
// rev_f.col(qi) = f_(qt, rev_y.col(qi), val_params_); | |
// } | |
for(int i = 0; i < 1+no_iterations_/2; ++i){ | |
// for(int i = 0; i < 0; ++i){ | |
residual = as_matrix(adj_coefficients_); | |
if(i > 0){ | |
residual.noalias() -= as_matrix(adj_rhs_) * q_.time_matrix_; | |
nested_rev_autodiff nested; | |
tmp_rev_coefficients_ = y; | |
auto& rev_y = tmp_rev_coefficients_; | |
auto& rev_f = tmp_rev_coefficients2_; | |
for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){ | |
double qt = tl_ + q_.quadrature_points_[qi] * h_; | |
rev_f.col(qi) = f_(qt, rev_y.col(qi), val_params_); | |
} | |
// nested.set_zero_all_adjoints(); | |
rev_f.adj().noalias() = as_matrix(adj_rhs_) * q_.weighted_test_functions_; | |
grad(); | |
matrix_type tmp(rev_y.adj()); | |
residual.noalias() += | |
h_ | |
* tmp | |
// * as_matrix(rev_y.adj()) | |
* q_.basis_at_quadrature_points_.transpose(); | |
} | |
for(int block = 0; block < no_blocks_; ++block){ | |
int bb = block_begin(block); | |
int bs = block_size(block); | |
adj_rhs_.segment(bb, bs) += ( | |
qr_Ks_[block].householderQ() | |
* qr_Ks_[block].matrixQR() | |
.template triangularView<Eigen::Upper>() | |
.transpose() | |
.solve(as_vector(residual).segment(bb, bs)) | |
); | |
} | |
} | |
} | |
if(parameter_sensitivities){ | |
nested_rev_autodiff nested; | |
auto& rev_f = tmp_rev_coefficients2_; | |
for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){ | |
double qt = tl_ + q_.quadrature_points_[qi] * h_; | |
rev_f.col(qi) = f_(qt, y.col(qi), params_); | |
} | |
rev_f.adj().noalias() = h_ | |
* as_matrix(adj_rhs_) | |
* q_.weighted_test_functions_; | |
grad(); | |
} | |
adj_yl_.noalias() += | |
as_matrix(adj_rhs_) * q_.test_function_coefficients_.col(0); | |
} | |
// adj(coefficients) -> adj(yl, params/dfl) | |
void reverse_collapse(){ | |
// std::cout << cache_idx_ << "-" << tl_ << "-COLLAPSE" << std::endl; | |
if(approximate_){ | |
approximate_collapse(); | |
return; | |
} | |
reverse_iterate_coefficients(); | |
// coefficients -= K \ rhs | |
tmp_coefficients_.noalias() = | |
yl_ * q_.constant_basis_weight_.transpose() | |
- as_matrix(coefficients_); | |
accumulate_solve(tmp_coefficients_); | |
// coefficients = cbw x yl | |
adj_yl_.noalias() += as_matrix(adj_coefficients_) * q_.constant_basis_weight_; | |
// rhs = -h * ctw x f(yl) | |
adj_fl_.noalias() = -h_ * as_matrix(adj_rhs_) * q_.constant_test_weight_; | |
// f = f(yl, params) | |
adj_yl_.noalias() += accumulate_f(tl_, yl_, adj_fl_); | |
if(Ky_idx_cache_.back() == cache_idx_){ | |
// std::cout << "ACCUMULATING DF" << std::endl; | |
// int dummy; | |
// std::cin >> dummy; | |
adj_yl_ += accumulate_df(); | |
// adj_dfl_ *= 0; | |
adj_dfl_.setZero(); | |
} | |
} | |
void reverse_move(){ | |
// std::cout << cache_idx_ << "-" << tl_ << "-MOVE" << std::endl; | |
tr_ = tl_; | |
yr_ = yl_; | |
--cache_idx_; | |
tl_ = t_cache_(cache_idx_); | |
yl_ = y_cache_.col(cache_idx_); | |
} | |
// adj(y1) -> adj(coefficients) | |
// y1 = basis_at_one @ coefficients | |
void reverse_expand(){ | |
// std::cout << cache_idx_ << "-" << tl_ << "-EXPAND" << std::endl; | |
as_matrix(adj_coefficients_).noalias() = adj_yl_ * q_.basis_at_one_.transpose(); | |
adj_yl_ = vector_type::Zero(no_states_); | |
if(approximate_){ | |
coefficients_ = coefficients_cache_.col(cache_idx_); | |
}else{ | |
update_rhs(); | |
update_coefficients(); | |
iterate_coefficients<true>(); | |
} | |
} | |
void reverse_step(){ | |
reverse_collapse(); // adj(coefficients) -> adj(y1, df1) | |
reverse_move(); // move t, y backwards | |
if(Ky_idx_cache_.back() > cache_idx_){ | |
Ky_idx_cache_.pop_back(); | |
int idx = Ky_idx_cache_.back(); | |
update_K(t_cache_(idx), y_cache_.col(idx)); | |
} | |
reverse_expand(); // adj(y1) -> adj(coefficients) | |
} | |
template <typename T> | |
void accumulate_y(double t, const T& y){} | |
void accumulate_y(double t, const var_vector_type& y){ | |
if(t == tl_){ | |
adj_yl_ += y.adj(); | |
}else{ | |
as_matrix(adj_coefficients_).noalias() += y.adj() * basis_functions_at(t).transpose(); | |
} | |
} | |
vector_type accumulate_df(){ | |
nested_rev_autodiff nested; | |
var_vector_type rev_yl = value_of(yl_); | |
var_vector_type rev_fl = value_of(fl_); | |
var_matrix_type rev_dfl = value_of(dfl_); | |
jacobian([this](const fvarvar_vector_type& y){ | |
return f_(tl_, y, params_.template cast<fvar<var>>()); | |
}, rev_yl, rev_fl, rev_dfl); | |
rev_dfl.adj() = adj_dfl_; | |
grad(); | |
return rev_yl.adj(); | |
} | |
void reverse_sweep(){ | |
cache_idx_--; | |
// for(int i = 0; i < t_cache_.size(); ++i){ | |
// std::cout << i << ": " << t_cache_[i] << std::endl; | |
// } | |
// for(int i = 0; i < Ky_idx_cache_.size(); ++i){ | |
// std::cout << i << ": " << Ky_idx_cache_[i] << std::endl; | |
// } | |
reverse_expand(); | |
for(int i = ts_.size() - 1; i >= 0; --i){ | |
double t = ts_[i]; | |
const auto& y = ys_[i]; | |
while(t < tl_){ | |
reverse_step(); | |
} | |
// requires t in [tl, tr]; | |
accumulate_y(t, y); | |
} | |
while(cache_idx_){ | |
reverse_step(); | |
} | |
reverse_collapse(); | |
// adj_yl_ += accumulate_df(); | |
safe_adjoint_increment(y0_, adj_yl_); | |
qr_Ks_.clear(); | |
qr_Ks_.shrink_to_fit(); | |
} | |
}; | |
struct ode_solver_factory{ | |
int no_dofs_; | |
int no_iterations_; | |
double h_; | |
double Kh_; | |
bool approximate_; | |
ode_solver_factory() | |
: no_dofs_(1), no_iterations_(0), | |
h_(1.), Kh_(0), approximate_(false){} | |
template < | |
typename F, | |
typename Y, | |
typename P | |
> | |
ode_solver<F,Y,P> initialize( | |
const F& f, | |
double t0, | |
const Y& y0, | |
const std::vector<double>& ts, | |
const P& params | |
){ | |
return ode_solver<F,Y,P>( | |
f, t0, y0, ts, params, | |
no_dofs_, h_, (Kh_ > 0) ? Kh_ : (10 * h_), | |
no_iterations_, approximate_ | |
); | |
} | |
}; | |
template < | |
typename SF, typename F, typename Y, typename P, | |
typename = std::enable_if_t< | |
ode_solver<F,Y,P>::implemented, | |
void | |
> | |
> | |
auto ode_dg_impl( | |
SF solver_factory, | |
const F& f, | |
double t0, | |
const Y& y0, | |
const std::vector<double>& ts, | |
const P& params | |
){ | |
auto solver = solver_factory.initialize(f, t0, y0, ts, params); | |
// std::cout << "FORWARD SWEEP" << std::endl; | |
solver.forward_sweep(); | |
if(solver.reverse_mode){ | |
reverse_pass_callback([solver]() mutable { | |
// std::cout << "REVERSE SWEEP" << std::endl; | |
solver.reverse_sweep(); | |
}); | |
} | |
return solver.ys(); | |
} | |
} | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment