Skip to content

Instantly share code, notes, and snippets.

@ledell
Last active August 31, 2015 18:53
Show Gist options
  • Save ledell/bd4e227d4e5ff426c41d to your computer and use it in GitHub Desktop.
Save ledell/bd4e227d4e5ff426c41d to your computer and use it in GitHub Desktop.
# Update for SuperLearner::CVFolds function that enables stratification by outcome and cluster ID
CVFolds2 <- function (N, id, Y, cvControl) {
if (!is.null(cvControl$validRows)) {
return(cvControl$validRows)
}
stratifyCV <- cvControl$stratifyCV
shuffle <- cvControl$shuffle
V <- cvControl$V
if (!stratifyCV) { ### Not Stratified
if (shuffle) { ## Not Stratified, Shuffled
if (is.null(id)) { #Not stratified, Shuffled, Not by ID
validRows <- split(sample(1:N), rep(1:V, length = N))
}
else { #Not stratified, Shuffled, by ID
n.id <- length(unique(id))
id.split <- split(sample(1:n.id), rep(1:V, length = n.id))
validRows <- vector("list", V)
for (v in seq(V)) {
validRows[[v]] <- which(id %in% unique(id)[id.split[[v]]])
}
}
}
else { ## Not Stratified, Not Shuffled
if (is.null(id)) { #Not Stratified, Not Shuffled, Not by ID
validRows <- split(1:N, rep(1:V, length = N))
}
else { #Not Stratified, Not Shuffled, by ID
n.id <- length(unique(id))
id.split <- split(1:n.id, rep(1:V, length = n.id))
validRows <- vector("list", V)
for (v in seq(V)) {
validRows[[v]] <- which(id %in% unique(id)[id.split[[v]]])
}
}
}
}
else { ### Stratified
if (length(unique(Y)) != 2) {
stop("stratifyCV only implemented for binary Y")
}
if (sum(Y) < V | sum(!Y) < V) {
stop("number of (Y=1) or (Y=0) is less than the number of folds")
}
if (shuffle) { ## Stratified, Shuffled
if (is.null(id)) { #Stratified, Shuffled, not by ID
wiY0 <- which(Y == 0)
wiY1 <- which(Y == 1)
rowsY0 <- split(sample(wiY0), rep(1:V, length = length(wiY0)))
rowsY1 <- split(sample(wiY1), rep(1:V, length = length(wiY1)))
validRows <- vector("list", length = V)
names(validRows) <- paste(seq(V))
for (vv in seq(V)) {
validRows[[vv]] <- c(rowsY0[[vv]], rowsY1[[vv]])
}
}
else { #Stratified, Shuffled, by ID
within.split <- suppressWarnings(tapply(1:N,
INDEX = Y, FUN = split, 1))
id.Y1 <- unique(id[within.split[[2]]])
id.notY1 <- setdiff(unique(id),id.Y1)
n.id.Y1 <- length(id.Y1)
n.id.notY1 <- length(id.notY1)
id.Y1.split <- split(sample(1:n.id.Y1), rep(1:V, length = n.id.Y1))
id.notY1.split <- split(sample(1:n.id.notY1), rep(1:V, length = n.id.notY1))
validRows <- vector("list", V)
for (v in seq(V)) {
validRows[[v]] <- c(which(id %in% id.Y1[id.Y1.split[[v]]]),
which(id %in% id.notY1[id.notY1.split[[v]]]))
}
}
}
else { ## Stratified, Not Shuffled
if (is.null(id)) {
within.split <- suppressWarnings(tapply(1:N,
INDEX = Y, FUN = split, rep(1:V)))
validRows <- vector("list", length = V)
names(validRows) <- paste(seq(V))
for (vv in seq(V)) {
validRows[[vv]] <- c(within.split[[1]][[vv]],
within.split[[2]][[vv]])
}
}
else { #Stratified, Not Shuffled, by ID
within.split <- suppressWarnings(tapply(1:N,
INDEX = Y, FUN = split, 1))
id.Y1 <- unique(id[within.split[[2]]])
id.notY1 <- setdiff(unique(id),id.Y1)
n.id.Y1 <- length(id.Y1)
n.id.notY1 <- length(id.notY1)
id.Y1.split <- split(1:n.id.Y1, rep(1:V, length = n.id.Y1))
id.notY1.split <- split(1:n.id.notY1, rep(1:V, length = n.id.notY1))
validRows <- vector("list", V)
for (v in seq(V)) {
validRows[[v]] <- c(which(id %in% id.Y1[id.Y1.split[[v]]]),
which(id %in% id.notY1[id.notY1.split[[v]]]))
}
}
}
}
return(validRows)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment