Skip to content

Instantly share code, notes, and snippets.

@nsiccha
Created February 17, 2022 16:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nsiccha/9c38758cf7ad754de1fb92701cc12cba to your computer and use it in GitHub Desktop.
Save nsiccha/9c38758cf7ad754de1fb92701cc12cba to your computer and use it in GitHub Desktop.
#ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_DG_HPP
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_DG_HPP
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/rev/core/var.hpp>
#include <stan/math/rev/fun/from_var_value.hpp>
#include <stan/math/fwd/functor/jacobian.hpp>
#include <stan/math/rev/fun/adjoint_of.hpp>
#include <stan/math/rev/core/nested_rev_autodiff.hpp>
#include <stan/math/prim/functor/ode_dg_quadrature.hpp>
#include <stan/math/mix/meta.hpp>
#include <stan/math/mix/fun.hpp>
#include <stan/math/mix/functor.hpp>
#include <stan/math/fwd/fun/pow.hpp>
#include <stan/math/rev/fun/pow.hpp>
#include <stan/math/prim/fun/abs.hpp>
#include <stan/math/prim/fun/max.hpp>
// #include <stan/math/mix/core/std_complex.hpp>
// #include <stan/math/prim/core/complex_base.hpp>
// #include <stan/math
// #include <stan/math/mix.hpp>
// #include <iostream>
namespace stan {
namespace math {
// fvar<var> fmax(double x, const fvar<var>& y){
// return (x < y) ? y : fvar<var>(x);
// }
inline fvar<var> square(fvar<var> x){
return fvar<var>(square(x.val_), x.d_ * 2 * x.val_);
}
inline fvar<var> cos(const fvar<var>& x) {
using std::cos;
using std::sin;
return fvar<var>(cos(x.val_), x.d_ * -sin(x.val_));
}
inline fvar<var> sin(const fvar<var>& x) {
using std::cos;
using std::sin;
return fvar<var>(sin(x.val_), x.d_ * cos(x.val_));
}
// auto pow(const fvar<var>& x, const fvar<var>& y){return pow(x,y);}
inline fvar<var> pow(const fvar<var>& x1, const fvar<var>& x2) {
using std::log;
using std::pow;
var pow_x1_x2(pow(x1.val_, x2.val_));
return fvar<var>(pow_x1_x2, (x2.d_ * log(x1.val_) + x2.val_ * x1.d_ / x1.val_)
* pow_x1_x2);
}
// template <class T> void my_real(T) = delete;
// auto my_pow(double x, double y){return std::pow(x,y);}
// auto my_pow(const fvar<double>& x, const fvar<double>& y){return pow(x,y);}
// auto my_pow(const var& x, const var& y){return pow(x,y);}
// // auto my_pow(const fvar<var>& x, double y){return pow(x,y);}
// auto my_pow(const fvar<var>& x, const fvar<var>& y){return pow(x,y);}
// template <typename X>
// auto my_pow(X x, X y){return pow(x,y);}
// auto my_real(var x){return x;}
// auto my_real(std::complex<var> x){return x.real();}
// auto my_real(fvar<var> x){return x;}
// template <typename T> auto my_real(T x){return x.hello();}
template <typename T, typename V>
std::enable_if_t<
std::is_same<scalar_type_t<T>, double>::value,
void
> safe_adjoint_increment(const T& x, const V& inc){}
template <typename T, typename V>
std::enable_if_t<
std::is_same<scalar_type_t<T>, var>::value,
void
> safe_adjoint_increment(T& x, const V& inc){
x.adj() += inc;
}
template <
typename F,
typename Y,
typename P
>
struct ode_solver{
typedef Eigen::Matrix<double, Eigen::Dynamic, 1> vector_type;
typedef Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> matrix_type;
typedef Eigen::Matrix<var, Eigen::Dynamic, 1> var_vector_type;
typedef Eigen::Matrix<var, Eigen::Dynamic, Eigen::Dynamic> var_matrix_type;
typedef Eigen::Matrix<fvar<var>, Eigen::Dynamic, 1> fvarvar_vector_type;
typedef Eigen::Matrix<fvar<var>, Eigen::Dynamic, 1> fvar_var_vector_type;
typedef scalar_type_t<Y> scalar_type_y;
typedef scalar_type_t<P> scalar_type_p;
static constexpr bool implemented = (
std::is_same<scalar_type_y, double>::value || std::is_same<scalar_type_y, var>::value
) && (
std::is_same<scalar_type_p, double>::value || std::is_same<scalar_type_p, var>::value
);
static constexpr bool parameter_sensitivities = (
std::is_same<scalar_type_p, var>::value
);
static constexpr bool initial_sensitivities = (
std::is_same<scalar_type_y, var>::value
);
static constexpr bool reverse_mode = (
parameter_sensitivities || initial_sensitivities
// std::is_same<scalar_type_y, var>::value || std::is_same<scalar_type_p, var>::value
);
typedef std::conditional_t<
reverse_mode, var_vector_type, vector_type
> y_type;
typedef std::vector<y_type> return_type;
typedef Eigen::HouseholderQR<matrix_type> qr_type;
template <typename T>
using carena_t = std::conditional_t<reverse_mode, arena_t<T>, T>;
typedef carena_t<vector_type> cvector_type;
typedef carena_t<matrix_type> cmatrix_type;
typedef carena_t<var_vector_type> cvar_vector_type;
typedef carena_t<var_matrix_type> cvar_matrix_type;
typedef Eigen::Map<vector_type> vector_map_type;
typedef Eigen::Map<matrix_type> matrix_map_type;
typedef Eigen::Map<const matrix_type> const_matrix_map_type;
typedef Eigen::HouseholderQR<cmatrix_type> cqr_type;
// Inputs
const F& f_;
double t0_;
carena_t<const Y&> y0_;
carena_t<const std::vector<double>&> ts_;
carena_t<const P&> params_;
const cvector_type val_params_;
// Outputs
// MEMORY LEAK?!
return_type ys_;
return_type ys() const {return ys_;}
// carena_t<return_type> ys_;
// return_type ys() const {
// return_type rv;
// rv.reserve(ys_.size());
// for(const auto& y : ys_){
// rv.emplace_back(y);
// }
// return rv;
// }
// Forward/Reverse working memory
const int no_states_;
const int no_dofs_;
const int no_iterations_;
const bool approximate_;
const ode_dg_quadrature& q_;
double h_;
double Kh_;
double Kt_;
double tl_;
double tr_;
cvector_type yl_;
cvector_type yr_;
cvector_type fl_;
cmatrix_type dfl_;
cvector_type coefficients_;
cvector_type rhs_;
int no_blocks_;
carena_t<std::vector<matrix_type>> Ks_;
// MEMORY LEAK?!
std::vector<qr_type> qr_Ks_;
// Scratch memory for temporaries
cvector_type tmp_y_;
cmatrix_type tmp_coefficients_;
cmatrix_type tmp_coefficients2_;
// cvar_matrix_type tmp_rev_dfl_;
cvar_matrix_type tmp_rev_coefficients_;
cvar_matrix_type tmp_rev_coefficients2_;
// Reverse cache + working memory
const int no_timesteps_;
int cache_idx_;
carena_t<std::vector<int>> Ky_idx_cache_;
// const int cache_size_;
// carena_t<std::vector<vector_type>> y_cache_;
cvector_type t_cache_;
cmatrix_type y_cache_;
cmatrix_type coefficients_cache_;
cvector_type adj_yl_;
cvector_type adj_fl_;
cmatrix_type adj_dfl_;
cvector_type adj_params_;
cvector_type adj_coefficients_;
cvector_type adj_rhs_;
template<typename T>
matrix_map_type as_matrix(T &x){
return matrix_map_type(x.data(), no_states_, no_dofs_);
}
template<typename T>
const_matrix_map_type as_matrix(const T &x){
return const_matrix_map_type(x.data(), no_states_, no_dofs_);
}
template<typename T>
vector_map_type as_vector(T &x){
return vector_map_type(x.data(), no_states_ * no_dofs_);
}
ode_solver(
const F& f,
double t0,
const Y& y0,
const std::vector<double>& ts,
const P& params,
int no_dofs,
double h,
double Kh,
int no_iterations,
bool approximate
) :
f_(f), t0_(t0), y0_(y0), ts_(to_arena_if<reverse_mode>(ts)), params_(params),
val_params_(value_of(params)),
no_states_(y0.size()), no_dofs_(no_dofs),
no_iterations_(no_iterations), approximate_(approximate),
q_(ode_dg_quadrature::instance(no_dofs)),
h_(h), Kh_(Kh),
tl_(t0), tr_(t0), yl_(value_of(y0)), yr_(value_of(y0)),
fl_(no_states_), dfl_(no_states_, no_states_),
coefficients_(no_dofs_ * no_states_),
rhs_(no_dofs_ * no_states_),
no_blocks_(1+(no_dofs - 1) / 2),
Ks_(no_blocks_), qr_Ks_(no_blocks_),
tmp_y_(no_states_),
tmp_coefficients_(no_states_, no_dofs_),
no_timesteps_(std::ceil((ts[ts.size()-1] - t0) / h_)),
cache_idx_(0), Ky_idx_cache_(0){
ys_.reserve(ts.size());
for(int block = 0; block < no_blocks_; ++block){
int bs = block_size(block);
Ks_[block] = matrix_type(bs, bs);
qr_Ks_[block] = qr_type(bs, bs);
}
if(reverse_mode){
tmp_coefficients2_ = matrix_type::Zero(no_states_, no_dofs_);
// tmp_rev_dfl_ = matrix_type::Zero(no_states_, no_states_);
tmp_rev_coefficients_ = matrix_type::Zero(no_states_, no_dofs_);
tmp_rev_coefficients2_ = matrix_type::Zero(no_states_, no_dofs_);
t_cache_ = vector_type::Zero(1+no_timesteps_);
y_cache_ = matrix_type::Zero(no_states_, 1+no_timesteps_);
if(approximate_){
coefficients_cache_ = matrix_type::Zero(no_dofs_ * no_states_, 1+no_timesteps_);
}else{
coefficients_cache_ = matrix_type::Zero(no_dofs_ * no_states_, no_iterations_);
}
adj_yl_ = vector_type::Zero(no_states_);
adj_fl_ = vector_type::Zero(no_states_);
adj_dfl_ = matrix_type::Zero(no_states_, no_states_);
adj_params_ = vector_type::Zero(no_states_);
adj_coefficients_ = vector_type::Zero(no_dofs_ * no_states_);
adj_rhs_ = vector_type::Zero(no_dofs_ * no_states_);
}
}
int block_begin(int block) const {
return 2 * block * no_states_;
}
int no_subblocks(int block) const {
return ((no_dofs_ % 2) && (2 * block + 1 == no_dofs_)) ? 1 : 2;
}
int block_size(int block) const {
return no_subblocks(block) * no_states_;
}
void update_fl(){
fl_ = f_(tl_, yl_, val_params_);
}
void update_fdf(double t, const vector_type& y){
vector_type fl;
matrix_type dfl;
jacobian([this, t](const auto& y_){return f_(t, y_, val_params_);}, y, fl, dfl);
fl_ = fl;
dfl_ = dfl;
}
void update_K(double t, const vector_type& y){
Kt_ = t;
update_fdf(t, y);
const auto &Id = matrix_type::Identity(no_states_, no_states_);
for(int block = 0; block < no_blocks_; ++block){
int nsb = no_subblocks(block);
for(int i = 0; i < nsb; ++i){
for(int j = 0; j < nsb; ++j){
Ks_[block].block(
i*no_states_, j*no_states_,
no_states_, no_states_
) = (
q_.time_matrix_(2*block+i, 2*block+j) * Id
-h_ * q_.mass_matrix_(2*block+i, 2*block+j) * dfl_
);
}
}
qr_Ks_[block].compute(Ks_[block]);
}
}
void update_rhs(){
update_fl();
as_matrix(rhs_).noalias() = -h_ * fl_ * q_.constant_test_weight_.transpose();
}
void set_constant(const vector_type &y){
as_matrix(coefficients_).noalias() = y * q_.constant_basis_weight_.transpose();
}
void update_coefficients(){
set_constant(yl_);
for(int block = 0; block < no_blocks_; ++block){
int bb = block_begin(block);
int bs = block_size(block);
coefficients_.segment(bb, bs) -= qr_Ks_[block].solve(
rhs_.segment(bb, bs)
);
}
}
template<bool reverse_mode_engaged=false>
void iterate_coefficients(){
for(int it = 0; it < no_iterations_; ++it){
if(reverse_mode_engaged){
coefficients_cache_.col(it) = coefficients_;
}
// const vector_type& yl =
// as_matrix(coefficients_)
// * q_.basis_function_coefficients_.col(0);
// as_matrix(rhs_)
as_matrix(rhs_).noalias() =
as_matrix(coefficients_) * q_.time_matrix_.transpose();
as_matrix(rhs_).noalias() -=
yl_ * q_.test_function_coefficients_.col(0).transpose();
// as_matrix(rhs_) = as_matrix(coefficients_) * q_.time_matrix_.transpose()
// - yl_ * q_.test_function_coefficients_.col(0).transpose();
const auto& val_params = val_params_;
tmp_coefficients_.noalias() =
as_matrix(coefficients_) * q_.basis_at_quadrature_points_;
for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){
double qt = tl_ + q_.quadrature_points_[qi] * h_;
tmp_coefficients_.col(qi) = f_(qt, tmp_coefficients_.col(qi), val_params);
// fl_ = f_(qt, interpolate(qt), val_params);
// tmp_y_ =
// as_matrix(coefficients_) *
// q_.basis_at_quadrature_points_.col(qi);//interpolate(qt);
// fl_ = f_(qt, tmp_y_, val_params);
// as_matrix(rhs_) -=
// h_
// * fl_
// * q_.weighted_test_functions_.col(qi).transpose();
}
as_matrix(rhs_).noalias() -=
h_ * tmp_coefficients_ * q_.weighted_test_functions_.transpose();
for(int block = 0; block < no_blocks_; ++block){
int bb = block_begin(block);
int bs = block_size(block);
coefficients_.segment(bb, bs) -= qr_Ks_[block].solve(
rhs_.segment(bb, bs)
);
}
}
}
void forward_step(){
tl_ = tr_;
yl_ = yr_;
if(reverse_mode){
t_cache_(cache_idx_) = tl_;
y_cache_.col(cache_idx_) = yl_;
}
if(tl_ >= Kt_ + Kh_){
if(reverse_mode){
Ky_idx_cache_.push_back(cache_idx_);
}
update_K(tl_, yl_);
}
if(reverse_mode){++cache_idx_;}
update_rhs();
update_coefficients();
iterate_coefficients<false>();
tr_ += h_;
yr_ = as_matrix(coefficients_) * q_.basis_at_one_;
if(reverse_mode && approximate_){
coefficients_cache_.col(cache_idx_-1) = coefficients_;
}
}
vector_type basis_functions_at(double t) const {
double xi = (t - tl_) / h_;
vector_type xi_pow(no_dofs_);
xi_pow(0) = 1;
for(int i = 1; i < no_dofs_; ++i){
xi_pow(i) = xi * xi_pow(i-1);
}
return q_.basis_function_coefficients_ * xi_pow;
}
vector_type interpolate(double t){
if(t == tr_){
return yr_;
}
return as_matrix(coefficients_) * basis_functions_at(t);
}
void forward_sweep(){
if(reverse_mode){
Ky_idx_cache_.push_back(0);
}
update_K(t0_, value_of(y0_));
for(auto t: ts_){
while(t > tr_){
forward_step();
}
// requires t in (tl, tr];
ys_.emplace_back(interpolate(t));
}
}
template <typename T_y, typename T_f>
vector_type accumulate_f(
double t, const T_y& y, const T_f& adj_f
){
nested_rev_autodiff nested;
var_vector_type rev_y = y;
var_vector_type rev_f = f_(t, rev_y, params_);
rev_f.adj() += adj_f;
grad();
return rev_y.adj();
}
// auto a
auto accumulate_qf(){
nested_rev_autodiff nested;
// tmp_coefficients_.noalias() =
// as_matrix(coefficients_) * q_.basis_at_quadrature_points_;
// tmp_coefficients2_.noalias() =
// -h_
// * as_matrix(adj_rhs_)
// * q_.weighted_test_functions_;
tmp_rev_coefficients_ = as_matrix(coefficients_) * q_.basis_at_quadrature_points_;
auto& rev_y = tmp_rev_coefficients_;
auto& rev_f = tmp_rev_coefficients2_;
// var_matrix_type rev_y = as_matrix(coefficients_) * q_.basis_at_quadrature_points_;
// var_matrix_type rev_f(no_states_, no_dofs_);
for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){
double qt = tl_ + q_.quadrature_points_[qi] * h_;
rev_f.col(qi) = f_(qt, rev_y.col(qi), params_);
// tmp_rev_coefficients2_.col(qi) = f_(qt, rev_y.col(qi), params_);
// rev_f.col(qi).adj() += tmp_coefficients2_.col(qi)
}
rev_f.adj() = -h_
// tmp_rev_coefficients2_.adj() = -h_
* as_matrix(adj_rhs_)
* q_.weighted_test_functions_;
grad();
tmp_coefficients_ = rev_y.adj();
return tmp_coefficients_ * q_.basis_at_quadrature_points_.transpose();
}
// template <
// typename T
// >
void accumulate_solve(const cmatrix_type &x){
for(int block = 0; block < no_blocks_; ++block){
int bb = block_begin(block);
int bs = block_size(block);
adj_rhs_.segment(bb, bs) = -(
qr_Ks_[block].householderQ()
* qr_Ks_[block].matrixQR()
.template triangularView<Eigen::Upper>()
.transpose()
.solve(adj_coefficients_.segment(bb, bs))
);
}
adj_dfl_.noalias() += h_ * as_matrix(adj_rhs_) * x.transpose();
}
void reverse_iterate_coefficients(){
for(int it = no_iterations_-1; it >= 0; --it){
// c(i+1) = c(i) - K \ rhs(i)
as_vector(tmp_coefficients_) = coefficients_cache_.col(it) - coefficients_;
accumulate_solve(tmp_coefficients_);
// rhs = c * rho.T - yl_ x t0 - <f x t>
as_matrix(adj_coefficients_).noalias() +=
as_matrix(adj_rhs_) * q_.time_matrix_;
adj_yl_.noalias() -=
as_matrix(adj_rhs_)
* q_.test_function_coefficients_.col(0);
coefficients_ = coefficients_cache_.col(it);
as_matrix(adj_coefficients_).noalias() += accumulate_qf();
}
}
// adj(coefficients) -> adj(yl, params/dfl)
void approximate_collapse(){
// std::cout << no_dofs_ << "-" << no_iterations_ << "-" << h_ << ": " << max(abs(rhs_)) << std::endl;
// rhs(c*; yl, params) =!= 0
// rhs(c) = c * rho.T - yl_ x t0 - <f(c) x t>
tmp_coefficients_.noalias() =
as_matrix(coefficients_) * q_.basis_at_quadrature_points_;
const auto& y = tmp_coefficients_;
auto& residual = tmp_coefficients2_;
// adj_rhs_ *= 0;
adj_rhs_.setZero();
{
// nested_rev_autodiff nested;
// tmp_rev_coefficients_ = y;
// auto& rev_y = tmp_rev_coefficients_;
// auto& rev_f = tmp_rev_coefficients2_;
// for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){
// double qt = tl_ + q_.quadrature_points_[qi] * h_;
// rev_f.col(qi) = f_(qt, rev_y.col(qi), val_params_);
// }
for(int i = 0; i < 1+no_iterations_/2; ++i){
// for(int i = 0; i < 0; ++i){
residual = as_matrix(adj_coefficients_);
if(i > 0){
residual.noalias() -= as_matrix(adj_rhs_) * q_.time_matrix_;
nested_rev_autodiff nested;
tmp_rev_coefficients_ = y;
auto& rev_y = tmp_rev_coefficients_;
auto& rev_f = tmp_rev_coefficients2_;
for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){
double qt = tl_ + q_.quadrature_points_[qi] * h_;
rev_f.col(qi) = f_(qt, rev_y.col(qi), val_params_);
}
// nested.set_zero_all_adjoints();
rev_f.adj().noalias() = as_matrix(adj_rhs_) * q_.weighted_test_functions_;
grad();
matrix_type tmp(rev_y.adj());
residual.noalias() +=
h_
* tmp
// * as_matrix(rev_y.adj())
* q_.basis_at_quadrature_points_.transpose();
}
for(int block = 0; block < no_blocks_; ++block){
int bb = block_begin(block);
int bs = block_size(block);
adj_rhs_.segment(bb, bs) += (
qr_Ks_[block].householderQ()
* qr_Ks_[block].matrixQR()
.template triangularView<Eigen::Upper>()
.transpose()
.solve(as_vector(residual).segment(bb, bs))
);
}
}
}
if(parameter_sensitivities){
nested_rev_autodiff nested;
auto& rev_f = tmp_rev_coefficients2_;
for(int qi = 0; qi < q_.quadrature_points_.size(); ++qi){
double qt = tl_ + q_.quadrature_points_[qi] * h_;
rev_f.col(qi) = f_(qt, y.col(qi), params_);
}
rev_f.adj().noalias() = h_
* as_matrix(adj_rhs_)
* q_.weighted_test_functions_;
grad();
}
adj_yl_.noalias() +=
as_matrix(adj_rhs_) * q_.test_function_coefficients_.col(0);
}
// adj(coefficients) -> adj(yl, params/dfl)
void reverse_collapse(){
// std::cout << cache_idx_ << "-" << tl_ << "-COLLAPSE" << std::endl;
if(approximate_){
approximate_collapse();
return;
}
reverse_iterate_coefficients();
// coefficients -= K \ rhs
tmp_coefficients_.noalias() =
yl_ * q_.constant_basis_weight_.transpose()
- as_matrix(coefficients_);
accumulate_solve(tmp_coefficients_);
// coefficients = cbw x yl
adj_yl_.noalias() += as_matrix(adj_coefficients_) * q_.constant_basis_weight_;
// rhs = -h * ctw x f(yl)
adj_fl_.noalias() = -h_ * as_matrix(adj_rhs_) * q_.constant_test_weight_;
// f = f(yl, params)
adj_yl_.noalias() += accumulate_f(tl_, yl_, adj_fl_);
if(Ky_idx_cache_.back() == cache_idx_){
// std::cout << "ACCUMULATING DF" << std::endl;
// int dummy;
// std::cin >> dummy;
adj_yl_ += accumulate_df();
// adj_dfl_ *= 0;
adj_dfl_.setZero();
}
}
void reverse_move(){
// std::cout << cache_idx_ << "-" << tl_ << "-MOVE" << std::endl;
tr_ = tl_;
yr_ = yl_;
--cache_idx_;
tl_ = t_cache_(cache_idx_);
yl_ = y_cache_.col(cache_idx_);
}
// adj(y1) -> adj(coefficients)
// y1 = basis_at_one @ coefficients
void reverse_expand(){
// std::cout << cache_idx_ << "-" << tl_ << "-EXPAND" << std::endl;
as_matrix(adj_coefficients_).noalias() = adj_yl_ * q_.basis_at_one_.transpose();
adj_yl_ = vector_type::Zero(no_states_);
if(approximate_){
coefficients_ = coefficients_cache_.col(cache_idx_);
}else{
update_rhs();
update_coefficients();
iterate_coefficients<true>();
}
}
void reverse_step(){
reverse_collapse(); // adj(coefficients) -> adj(y1, df1)
reverse_move(); // move t, y backwards
if(Ky_idx_cache_.back() > cache_idx_){
Ky_idx_cache_.pop_back();
int idx = Ky_idx_cache_.back();
update_K(t_cache_(idx), y_cache_.col(idx));
}
reverse_expand(); // adj(y1) -> adj(coefficients)
}
template <typename T>
void accumulate_y(double t, const T& y){}
void accumulate_y(double t, const var_vector_type& y){
if(t == tl_){
adj_yl_ += y.adj();
}else{
as_matrix(adj_coefficients_).noalias() += y.adj() * basis_functions_at(t).transpose();
}
}
vector_type accumulate_df(){
nested_rev_autodiff nested;
var_vector_type rev_yl = value_of(yl_);
var_vector_type rev_fl = value_of(fl_);
var_matrix_type rev_dfl = value_of(dfl_);
jacobian([this](const fvarvar_vector_type& y){
return f_(tl_, y, params_.template cast<fvar<var>>());
}, rev_yl, rev_fl, rev_dfl);
rev_dfl.adj() = adj_dfl_;
grad();
return rev_yl.adj();
}
void reverse_sweep(){
cache_idx_--;
// for(int i = 0; i < t_cache_.size(); ++i){
// std::cout << i << ": " << t_cache_[i] << std::endl;
// }
// for(int i = 0; i < Ky_idx_cache_.size(); ++i){
// std::cout << i << ": " << Ky_idx_cache_[i] << std::endl;
// }
reverse_expand();
for(int i = ts_.size() - 1; i >= 0; --i){
double t = ts_[i];
const auto& y = ys_[i];
while(t < tl_){
reverse_step();
}
// requires t in [tl, tr];
accumulate_y(t, y);
}
while(cache_idx_){
reverse_step();
}
reverse_collapse();
// adj_yl_ += accumulate_df();
safe_adjoint_increment(y0_, adj_yl_);
qr_Ks_.clear();
qr_Ks_.shrink_to_fit();
}
};
struct ode_solver_factory{
int no_dofs_;
int no_iterations_;
double h_;
double Kh_;
bool approximate_;
ode_solver_factory()
: no_dofs_(1), no_iterations_(0),
h_(1.), Kh_(0), approximate_(false){}
template <
typename F,
typename Y,
typename P
>
ode_solver<F,Y,P> initialize(
const F& f,
double t0,
const Y& y0,
const std::vector<double>& ts,
const P& params
){
return ode_solver<F,Y,P>(
f, t0, y0, ts, params,
no_dofs_, h_, (Kh_ > 0) ? Kh_ : (10 * h_),
no_iterations_, approximate_
);
}
};
template <
typename SF, typename F, typename Y, typename P,
typename = std::enable_if_t<
ode_solver<F,Y,P>::implemented,
void
>
>
auto ode_dg_impl(
SF solver_factory,
const F& f,
double t0,
const Y& y0,
const std::vector<double>& ts,
const P& params
){
auto solver = solver_factory.initialize(f, t0, y0, ts, params);
// std::cout << "FORWARD SWEEP" << std::endl;
solver.forward_sweep();
if(solver.reverse_mode){
reverse_pass_callback([solver]() mutable {
// std::cout << "REVERSE SWEEP" << std::endl;
solver.reverse_sweep();
});
}
return solver.ys();
}
}
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment