Skip to content

Instantly share code, notes, and snippets.

@moodymudskipper
Created August 4, 2023 12:17
Show Gist options
  • Save moodymudskipper/6347418d82fea2160178422aa574dec2 to your computer and use it in GitHub Desktop.
Save moodymudskipper/6347418d82fea2160178422aa574dec2 to your computer and use it in GitHub Desktop.
summarize_with_margins
#' Grouped operations with margins
#'
#' * `summarize_with_margins()` is similar to summarize but creates an additional
#' `"(all)"` category for each grouping variable. It assumes a hierarchy of groups
#' and the higher level groups should be provided first. Regular groups, not
#' used for totals/subtotals can be provided through the `.more_groups` arg
#' and will be used as parent groups.
#' * `mutate_over_margins()` is meant to be applied right after `summarize_with_margins(, .groups = "keep")`
#' when we want a window function to be applied by grouping set, it detects grouping
#' sets based on `"(all)"` values in grouping columns.
#'
#' Categories are converted to character, that's necessary to add a "(all)" category,
#' missing values are kept as missing.
#'
#' We expect the `.by` argument to often be a subset of the categories below (enumerated here for
#' easy copy and paste) :
#' * Time: year, month, week, timestamp or jahr, monat, kalenderwoche, kw, zeitstempel
#' * Organization structure : gf, vertrieb_level_02, vertrieb_level_03, vertrieb_level_04,
#' vertrieb_level_05, vertrieb_level_06, vertrieb_level_07, vermittler_name,
#' vermittler_nr
#' * Customer: partner_nr, kanton, plz
#' * Insurance Coverage: contract_number, contract_product, contract_categor
#'
#'
#' @param .data A data frame (lazy or not)
#' @param ... Name-value pairs as used in `dplyr::summarize()` and `dplyr::mutate()`
#' @param .by grouping columns, starting from the highest parent, optional if the
#' data is already grouped.
#' @param .groups What to do with groups after the transformation : "drop" or "keep".
#' By default they are dropped, which is different from the default behavior from dplyr that
#' keeps them with `mutate()` and peels one off with `summarize()`.
#' `.groups = "keep"` will also keep groups if they are provided with `.by`,
#' another divergence from dplyr, designed to facillitate the use of `mutate_over_margins()`
#'
#' @return a data frame
#' @export
#'
#' @examples
#' df <- summarize_with_margins(mtcars, mpg = mean(mpg, na.rm = TRUE), .by = c(cyl, vs, am))
#' df
#' # here we repeat the `.by`
#' mutate_over_margins(df, n = n(), min_flag = mpg == min(mpg), .by = c(cyl, vs, am))
#'
#' # use `.groups = "keep"` in `summarize_with_margins()` to avoid this repetition
#' mtcars |>
#' summarize_with_margins(mpg = mean(mpg, na.rm = TRUE), .by = c(cyl, vs, am), .groups = "keep") |>
#' mutate_over_margins(n = n(), min_flag = mpg == min(mpg))
summarize_with_margins <- function(.data, ..., .by = NULL, .more_groups = NULL, .groups = "drop") {
.groups <- rlang::arg_match(.groups, c("drop", "keep"))
group_vars <- fetch_group_vars(.data, {{ .by }})
if (!missing(.more_groups)) {
additional_groups <- fetch_group_vars(.data, {{ .more_groups }})
} else {
additional_groups <- character(0)
}
.data <- .data %>%
dplyr::ungroup() %>%
mutate(across(all_of(group_vars), as.character))
.data <- map(
rev(c(0,seq_along(group_vars))),
function(i) {
group_vars_i <- group_vars[0:i]
subtotal_vars <- setdiff(group_vars, group_vars_i)
subtotal_value_pairs <- set_names(replicate(length(subtotal_vars), "(all)"), subtotal_vars)
summarize(.data, !!!subtotal_value_pairs, ..., .by = all_of(c(additional_groups, group_vars_i)))
}
)
# bind_rows doesn't support lazy tables
.data <- Reduce(dplyr::union_all, .data)
if (.groups == "keep") {
.data <- dplyr::group_by(.data, dplyr::pick(all_of(group_vars)))
}
.data
}
#' @export
#' @rdname summarize_with_margins
mutate_over_margins <- function(.data, ..., .by = NULL, .groups = "drop") {
.groups <- rlang::arg_match(.groups, c("drop", "keep"))
group_vars <- fetch_group_vars(.data, {{ .by }})
.data <- dplyr::ungroup(.data)
temp_colnames <- paste0("temp_margin_", group_vars)
.data <-
.data %>%
mutate(across(all_of(group_vars), ~ .x == "(all)", .names = "{temp_colnames}")) %>%
mutate(..., .by = all_of(temp_colnames)) %>%
select(-all_of(temp_colnames))
if (.groups == "keep") {
.data <- dplyr::group_by(.data, dplyr::pick(all_of(group_vars)))
}
.data
}
fetch_group_vars <- function(.data, .by) {
.by <- rlang::enquo(.by)
if (rlang::quo_is_null(.by)) {
if (!dplyr::is_grouped_df(.data)) {
abort("`.by` must be provided or `.data` must be grouped.")
}
group_vars <- dplyr::group_vars(.data)
} else {
if (dplyr::is_grouped_df(.data)) {
abort("`.by` must not be provided if `.data` is grouped.")
}
group_vars <- names(tidyselect::eval_select(.by, .data))
}
group_vars
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment