Skip to content

Instantly share code, notes, and snippets.

@boennecd

boennecd/chol-deriv.cpp

Last active Mar 12, 2020
Embed
What would you like to do?
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
/* d vech(chol(X)) / d vech(X). See mathoverflow.net/a/232129/134083
*
* Args:
* X: symmetric positive definite matrix.
* upper: logical for whether vech denotes the upper triangular part.
*/
// [[Rcpp::export]]
arma::mat dchol(arma::mat const &X, bool const upper = false)
{
using arma::uword;
uword const ndim = X.n_rows,
nvech = (ndim * (ndim + 1L)) / 2L;
arma::mat out(nvech, nvech, arma::fill::zeros);
arma::mat const F = arma::chol(X),
Fi = arma::inv(arma::trimatu(F));
/* class to do the computation */
struct util {
using uword = arma::uword;
arma::mat const &F, &Fi;
uword const ndim = F.n_cols;
util(arma::mat const &F, arma::mat const &Fi): F(F), Fi(Fi) {
assert(F .n_rows == ndim);
assert(F .n_cols == ndim);
assert(Fi.n_cols == ndim);
assert(Fi.n_rows == ndim);
}
double operator()
(uword const i, uword const j, uword const k) const {
double out(0);
double const *f = &F.at(j, i),
*fik = &Fi.at(k, j),
mult = *fik;
out += *f++ * *fik / 2.;
fik += ndim;
for(uword m = j + 1; m < ndim; m++, fik += ndim)
out += *f++ * *fik;
out *= mult;
return out;
}
double operator()
(uword const i, uword const j, uword const k, uword const l) const {
double out(0);
double x(0);
double const *f = &F.at(j, i),
*fik = &Fi.at(k, j),
*fil = &Fi.at(l, j),
mult_k = *fik,
mult_l = *fil;
out += *f * *fik / 2.;
x += *f++ * *fil / 2.;
fik += ndim;
fil += ndim;
for(uword m = j + 1; m < ndim; m++, fik += ndim, fil += ndim){
out += *f * *fik;
x += *f++ * *fil;
}
out *= mult_l;
x *= mult_k;
out += x;
return out;
}
};
if(upper){
/* get index map from index in vech that maps to a lower triangular
* matrix to one that maps to an upper triangular matrix.
* TODO: very slow... */
auto im = [&](uword const idx){
uword co(0), ro(0);
{
uword dum(idx), remain(ndim);
while(dum >= remain){
++co;
dum -= remain--;
}
ro = co + dum;
}
return (ro * (ro + 1L)) / 2L + co;
};
uword r(0);
for(uword j = 0; j < ndim; j++)
for(uword i = j; i < ndim; i++, r++){
uword c(0);
uword const rim = im(r);
for(uword k = 0; c <= r and k < ndim; k++){
out.at(rim, im(c++)) = util(F, Fi)(i, j, k);
for(uword l = k + 1L; l < ndim; l++, c++)
out.at(rim, im(c)) = util(F, Fi)(i, j, k, l);
}
}
} else {
uword r(0);
for(uword j = 0; j < ndim; j++)
for(uword i = j; i < ndim; i++, r++){
uword c(0);
for(uword k = 0; c <= r and k < ndim; k++){
out.at(r, c++) = util(F, Fi)(i, j, k);
for(uword l = k + 1L; l < ndim; l++, c++)
out.at(r, c) = util(F, Fi)(i, j, k, l);
}
}
}
return out;
}
/*** R
options(digits = 4)
require(matrixcalc)
set.seed(2349025)
n <- 10
Z <- drop(rWishart(1, 2 * n, diag(n)))
#####
# simple R-version of the derivative of the Cholesky Factor
fn <- function(xin){
x <- matrix(nr = n, nc = n)
x[lower.tri(x, TRUE)] <- xin
x[upper.tri(x)] <- t(x)[upper.tri(x)]
t(chol(x))
}
dchol_R <- function(Z){
X <- fn(Z[lower.tri(Z, TRUE)])
L <- elimination.matrix(n)
d <- L %*% (diag(n^2) + commutation.matrix(r = n)) %*%
tcrossprod(X %x% diag(n), L)
solve(d)
}
# check function
library(numDeriv)
d <- dchol_R(Z)
jac <- jacobian(fn, Z[lower.tri(Z, TRUE)])
keep <- lower.tri(Z, TRUE)
all.equal(jac[keep, ], d)
#R> [1] TRUE
#####
# C++ version
all.equal(d, dchol(Z))
#R> [1] TRUE
#####
# same function but with the upper triangular parts
fnU <- function(xin){
x <- matrix(nr = n, nc = n)
x[upper.tri(x, TRUE)] <- xin
x[lower.tri(x)] <- t(x)[lower.tri(x)]
chol(x)
}
jac <- jacobian(fnU, Z[upper.tri(Z, TRUE)])
keep <- upper.tri(Z, TRUE)
all.equal(jac[keep, ], dchol(Z, upper = TRUE))
#R> [1] TRUE
#####
# benchmarks
microbenchmark::microbenchmark(
R = dchol_R(Z), `C++` = dchol(Z), `C++ (upper)` = dchol(Z, upper = TRUE),
times = 1000)
#R> Unit: microseconds
#R> expr min lq mean median uq max neval
#R> R 11233.06 11651.94 12810.15 12959.27 13388.65 44049.72 1000
#R> C++ 14.33 15.68 21.67 16.88 30.17 99.71 1000
#R> C++ (upper) 16.09 18.38 26.83 21.55 35.07 1467.63 1000
# the R function is mainly slow due to the use of the matrixcalc package
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment