Skip to content

Instantly share code, notes, and snippets.

Last active August 5, 2020 13:45
Show Gist options
  • Save boennecd/4fcd03c03443ba1b5bd49aa5e24475f3 to your computer and use it in GitHub Desktop.
Save boennecd/4fcd03c03443ba1b5bd49aa5e24475f3 to your computer and use it in GitHub Desktop.
Computes the commutation matrix fast-ish.
#include <Rcpp.h>
#include <memory.h>
std::unique_ptr<size_t[]> get_commutation_unequal_vec
(unsigned const n, unsigned const m, bool const transpose){
unsigned const nm = n * m,
nnm_p1 = n * nm + 1L,
nm_pm = nm + m;
std::unique_ptr<size_t[]> out(new size_t[nm]);
size_t * const o_begin = out.get();
size_t idx = 0L;
for(unsigned i = 0; i < n; ++i, idx += nm_pm){
size_t idx1 = idx;
for(unsigned j = 0; j < m; ++j, idx1 += nnm_p1)
*(o_begin + idx1 / nm) = (idx1 % nm);
*(o_begin + idx1 % nm) = (idx1 / nm);
return out;
// [[Rcpp::export(rng = false)]]
Rcpp::NumericVector commutation_dot
(unsigned const n, unsigned const m, Rcpp::NumericVector x,
bool const transpose){
size_t const nm = n * m;
Rcpp::NumericVector out(nm);
auto const indices = get_commutation_unequal_vec(n, m, transpose);
for(size_t i = 0; i < nm; ++i)
out[i] = x[*(indices.get() +i )];
return out;
Rcpp::NumericMatrix get_commutation_unequal
(unsigned const n, unsigned const m){
unsigned const nm = n * m,
nnm_p1 = n * nm + 1L,
nm_pm = nm + m;
Rcpp::NumericMatrix out(nm, nm);
double * o = &out[0];
for(unsigned i = 0; i < n; ++i, o += nm_pm){
double *o1 = o;
for(unsigned j = 0; j < m; ++j, o1 += nnm_p1)
*o1 = 1.;
return out;
Rcpp::NumericMatrix get_commutation_equal(unsigned const m){
unsigned const mm = m * m,
mmm = mm * m,
mmm_p1 = mmm + 1L,
mm_pm = mm + m;
Rcpp::NumericMatrix out(mm, mm);
double * const o = &out[0];
unsigned inc_i(0L);
for(unsigned i = 0; i < m; ++i, inc_i += m){
double *o1 = o + inc_i + i * mm,
*o2 = o + i + inc_i * mm;
for(unsigned j = 0; j < i; ++j, o1 += mmm_p1, o2 += mm_pm){
*o1 = 1.;
*o2 = 1.;
*o1 += 1.;
return out;
// [[Rcpp::export(rng = false)]]
Rcpp::NumericMatrix get_commutation(unsigned const n, unsigned const m){
if(n == m)
return get_commutation_equal(n);
return get_commutation_unequal(n, m);
/*** R
options(digits = 3)
for(i in 1:5)
for(j in 1:5){
x <- rnorm(i * j)
K <- get_commutation(i, j)
stopifnot(all.equal(drop(K %*% x),
commutation_dot(i, j, x, FALSE)))
stopifnot(all.equal(drop(crossprod(K, x)),
commutation_dot(i, j, x, TRUE)))
i <- 20L
j <- 10L
x <- rnorm(i * j)
`Using dense` = get_commutation(i, j) %*% x,
`Using sparse` = commutation_dot(i, j, x, FALSE))
#R> Unit: microseconds
#R> expr min lq mean median uq max neval
#R> Using dense 131.03 131.79 133.78 132.53 133.14 177.61 100
#R> Using sparse 2.24 2.43 2.94 2.87 3.11 9.76 100
get_commutation_R <- function(m){
out <- matrix(0., nrow = m * m, ncol = m * m)
for(i in 1:m)
for(j in 1:m)
out[(i - 1L) * m + j, (j - 1L) * m + i] <- 1.
# equal
for(i in 2:10){
stopifnot(all.equal(commutation.matrix(i), get_commutation_R(i)))
stopifnot(all.equal(commutation.matrix(i), get_commutation (i, i)))
# unequal
for(i in 3:10)
for(j in 2:(i - 1L)){
stopifnot(all.equal(commutation.matrix(i, j), get_commutation (i, j)))
stopifnot(all.equal(commutation.matrix(j, i), get_commutation (j, i)))
# benchmark: equal
matrixcalc = commutation.matrix(4L),
R = get_commutation_R (4L),
cpp = get_commutation (4L, 4L),
times = 1000)
#R> Unit: nanoseconds
#R> expr min lq mean median uq max neval
#R> matrixcalc 403333 422449 461087 433004 444748 3127575 1000
#R> R 4205 4661 5330 5224 5624 54460 1000
#R> cpp 638 848 1288 1264 1510 20620 1000
matrixcalc = commutation.matrix(20L),
R = get_commutation_R (20L),
cpp = get_commutation (20L, 20L),
times = 25)
#R> Unit: microseconds
#R> expr min lq mean median uq max neval
#R> matrixcalc 531256.1 565798.3 571792 572429.9 579425.0 622050 25
#R> R 159.4 172.4 184 179.4 183.1 266 25
#R> cpp 39.1 44.3 103 48.4 52.9 1381 25
# benchmark: unequal
matrixcalc = commutation.matrix(17L, 20L),
cpp = get_commutation (17L, 20L),
times = 25)
#R> Unit: microseconds
#R> expr min lq mean median uq max neval
#R> matrixcalc 318443 324061 338861 329943.3 358941 363992 25
#R> cpp 29 37 140 42.1 46 1368 25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment