public
anonymous / gist:5987975
Created

dmvnorm_arma

  • Download Gist
gistfile1.txt
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#include <RcppArmadillo.h>
#include <Rcpp.h>
#include <omp.h>
 
// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec Mahalanobis(arma::mat x, arma::rowvec center, arma::mat cov, int cores){
omp_set_num_threads( cores );
int n = x.n_rows;
arma::mat x_cen;
x_cen.copy_size(x);
#pragma omp parallel for schedule(dynamic)
for (int i=0; i < n; i++) {
x_cen.row(i) = x.row(i) - center;
}
return sum((x_cen * cov.i()) % x_cen, 1);
}
 
 
 
 
// [[Rcpp::export]]
arma::vec dmvnorm_arma ( arma::mat x, arma::rowvec mean, arma::mat sigma, bool log, int cores){
 
arma::vec distval = Mahalanobis(x, mean, sigma, cores);
double logdet = sum(arma::log(arma::eig_sym(sigma)));
double log2pi = 1.8378770664093454835606594728112352797227949472755668;
arma::vec logretval = -( (x.n_cols * log2pi + logdet + distval)/2 ) ;
 
if(log){
return(logretval);
}else {
return(exp(logretval));
}
}

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.