Last active
August 29, 2015 14:17
-
-
Save vsbuffalo/1f5dd54959e9cbef491a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <RcppEigen.h> | |
using namespace Eigen; | |
using namespace Rcpp; | |
// [[Rcpp::depends(RcppEigen)]] | |
// [[Rcpp::export]] | |
MatrixXd sherman(const Map<MatrixXd> Ap, const Map<MatrixXd> u, const Map<MatrixXd> v) { | |
// R formula for reference: | |
//Ap - (Ap %*% u %*% t(v) %*% Ap)/drop(1 + t(v) %*% Ap %*% u) | |
MatrixXd a(Ap.rows(), Ap.cols()); | |
double b; | |
MatrixXd vt(1, Ap.rows()); | |
vt = v.transpose(); | |
return Ap - (Ap * u * vt * Ap)/(1 + (vt * Ap * u)(0, 0)); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Sherman-Morrison formula for updating an inverted matrix | |
# includes R and C++ implementations, as well as benchmarks | |
library(microbenchmark) | |
library(RcppEigen) | |
library(Rcpp) | |
sourceCpp("sherman.cpp") | |
EPS <- 1e-8 | |
set.seed(0) | |
# text matrix | |
m <- matrix(sample(1:100, 25), nrow=5) | |
sherman_r <- function(Ap, u, v) { | |
Ap - (Ap %*% u %*% t(v) %*% Ap)/drop(1 + t(v) %*% Ap %*% u) | |
} | |
## Test 1 | |
# all zeros | |
u <- cbind(rep(0, 5)) | |
v <- cbind(rep(0, 5)) | |
ms <- solve(m) | |
abs(sherman(ms, u, v) - solve(m)) < EPS | |
sherman(ms, u, v) | |
sherman_r(ms, u, v) | |
## Test 2 | |
# change 3 row, 2 col | |
u1 <- cbind(runif(5)) | |
v1 <- cbind(c(0, 0, 1, 0, 0)) | |
ms <- solve(m) | |
sherman(ms, u1, v1) | |
sherman_r(ms, u1, v1) | |
solve(m + u1 %*% t(v1)) | |
# some floating point error | |
abs(sherman(ms, u1, v1) - solve(m + u1 %*% t(v1))) < EPS | |
# It works! | |
## Benchmarking | |
m <- matrix(sample(as.double(1:100000000), 1e6), nrow=1000) | |
# as before, change 3 row, 2 col | |
u <- cbind(rnorm(nrow(m))) | |
v <- cbind(rep(0, nrow(m))); v1[3, 1] <- 1L | |
ms <- solve(m) | |
use_solve <- function(m, u, v) { | |
solve(m + u %*% t(v)) | |
} | |
use_sherman <- function(ms, u, v) { | |
sherman(ms, u, v) | |
} | |
use_sherman_r<- function(ms, u, v) { | |
sherman_r(ms, u, v) | |
} | |
res <- microbenchmark(cpp_sherman={ use_sherman(ms, u, v) }, | |
r_sherman={ use_sherman_r(ms, u, v) }, | |
solve={ use_solve(m, u, v) }, times=30) | |
pres <- print(res) | |
pres$index <- pres$mean/min(pres$mean) | |
# expr min lq mean median uq max neval index | |
# cpp_sherman 225.5816 233.3039 267.2072 239.3411 291.8987 481.3756 30 1.000000 | |
# r_sherman 940.6525 981.0046 1142.3853 1102.2758 1268.8268 1551.9803 30 4.275278 | |
# solve 2516.1753 2556.3144 2703.4512 2613.0607 2749.8054 3290.5046 30 10.117433 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment