Skip to content

Instantly share code, notes, and snippets.

Created July 12, 2013 21:24
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/5987975 to your computer and use it in GitHub Desktop.
Save anonymous/5987975 to your computer and use it in GitHub Desktop.
dmvnorm_arma
#include <RcppArmadillo.h>
#include <Rcpp.h>
#include <omp.h>
// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec Mahalanobis(arma::mat x, arma::rowvec center, arma::mat cov, int cores){
omp_set_num_threads( cores );
int n = x.n_rows;
arma::mat x_cen;
x_cen.copy_size(x);
#pragma omp parallel for schedule(dynamic)
for (int i=0; i < n; i++) {
x_cen.row(i) = x.row(i) - center;
}
return sum((x_cen * cov.i()) % x_cen, 1);
}
// [[Rcpp::export]]
arma::vec dmvnorm_arma ( arma::mat x, arma::rowvec mean, arma::mat sigma, bool log, int cores){
arma::vec distval = Mahalanobis(x, mean, sigma, cores);
double logdet = sum(arma::log(arma::eig_sym(sigma)));
double log2pi = 1.8378770664093454835606594728112352797227949472755668;
arma::vec logretval = -( (x.n_cols * log2pi + logdet + distval)/2 ) ;
if(log){
return(logretval);
}else {
return(exp(logretval));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment