Skip to content

Instantly share code, notes, and snippets.

@davidtedfordholt
Created July 30, 2021 14:03
Show Gist options
  • Save davidtedfordholt/3c7e0afdae183bdb0ce705543e561a00 to your computer and use it in GitHub Desktop.
Save davidtedfordholt/3c7e0afdae183bdb0ce705543e561a00 to your computer and use it in GitHub Desktop.
A function to allow sampling of groups in dataframes, tibbles and tsibbles, in order to create easy train and test sets on time series data
#' Sample groups randomly in a grouped data frame
#'
#' @param .data dataframe
#' @param ... names of key variables to define groups. If unspecified, dataframe must be grouped.
#' @param n if integer, number of unique groups to keep. If between 1 and 0, proportion of unique groups to keep.
#'
#' @return tibble of all rows for sampled groups
#' @export
#'
#' @examples
#' data <- pull_tender_data_for_forecasting(
#' level = 'zip3',
#' metric = 'tenders',
#' first_date = '2020-11-01')
#' data %>%
#' distinct(id, direction)
#' data %>%
#' group_by(id, direction) %>%
#' sample_groups(20)
sample_groups <-
function(.data, ..., n) {
UseMethod("sample_groups")
}
#' @export
sample_groups.data.frame <-
function(.data, ..., n) {
keys <- unname(unlist(purrr::map_chr(rlang::exprs(...), as.character)))
if (length(keys) == 0) {
rlang::abort("`sample_groups` requires either a grouped dataframe or columns to define groups")
}
sample_groups_engine(.data, keys, n)
}
#' @export
sample_groups.tbl_ts <-
function(.data, ..., n) {
dots <- unname(unlist(purrr::map_chr(rlang::exprs(...), as.character)))
if (length(dots) != 0) {
keys <- dots
} else {
keys <- tsibble::key_vars(.data)
}
if (length(keys) == 0) {
rlang::abort("`sample_groups` requires either a grouped dataframe or columns to define groups")
}
sample_groups_engine(.data, keys, n) %>%
tsibble::as_tsibble(
key = tsibble::key_vars(.data),
index = !!tsibble::index_var(.data)) %>%
dplyr::ungroup()
}
#' @export
sample_groups.grouped_df <-
function(.data, ..., n) {
dots <- unname(unlist(purrr::map_chr(rlang::exprs(...), as.character)))
if (length(dots) != 0) {
keys <- dots
rlang::inform("sampling from specified variables, rather than grouped variables")
} else {
keys <- dplyr::group_vars(.data)
}
sample_groups_engine(.data, keys, n)
}
#' @export
sample_groups.grouped_ts <-
function(.data, ..., n) {
dots <- unname(unlist(purrr::map_chr(rlang::exprs(...), as.character)))
if (length(dots) != 0) {
keys <- c(dots, tsibble::key_vars(.data))
rlang::inform("sampling from key(s) and specified variables, rather than grouped variables")
} else {
if (any(tsibble::key_vars(.data) %not_in% dplyr::group_vars(.data))) {
rlang::inform("key variables are being added to the grouped variables")
}
keys <- unique(c(dplyr::group_vars(.data), tsibble::key_vars(.data)))
}
if (length(keys) == 0) {
rlang::abort("`sample_groups` requires either a grouped dataframe or columns to define groups")
}
sample_groups_engine(.data, keys, n) %>%
tsibble::as_tsibble(
key = tsibble::key_vars(.data),
index = !!tsibble::index_var(.data))
}
sample_groups_engine <-
function(.data, keys, n) {
# parse n as either number of samples or proportion of samples
if (n < 1 && n > 0) {
n <- ceiling(length(keys) * n)
} else if (n <= 0 || as.integer(n) != n) {
rlang::abort("`n` must be an integer greater than 0 or a number between 0 and 1.")
}
# sample n groups
.data %>%
as_tibble() %>%
ungroup() %>%
select(all_of(keys)) %>%
distinct() %>%
slice_sample(n = n) %>%
left_join(.data, by = keys) %>%
select(all_of(names(.data))) %>%
group_by(across(group_vars(.data)))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment