Skip to content

Instantly share code, notes, and snippets.

@moredatapls
Last active June 17, 2019 09:27
Show Gist options
  • Save moredatapls/f9a180a85443b4dd1d5249f8b5d69400 to your computer and use it in GitHub Desktop.
Save moredatapls/f9a180a85443b4dd1d5249f8b5d69400 to your computer and use it in GitHub Desktop.
R: sample rows from a data.table with different fractions per class

I needed a way to easily sample rows from a data.frame with different fractions per class (as opposed to dplyr::sample_frac() that only supports a single fraction). So I wrote this little function below.

#' Samples random rows from a data.frame based on per-class fractions.
#' `dplyr::sample_frac` only supports a single fraction for the entire dataset.
#' This function works best if there are little classes because the fractions
#' have to be specified for all the classes.
#'
#' The notation is consistent with the one used for `dplyr::sample_frac`.
#'
#' @param tbl The data.frame to sample from
#' @param classCol The column containing the class labels
#' @param sizes The per-class fractions as a `list("class name" = fraction)`
#' @param ... Other parameters to pass to `dplyr::sample_frac`, such as `replace = TRUE`
#'
#' @return A subset of `tbl`
#'
sample_frac_class <- function(tbl, classCol, sizes, ...) {
class_ <- as.factor(dplyr::pull(tbl, !!dplyr::enquo(classCol)))
stopifnot(
is.list(sizes),
all(levels(class_) %in% names(sizes)),
all(sapply(sizes, function(size) size >= 0 & size <= 1))
)
sample_ <- function(clazz) {
dplyr::sample_frac(tbl[which(tbl$class == clazz),], size = sizes[[clazz]], ...)
}
do.call(rbind, lapply(names(sizes), sample_))
}
context("sampling")
testthat::test_that("per-class sampling works", {
set.seed(42)
# define input and output
data <- data.frame(
id = c(1, 2, 3, 4, 5),
class = c(1, 2, 2, 2, 3),
val = c("abc", "def", "geh", "ijk", "lmn")
)
sizes <- list("1" = 1, "2" = 2/3, "3" = 0)
expected <- data.frame(
id = c(1, 2, 4),
class = c(1, 2, 2),
val = factor(c("abc", "def", "ijk"), levels = levels(data$val))
)
expected_replace <- data.frame(
id = c(1, 2, 2),
class = c(1, 2, 2),
val = factor(c("abc", "def", "def"), levels = levels(data$val))
)
# run it
actual <- sample_frac_class(data, class, sizes)
actual_replace <- sample_frac_class(data, class, sizes, replace = TRUE)
# check it
expect(isTRUE(dplyr::all_equal(expected, actual)), "sampling without replacement")
expect(isTRUE(dplyr::all_equal(expected_replace, actual_replace)), "sampling without replacement")
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment