Skip to content

Instantly share code, notes, and snippets.

@mrdwab
Last active April 10, 2019 06:29
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save mrdwab/933ffeaa7a1d718bd10a to your computer and use it in GitHub Desktop.
Save mrdwab/933ffeaa7a1d718bd10a to your computer and use it in GitHub Desktop.
Attempt to rewrite stratified for `data.table`. The `data.frame` version can be found at https://gist.github.com/mrdwab/6424112
stratifiedDT <- function(indt, group, size, select = NULL,
replace = FALSE, keep.rownames = FALSE,
bothSets = FALSE) {
if (is.numeric(group)) group <- names(indt)[group]
if (!is.data.table(indt)) indt <- as.data.table(
indt, keep.rownames = keep.rownames)
if (is.null(select)) {
indt <- indt
} else {
if (is.null(names(select))) stop("'select' must be a named list")
if (!all(names(select) %in% names(indt)))
stop("Please verify your 'select' argument")
temp <- vapply(names(select), function(x)
indt[[x]] %in% select[[x]], logical(nrow(indt)))
indt <- indt[rowSums(temp) == length(select), ]
}
df.table <- indt[, .N, by = group]
df.table
if (length(size) > 1) {
if (length(size) != nrow(df.table))
stop("Number of groups is ", nrow(df.table),
" but number of sizes supplied is ", length(size))
if (is.null(names(size))) {
stop("size should be entered as a named vector")
} else {
ifelse(all(names(size) %in% do.call(
paste, df.table[, group, with = FALSE])),
n <- merge(
df.table,
setnames(data.table(names(size), ss = size),
c(group, "ss")), by = group),
stop("Named vector supplied with names ",
paste(names(size), collapse = ", "),
"\n but the names for the group levels are ",
do.call(paste, c(unique(
df.table[, group, with = FALSE]), collapse = ", "))))
}
} else if (size < 1) {
n <- df.table[, ss := round(N * size, digits = 0)]
} else if (size >= 1) {
if (all(df.table$N >= size) || isTRUE(replace)) {
n <- cbind(df.table, ss = size)
} else {
message(
"Some groups\n---",
do.call(paste, c(df.table[df.table$N < size][, group, with = FALSE],
sep = ".", collapse = ", ")),
"---\ncontain fewer observations",
" than desired number of samples.\n",
"All observations have been returned from those groups.")
n <- cbind(df.table, ss = pmin(df.table$N, size))
}
}
setkeyv(indt, group)
setkeyv(n, group)
indt[, .RNID := sequence(nrow(indt))]
out1 <- indt[indt[n, list(
.RNID = sample(.RNID, ss, replace)), by = .EACHI]$`.RNID`]
if (isTRUE(bothSets)) {
out2 <- indt[!.RNID %in% out1$`.RNID`]
indt[, .RNID := NULL]
out1[, .RNID := NULL]
out2[, .RNID := NULL]
list(SAMP1 = out1, SAMP2 = out2)
} else {
indt[, .RNID := NULL]
out1[, .RNID := NULL][]
}
}
@mrdwab
Copy link
Author

mrdwab commented Sep 18, 2014

Changing the last few rows to use .EACHI, like in the following, makes a big difference:

setkey(indt, GrpKey)
indt[, .RNID := sequence(nrow(indt))]
matchdt <- data.table(GrpKey = names(n), 
                    .ss = n, key = "GrpKey")
out <- indt[indt[matchdt, list(
.RNID = sample(.RNID, .ss, FALSE)), by = .EACHI]$`.RNID`]
out[, c("GrpKey", ".RNID") := NULL][]

Here are some new benchmarks:

microbenchmark(df1(), dt1(), times = 20)
# Unit: milliseconds
#   expr      min       lq   median       uq       max neval
#  df1() 86.21819 87.62596 89.66106 93.42402 115.06335    20
#  dt1() 38.72521 38.90182 39.16075 39.77761  56.36423    20
microbenchmark(df2(), dt2(), times = 10)
# Unit: milliseconds
#   expr      min        lq    median       uq       max neval
#  df2() 702.7874 715.86875 735.96753 757.0229 954.10381    10
#  dt2()  62.0462  62.74495  63.08022  64.7172  69.79193    10
microbenchmark(df3(), dt3(), times = 3)
# Unit: milliseconds
#   expr        min         lq     median         uq       max neval
#  df3() 1529.44492 1553.13860 1576.83227 1581.50246 1586.1726     3
#  dt3()   84.64663   86.77603   88.90543   89.87317   90.8409     3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment