Skip to content

Instantly share code, notes, and snippets.

@vsbuffalo
Last active August 29, 2015 14:17
Show Gist options
  • Save vsbuffalo/1f5dd54959e9cbef491a to your computer and use it in GitHub Desktop.
Save vsbuffalo/1f5dd54959e9cbef491a to your computer and use it in GitHub Desktop.
#include <RcppEigen.h>
using namespace Eigen;
using namespace Rcpp;
// [[Rcpp::depends(RcppEigen)]]
// [[Rcpp::export]]
MatrixXd sherman(const Map<MatrixXd> Ap, const Map<MatrixXd> u, const Map<MatrixXd> v) {
// R formula for reference:
//Ap - (Ap %*% u %*% t(v) %*% Ap)/drop(1 + t(v) %*% Ap %*% u)
MatrixXd a(Ap.rows(), Ap.cols());
double b;
MatrixXd vt(1, Ap.rows());
vt = v.transpose();
return Ap - (Ap * u * vt * Ap)/(1 + (vt * Ap * u)(0, 0));
}
# Sherman-Morrison formula for updating an inverted matrix
# includes R and C++ implementations, as well as benchmarks
library(microbenchmark)
library(RcppEigen)
library(Rcpp)
sourceCpp("sherman.cpp")
EPS <- 1e-8
set.seed(0)
# text matrix
m <- matrix(sample(1:100, 25), nrow=5)
sherman_r <- function(Ap, u, v) {
Ap - (Ap %*% u %*% t(v) %*% Ap)/drop(1 + t(v) %*% Ap %*% u)
}
## Test 1
# all zeros
u <- cbind(rep(0, 5))
v <- cbind(rep(0, 5))
ms <- solve(m)
abs(sherman(ms, u, v) - solve(m)) < EPS
sherman(ms, u, v)
sherman_r(ms, u, v)
## Test 2
# change 3 row, 2 col
u1 <- cbind(runif(5))
v1 <- cbind(c(0, 0, 1, 0, 0))
ms <- solve(m)
sherman(ms, u1, v1)
sherman_r(ms, u1, v1)
solve(m + u1 %*% t(v1))
# some floating point error
abs(sherman(ms, u1, v1) - solve(m + u1 %*% t(v1))) < EPS
# It works!
## Benchmarking
m <- matrix(sample(as.double(1:100000000), 1e6), nrow=1000)
# as before, change 3 row, 2 col
u <- cbind(rnorm(nrow(m)))
v <- cbind(rep(0, nrow(m))); v1[3, 1] <- 1L
ms <- solve(m)
use_solve <- function(m, u, v) {
solve(m + u %*% t(v))
}
use_sherman <- function(ms, u, v) {
sherman(ms, u, v)
}
use_sherman_r<- function(ms, u, v) {
sherman_r(ms, u, v)
}
res <- microbenchmark(cpp_sherman={ use_sherman(ms, u, v) },
r_sherman={ use_sherman_r(ms, u, v) },
solve={ use_solve(m, u, v) }, times=30)
pres <- print(res)
pres$index <- pres$mean/min(pres$mean)
# expr min lq mean median uq max neval index
# cpp_sherman 225.5816 233.3039 267.2072 239.3411 291.8987 481.3756 30 1.000000
# r_sherman 940.6525 981.0046 1142.3853 1102.2758 1268.8268 1551.9803 30 4.275278
# solve 2516.1753 2556.3144 2703.4512 2613.0607 2749.8054 3290.5046 30 10.117433
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment