loop to run simulations of intrees methods across RF, GBM, and rpart algorithms. Code supporting blog post: http://matthewdharris.com/2016/07/30/one-tree-to-rule-them-all-intrees-and-rule-based-learing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library("data.table") | |
library("rowr") | |
library("inTrees") | |
library("dplyr") | |
library("randomForest") | |
library("xtable") | |
library("caret") | |
library("gbm") | |
library("rpart") | |
library("reshape2") | |
library("ggplot2") | |
library("ggalt") | |
### Read Data | |
dat <- fread("YOUR/DATA/DIRECTORY/FILE.csv") | |
dat <- data.frame(dat) | |
# remove some uneeded columns | |
dat_trimmed <- dat[,colnames(dat) %ni% c("V1", "tpi_sd250c", "tpi_cls250c", | |
"e_trail_dist")] | |
## set up data partition variables | |
train_v_test_fraction <- 0.75 # 0 to 1 | |
absence_presence_balance <- 1 # 1 for balanced; typically 3 | |
data_reduction_fraction <- 0.05 # 0 to 1 | |
runs <- 20 # resample repeats | |
simulated_results <- matrix(nrow = runs, ncol = 6) | |
colnames(simulated_results) <- c("run", "RF Agg", "GB agg", "rpart", "RF", "GBM") | |
for(i in 1:runs){ | |
message(paste0("processing run: ", i)) | |
# THIS FUNCTION NOT PROVIDED! it is a big intricate function, but bottom line is for you to have a test and train sample | |
test_train <- get.train.test.sets(dat_trimmed, train_v_test_fraction, sites_per_train, sites_per_test, | |
absence_presence_balance, data_reduction_fraction) | |
train <- test_train[["train"]] | |
test <- test_train[["test"]] | |
# slim down to just a few variables | |
dat_rnd <- data.frame(e_hyd_min = round(train$e_hyd_min,1), | |
std_32c = round(train$std_32c,2), | |
elev_2_drainh = round(train$elev_2_drainh,2)) | |
target <- as.factor(train$presence) | |
test_learned <- test[,c("e_hyd_min", "std_32c", "elev_2_drainh", "presence")] | |
rf <- randomForest(dat_rnd, target) | |
treeList <- RF2List(rf) # transform rf object to an inTrees' format | |
exec <- extractRules(treeList,dat_rnd) # R-executable conditions | |
ruleMetric <- getRuleMetric(exec,dat_rnd,target) # get rule metrics | |
ruleMetric <- pruneRule(ruleMetric,dat_rnd,target) | |
ruleMetric <- selectRuleRRF(ruleMetric,dat_rnd,target) | |
learner <- buildLearner(ruleMetric,dat_rnd,target) | |
Simp_Learner <- presentRules(ruleMetric,colnames(dat_rnd)) | |
target1 <- as.numeric(as.character(target)) | |
gb_train <- data.frame(presence = target1, dat_rnd) | |
gb <- gbm(presence ~ ., data = gb_train, distribution = "bernoulli", n.trees = 100) | |
treeList_gb <- GBM2List(gb,dat_rnd) | |
ruleExec_gb = extractRules(treeList_gb,dat_rnd) | |
ruleExec_gb <- unique(ruleExec_gb) | |
ruleMetric_gb <- getRuleMetric(ruleExec_gb,dat_rnd,target) | |
ruleMetric_gb <- pruneRule(ruleMetric_gb,dat_rnd,target) | |
ruleMetric_gb <- unique(ruleMetric_gb) | |
learner_gb <- buildLearner(ruleMetric_gb,dat_rnd,target) | |
test_learned$pred_rfagg <- as.numeric(applyLearner(learner, test_learned)) | |
rfagg_err <- mae(pred = test_learned$pred_rfagg, obs = test_learned$presence) | |
test_learned$pred_gbagg <- as.numeric(applyLearner(learner_gb, test_learned)) | |
gbagg_err <- mae(pred = test_learned$pred_gbagg, obs = test_learned$presence) | |
rpart_train <- data.frame(presence = target, dat_rnd) | |
rp <- rpart(presence ~ ., data = rpart_train) | |
test_learned$pred_rp <- predict(rp, newdata = test_learned, type = "class") | |
test_learned$pred_rp <- as.numeric(as.character(test_learned$pred_rp)) | |
rp_err <- mae(pred = test_learned$pred_rp, obs = test_learned$presence) | |
test_learned$pred_rf <- predict(rf, newdata = test_learned, type = "class") | |
test_learned$pred_rf <- as.numeric(as.character(test_learned$pred_rf)) | |
rf_err <- mae(pred = test_learned$pred_rf, obs = test_learned$presence) | |
test_learned$pred_gbm <- predict(gb, newdata = test_learned, n.trees = 100, type = "response") | |
test_learned$pred_gbm <- ifelse(test_learned$pred_gbm > 0.5, 1, 0) | |
gbm_err <- mae(pred = test_learned$pred_gbm, obs = test_learned$presence) | |
simulated_results[i,] <- c(i, rfagg_err, gbagg_err, rp_err, rf_err, gbm_err) | |
} | |
print(simulated_results) | |
# preparing aggregate results plot | |
sim_melt <- reshape2::melt(simulated_results)[-c(1:runs),] | |
median_err <- group_by(sim_melt, Var2) %>% | |
summarise(median = median(value)) %>% | |
data.frame() | |
mean_err <- group_by(sim_melt, Var2) %>% | |
summarise(mean = mean(value)) %>% | |
data.frame() | |
sim_melt$Variable_ordered_mean <-factor(sim_melt$Var2, | |
levels = mean_err[order(mean_err$mean, decreasing = FALSE), | |
"Var2"]) | |
p1 <- ggplot(sim_melt, aes(x = Variable_ordered_mean, y = value, group = Variable_ordered_mean)) + | |
# geom_boxplot(width = 0.5, color = "gray90") + | |
geom_line(aes(group = Var1), color = "gray68", alpha = 0.20) + | |
# geom_line(aes(group = Var1, color = as.factor(Var1)), alpha = 0.75) + | |
geom_jitter(width = 0.25) + | |
# geom_point() + | |
theme_bw() + | |
labs(title="Aggregate Prediction Error for Different Learners", | |
subtitle=paste0("Values for ", runs , " Resamples Ordered by Mean Error of 25% Hold-Out Sample"), | |
x = "Model", | |
y = "Mean Absolute Error") + | |
scale_y_continuous(breaks = seq(0.25,0.5,0.05)) + | |
theme( | |
legend.position = "none", | |
panel.border = element_rect(colour = "gray90"), | |
axis.text.x = element_text(angle = 90, size = 8, hjust = 1, family = "Trebuchet MS"), | |
axis.text.y = element_text(size = 8, family = "Trebuchet MS"), | |
axis.title = element_text(size = 10, family = "Trebuchet MS", face = "bold"), | |
plot.title = element_text(family="TrebuchetMS-Bold"), | |
plot.subtitle = element_text(family="TrebuchetMS-Italic") | |
) | |
plot(p1) | |
fn <- paste0("dot_line_plot_", runs, "_resamples.png") | |
ggsave(filename = fn, width = 6, height = 4) | |
# create results table | |
median_sims_row <- round(apply(simulated_results, 1, median),3) | |
median_sims <- cbind(simulated_results, median_sims_row) | |
median_sims_col <- round(apply(median_sims, 2, median),3) | |
median_sims <- rbind(median_sims_col, median_sims_col) | |
rownames(median_sims) <- c(1:10, "Median") | |
median_sims <- median_sims[,-1] | |
colnames(median_sims) <- c(colnames(median_sims)[-ncol(median_sims)], "Median") | |
median_sims <- apply(median_sims,2,round, digits = 3) | |
sims <- data.frame(median_sims, stringsAsFactors = FALSE) | |
simsx <- xtable(sims) | |
print(simsx, type = "html") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment