Skip to content

Instantly share code, notes, and snippets.

@abmathewks
Created April 8, 2021 17:56
Show Gist options
  • Save abmathewks/f08463bb5570353da87ff08e7a645a41 to your computer and use it in GitHub Desktop.
Save abmathewks/f08463bb5570353da87ff08e7a645a41 to your computer and use it in GitHub Desktop.
#' GET_NETWORK_SUMMARY
#'
#' @param DATA_DT The name of the data table
#' @param USE_THESE_FEATURES a vector of all variable names
#' @param DO_SAMPLE Logical value for if the data should be sampled
#' @param DO_SAMPLE_FRAC If Do Sample is set to true, what percentage
#' @param BY_GROUP logical value
#' @param BY_GROUP_VAL get the results by each group
#' @param DEBUG Should the script run with debugging
#'
#' @return a data table with each node and what variables are related to them
#' @export
#'
#' @examples
GET_NETWORK_SUMMARY <- function(DATA_DT,
USE_THESE_FEATURES = NULL,
DO_SAMPLE = TRUE,
DO_SAMPLE_FRAC = 0.70,
BY_GROUP = FALSE,
BY_GROUP_VAL = NULL,
DEBUG = TRUE){
if(DEBUG) message('|===>> DEBUG: GET_NETWORK_SUMMARY - Checking function arguments. \n')
if( missing(DATA_DT) )
stop("Argument ", deparse(substitute(DATA_DT)), " is missing.")
}
if( !is.data.table(DATA_DT) ) {
stop("Argument ", deparse(substitute(DATA_DT)), " is not a DATA_DT table.")
}
if( is.null(USE_THESE_FEATURES) ) {
stop("Argument ", deparse(substitute(USE_THESE_FEATURES)), " is missing.")
}
if(DO_SAMPLE){
if(DEBUG) message("|===>> DEBUG: GET_NETWORK_SUMMARY - Sampling DATA_DT percentage: ", DO_SAMPLE_FRAC, ' \n')
DATA_DT <- DATA_DT[sample(.N, (.N * DO_SAMPLE_FRAC))]
}
if(DEBUG) message('|===>> DEBUG: GET_NETWORK_SUMMARY - Processing Bayesian Network. \n')
network_summary_output_lst <- list()
if(!BY_GROUP){
net = bnlearn::hc(DATA_DT[, mget(USE_THESE_FEATURES)])
# plot(net)
fitted <- bnlearn::bn.fit(net, DATA_DT[, mget(USE_THESE_FEATURES)])
# fitted
all_arc_strengths <- setDT(bnlearn::arc.strength(net, DATA_DT[, mget(USE_THESE_FEATURES)]))
# all_arc_strengths
# each_element = "Factor_5"
for(each_element in all_arc_strengths$to){
if(DEBUG) message('|===>> DEBUG: GET_NETWORK_SUMMARY - Processing variable: ', each_element, ' \n')
network_summary_output_tmp <- data.table()
network_summary_output_tmp[, run_date := Sys.Date()]
network_summary_output_tmp[, child_var := fitted[[each_element]][["node"]]]
network_summary_output_tmp[, parent_var := paste0(fitted[[each_element]][["parents"]], collapse = ", ")]
# all_arc_strengths[from == fitted[[each_element]][["parents"]] &
# to == fitted[[each_element]][["node"]] ,][, strength]
network_summary_output_tmp[, arc_strength := all_arc_strengths[from %in% fitted[[each_element]][["parents"]] &
to %in% fitted[[each_element]][["node"]] ,][, strength]]
if(DEBUG){
if(!missing(network_summary_output_tmp)){
message('|===>> DEBUG: GET_NETWORK_SUMMARY - network_summary_output_tmp has been created and has ',
nrow(network_summary_output_tmp), ' rows. \n')
} else {
stop("network_summary_output_tmp is missing. Please investigate")
}
}
network_summary_output_lst[[each_element]] <- network_summary_output_tmp
}
} else {
# each_group = "Factor_5"
for(each_group in BY_GROUP_VAL){
if(DEBUG) message('|===>> DEBUG: GET_NETWORK_SUMMARY - Processing variable: ', each_group, ' \n')
#DATA_DT = data
data_dict <- DATA_DT[, .N, by = each_group]
colnames(data_dict)[1] <- "group_name"
# each_group_val = "C"
for(each_group_val in unique(data_dict$group_name)){
if(DEBUG) message('|===>> DEBUG: GET_NETWORK_SUMMARY - Processing variable: ', each_group_val, ' \n')
filter_criteria <- paste0(colnames(data_dict)[1], " == ", each_group_val)
DATA_DT_SUB <- DATA_DT[get(each_group) == each_group_val, ]
# DATA_DT_SUB[, .N, by = Factor_5]
net = bnlearn::hc(DATA_DT_SUB[, mget(USE_THESE_FEATURES)])
# plot(net)
fitted <- bnlearn::bn.fit(net, DATA_DT_SUB[, mget(USE_THESE_FEATURES)])
# fitted
all_arc_strengths <- setDT(bnlearn::arc.strength(net, DATA_DT_SUB[, mget(USE_THESE_FEATURES)]))
# all_arc_strengths
# each_element = "Independent_Variable1"
for(each_element in names(fitted)){
if(DEBUG) message('|===>> DEBUG: GET_NETWORK_SUMMARY - Processing variable: ', each_element, ' \n')
network_summary_output_tmp <- data.table()
network_summary_output_tmp[, run_date := Sys.Date()]
network_summary_output_tmp[, group_name := each_group]
network_summary_output_tmp[, group_name_value := each_group_val]
network_summary_output_tmp[, child_var := fitted[[each_element]][["node"]]]
network_summary_output_tmp[, parent_var := paste0(fitted[[each_element]][["parents"]], collapse = ", ")]
# all_arc_strengths[from == fitted[[each_element]][["parents"]] &
# to == fitted[[each_element]][["node"]] ,][, strength]
network_summary_output_tmp[, arc_strength := all_arc_strengths[from %in% fitted[[each_element]][["parents"]] &
to %in% fitted[[each_element]][["node"]] ,][, strength]]
if(DEBUG){
if(!missing(network_summary_output_tmp)){
message('|===>> DEBUG: GET_NETWORK_SUMMARY - network_summary_output_tmp has been created and has ',
nrow(network_summary_output_tmp), ' rows. \n')
} else {
stop("network_summary_output_tmp is missing. Please investigate")
}
}
network_summary_output_lst[[each_element]] <- network_summary_output_tmp
}
}
}
}
network_summary_output <- rbindlist(network_summary_output_lst)
setorder(network_summary_output, -arc_strength)
return(network_summary_output)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment