Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@emitanaka
Created April 1, 2021 04:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save emitanaka/1dff21e78e219ae50ea0f2e4e3cc38bd to your computer and use it in GitHub Desktop.
Save emitanaka/1dff21e78e219ae50ea0f2e4e3cc38bd to your computer and use it in GitHub Desktop.
as_tidymatrix <- function(x){
structure(list(.value = as.vector(x),
.row = as.vector(row(x)),
.col = as.vector(col(x))),
row.names = seq_len(length(as.vector(x))),
nrow = nrow(x), # need these attribute if storing it as sparse later
ncol = ncol(x),
class = c("tidymatrix", "data.frame"))
}
as.matrix.tidymatrix <- function(x) {
matrix(data = x$.value,
nrow = attr(x, "nrow"),
ncol = attr(x, "ncol"))
}
`%*%.tidymatrix` <- function(x, y) {
res <- as.matrix(x) %*% as.matrix(y)
as_tidymatrix(res)
}
x <- as_tidymatrix(matrix(1:6, nrow = 2))
y <- as_tidymatrix(matrix(1:9, nrow = 3))
x %*% y
# error -> Primitive("%*%") has some check built-in :(
# requires numeric/complex matrix/vector arguments
@emitanaka
Copy link
Author

emitanaka commented Apr 1, 2021

`+.tidymatrix` <- function(x, y) {
  res <- as.matrix(x) + as.matrix(y)
  as_tidymatrix(res)
}

z <- as_tidymatrix(matrix(1, nrow = 3, ncol = 3))
y + z

@emitanaka
Copy link
Author

This works. Thanks Brenton Wiernik for the info.

`%*%` <- function(x, y) {
  UseMethod("%*%")
}

`%*%.matrix` <- function(x, y) {
  .Primitive("%*%")(x, y)
}

`%*%.tidymatrix` <- function(x, y) {
  res <- as.matrix(x) %*% as.matrix(y)
  as_tidymatrix(res)
}

x <- as_tidymatrix(matrix(1:6, nrow = 2))
y <- as_tidymatrix(matrix(1:9, nrow = 3))
x %*% y

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment