Skip to content

Instantly share code, notes, and snippets.

@bakaburg1
Last active October 1, 2020 23:18
Show Gist options
  • Save bakaburg1/893ea82db683b45d71477e8535a2a72e to your computer and use it in GitHub Desktop.
Save bakaburg1/893ea82db683b45d71477e8535a2a72e to your computer and use it in GitHub Desktop.
Small helper functions to interact with partykit::ctree().
#' Extract ctree rules.
#'
#' Extract the tree rules, optionally formatting them in order to be ready to
#' use for data filtering.
#'
#' @param tree An object produced by \code{partikit::ctree()}.
#' @param rule.as.text Whether to collapse rule conditions into a string or
#' leave them as a character vector.
#' @param eval.ready Whether to format rules in order to be easily eval() for
#' data filtering. Turns \code{rule.as.text} automatically on.
#'
#' @return A dataframe with the rule, the tree node id and the depth of the rule.
#'
#' @import dplyr
#' @import stringr
#' @import partykit
#' @importFrom readr parse_number
#'
#' @export
#'
#' @examples
get_ctree_rules <- function(tree, rule.as.text = T, eval.ready = F) {
library(dplyr)
library(stringr)
library(partykit)
library(readr)
if (length(tree) == 1) return(data.frame(rule = character(0), id = numeric(0), depth = numeric(0)))
out <- capture.output(tree)
rules <- tibble(
rule = out[(which(out == "[1] root") + 1):max(which(out == "") - 1)] %>%
str_remove('\\|\\s+'),
id = str_extract(rule, '\\[\\d+\\]') %>% parse_number(),
depth = str_count(rule, '\\|') + 1
)
rules <- rules %>% mutate(
rule = rule %>% str_remove_all('\\|\\s+') %>%
str_remove('\\[\\d+\\]') %>%
str_remove(':.*') %>% str_squish()
)
if (eval.ready) {
rules$rule <- rules$rule %>%
str_replace(' in ', ' %in% ') %>%
str_replace_all(c(', ' = '", "', '%\\s+' = '% c("', '(\\D)$' = '\\1")'))
}
rules %>% mutate(
rule = sapply(id, function(this.id) {
ids <- id[id <= this.id & depth <= depth[id == this.id]]
depths <- depth[id <= this.id & depth <= depth[id == this.id]]
ids <- tapply(ids, depths, max)
rule <- rule[id %in% ids]
if (rule.as.text) paste(rule, collapse = ' & ') else list(rule)
})
)
}
#' Simplify ctree rules.
#'
#' Remove redundant components of a rule keeping only the shortest set
#' definition (e.g.: if many conditions in a rule represent nested sets, only
#' those necessary to define the innermost set are kept). The conditions are
#' also rearranged alphabetically for easier comparison.
#'
#' @param rules A character vector of rules joined by the & symbol.
#'
#' @return The same vector of rules after simplification.
#'
#' @import dplyr
#' @import stringr
#'
#' @export
#'
#' @examples
simplify_rules <- function(rules) {
library(dplyr)
library(stringr)
sapply(rules, function(rule) {
if (rule == '') return(NA)
components <- str_split(rule, ' & ') %>% unlist
vars <- str_extract(components, '.* [<>%=in]+') %>% unique
ind <- sapply(vars, function(v) tail(which(str_detect(components, fixed(v))), 1))
paste(components[ind] %>% sort, collapse = ' & ')
}) %>% na.omit
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment