Skip to content

Instantly share code, notes, and snippets.

@namarkus
Created May 4, 2023 06:15
Show Gist options
  • Save namarkus/3512e982f5fffd766d77f0db534e134d to your computer and use it in GitHub Desktop.
Save namarkus/3512e982f5fffd766d77f0db534e134d to your computer and use it in GitHub Desktop.
add_total function for dplyr-pipes
add_total <- function(df, grp_vars, sum_vars, label='Total', ...) {
if(missing(grp_vars) & !missing(sum_vars)) {
total =
df %>%
group_by() %>%
summarise(across({{sum_vars}}, sum),
across(!{{sum_vars}} & where(is.factor), ~factor(label)),
across(!{{sum_vars}} & where(is.character), ~label),
.groups='drop')
res = df %>%
mutate(across(!{{sum_vars}} & where(is.factor), ~forcats::fct_expand(., label))) %>%
bind_rows(total)
} else if(!missing(grp_vars) & !missing(sum_vars)) {
total =
df %>%
group_by(across({{grp_vars}})) %>%
summarise(across({{sum_vars}}, sum),
across(!{{sum_vars}} & where(is.factor), ~factor(label)),
across(!{{sum_vars}} & where(is.character), ~label),
.groups='drop')
res = df %>%
mutate(across(c(!{{grp_vars}}, !{{sum_vars}}) & where(is.factor), ~forcats::fct_expand(., label))) %>%
bind_rows(total)
} else if(missing(grp_vars) & missing(sum_vars)) {
smry_funs = enquos(...)
total =
df %>%
summarise(!!!smry_funs,
across(where(is.factor), ~factor(label)),
across(where(is.character), ~label),
.groups='drop')
res = df %>% bind_rows(total)
} else if(!missing(grp_vars) & missing(sum_vars)) {
smry_funs = enquos(...)
total =
df %>%
group_by(across({{grp_vars}})) %>%
summarise(!!!smry_funs,
across(where(is.factor), ~factor(label)),
across(where(is.character), ~label),
.groups='drop')
res = df %>% bind_rows(total)
}
return(res)
}
# Tests
df =
mtcars %>%
as_tibble() %>%
mutate(across(c(cyl, vs), factor)) %>%
group_by(cyl, vs) %>%
summarise(n = n(), mmpg = mean(mpg), .groups = 'drop')
df %>% add_total(sum_vars = n)
df %>% add_total(grp_vars = cyl, sum_vars = n)
df %>% add_total(mmpg = weighted.mean(mmpg, n), n = sum(n))
df %>% add_total(grp_vars = cyl, mmpg = weighted.mean(mmpg, n), n = sum(n))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment