Skip to content

Instantly share code, notes, and snippets.

@mrdwab
Last active April 10, 2019 06:29
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save mrdwab/933ffeaa7a1d718bd10a to your computer and use it in GitHub Desktop.
Save mrdwab/933ffeaa7a1d718bd10a to your computer and use it in GitHub Desktop.
Attempt to rewrite stratified for `data.table`. The `data.frame` version can be found at https://gist.github.com/mrdwab/6424112
stratifiedDT <- function(indt, group, size, select = NULL,
replace = FALSE, keep.rownames = FALSE,
bothSets = FALSE) {
if (is.numeric(group)) group <- names(indt)[group]
if (!is.data.table(indt)) indt <- as.data.table(
indt, keep.rownames = keep.rownames)
if (is.null(select)) {
indt <- indt
} else {
if (is.null(names(select))) stop("'select' must be a named list")
if (!all(names(select) %in% names(indt)))
stop("Please verify your 'select' argument")
temp <- vapply(names(select), function(x)
indt[[x]] %in% select[[x]], logical(nrow(indt)))
indt <- indt[rowSums(temp) == length(select), ]
}
df.table <- indt[, .N, by = group]
df.table
if (length(size) > 1) {
if (length(size) != nrow(df.table))
stop("Number of groups is ", nrow(df.table),
" but number of sizes supplied is ", length(size))
if (is.null(names(size))) {
stop("size should be entered as a named vector")
} else {
ifelse(all(names(size) %in% do.call(
paste, df.table[, group, with = FALSE])),
n <- merge(
df.table,
setnames(data.table(names(size), ss = size),
c(group, "ss")), by = group),
stop("Named vector supplied with names ",
paste(names(size), collapse = ", "),
"\n but the names for the group levels are ",
do.call(paste, c(unique(
df.table[, group, with = FALSE]), collapse = ", "))))
}
} else if (size < 1) {
n <- df.table[, ss := round(N * size, digits = 0)]
} else if (size >= 1) {
if (all(df.table$N >= size) || isTRUE(replace)) {
n <- cbind(df.table, ss = size)
} else {
message(
"Some groups\n---",
do.call(paste, c(df.table[df.table$N < size][, group, with = FALSE],
sep = ".", collapse = ", ")),
"---\ncontain fewer observations",
" than desired number of samples.\n",
"All observations have been returned from those groups.")
n <- cbind(df.table, ss = pmin(df.table$N, size))
}
}
setkeyv(indt, group)
setkeyv(n, group)
indt[, .RNID := sequence(nrow(indt))]
out1 <- indt[indt[n, list(
.RNID = sample(.RNID, ss, replace)), by = .EACHI]$`.RNID`]
if (isTRUE(bothSets)) {
out2 <- indt[!.RNID %in% out1$`.RNID`]
indt[, .RNID := NULL]
out1[, .RNID := NULL]
out2[, .RNID := NULL]
list(SAMP1 = out1, SAMP2 = out2)
} else {
indt[, .RNID := NULL]
out1[, .RNID := NULL][]
}
}
@mrdwab
Copy link
Author

mrdwab commented Aug 13, 2014

There seems to be a big slowdown (in both the data.frame and data.table versions) when more groups are factored into the equation. But the difference is quite pronounced in the data.table version.

.EACHI makes the data.table version MUCH faster!!!

set.seed(1)
n <- 100000
dat1 <- data.frame(ID = sequence(n), 
                   A = sample(c(letters, LETTERS), n, TRUE), 
                   B = rnorm(n), C = abs(round(rnorm(n), digits = 1)), 
                   D = sample(c(month.name), n, TRUE), 
                   E = sample(c("M", "F"), n, TRUE))

library(microbenchmark)

df1 <- function() stratified(dat1, "A", 5)
df2 <- function() stratified(dat1, c("A", "D"), 5)
df3 <- function() stratified(dat1, c("A", "D", "E"), 5)
dt1 <- function() stratifiedDT(dat1, "A", 5)
dt2 <- function() stratifiedDT(dat1, c("A", "D"), 5)
dt3 <- function() stratifiedDT(dat1, c("A", "D", "E"), 5)

microbenchmark(df1(), dt1(), times = 20)
# Unit: milliseconds
#   expr      min       lq   median       uq      max neval
#  df1() 198.2978 213.7986 235.4996 277.2846 431.9715    20
#  dt1() 512.2372 594.8178 634.0051 714.0381 860.4061    20

microbenchmark(df2(), dt2(), times = 10)
# Unit: seconds
#   expr      min       lq   median       uq      max neval
#  df2() 1.535403 1.575516 1.651618 1.883049 1.948000    10
#  dt2() 5.942521 6.082964 6.307843 6.530196 6.788508    10

microbenchmark(df3(), dt3(), times = 3)
# Unit: seconds
#   expr       min        lq    median        uq       max neval
#  df3()  3.562429  3.714761  3.867093  3.946384  4.025676     3
#  dt3() 12.955109 13.269335 13.583561 13.662457 13.741354     3

@mrdwab
Copy link
Author

mrdwab commented Sep 18, 2014

Changing the last few rows to use .EACHI, like in the following, makes a big difference:

setkey(indt, GrpKey)
indt[, .RNID := sequence(nrow(indt))]
matchdt <- data.table(GrpKey = names(n), 
                    .ss = n, key = "GrpKey")
out <- indt[indt[matchdt, list(
.RNID = sample(.RNID, .ss, FALSE)), by = .EACHI]$`.RNID`]
out[, c("GrpKey", ".RNID") := NULL][]

Here are some new benchmarks:

microbenchmark(df1(), dt1(), times = 20)
# Unit: milliseconds
#   expr      min       lq   median       uq       max neval
#  df1() 86.21819 87.62596 89.66106 93.42402 115.06335    20
#  dt1() 38.72521 38.90182 39.16075 39.77761  56.36423    20
microbenchmark(df2(), dt2(), times = 10)
# Unit: milliseconds
#   expr      min        lq    median       uq       max neval
#  df2() 702.7874 715.86875 735.96753 757.0229 954.10381    10
#  dt2()  62.0462  62.74495  63.08022  64.7172  69.79193    10
microbenchmark(df3(), dt3(), times = 3)
# Unit: milliseconds
#   expr        min         lq     median         uq       max neval
#  df3() 1529.44492 1553.13860 1576.83227 1581.50246 1586.1726     3
#  dt3()   84.64663   86.77603   88.90543   89.87317   90.8409     3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment