Skip to content

Instantly share code, notes, and snippets.

@mrecos
Created July 30, 2016 02:59
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrecos/9adcfcb494e6f834f4899982b6a46d4e to your computer and use it in GitHub Desktop.
Save mrecos/9adcfcb494e6f834f4899982b6a46d4e to your computer and use it in GitHub Desktop.
loop to run simulations of intrees methods across RF, GBM, and rpart algorithms. Code supporting blog post: http://matthewdharris.com/2016/07/30/one-tree-to-rule-them-all-intrees-and-rule-based-learing
library("data.table")
library("rowr")
library("inTrees")
library("dplyr")
library("randomForest")
library("xtable")
library("caret")
library("gbm")
library("rpart")
library("reshape2")
library("ggplot2")
library("ggalt")
### Read Data
dat <- fread("YOUR/DATA/DIRECTORY/FILE.csv")
dat <- data.frame(dat)
# remove some uneeded columns
dat_trimmed <- dat[,colnames(dat) %ni% c("V1", "tpi_sd250c", "tpi_cls250c",
"e_trail_dist")]
## set up data partition variables
train_v_test_fraction <- 0.75 # 0 to 1
absence_presence_balance <- 1 # 1 for balanced; typically 3
data_reduction_fraction <- 0.05 # 0 to 1
runs <- 20 # resample repeats
simulated_results <- matrix(nrow = runs, ncol = 6)
colnames(simulated_results) <- c("run", "RF Agg", "GB agg", "rpart", "RF", "GBM")
for(i in 1:runs){
message(paste0("processing run: ", i))
# THIS FUNCTION NOT PROVIDED! it is a big intricate function, but bottom line is for you to have a test and train sample
test_train <- get.train.test.sets(dat_trimmed, train_v_test_fraction, sites_per_train, sites_per_test,
absence_presence_balance, data_reduction_fraction)
train <- test_train[["train"]]
test <- test_train[["test"]]
# slim down to just a few variables
dat_rnd <- data.frame(e_hyd_min = round(train$e_hyd_min,1),
std_32c = round(train$std_32c,2),
elev_2_drainh = round(train$elev_2_drainh,2))
target <- as.factor(train$presence)
test_learned <- test[,c("e_hyd_min", "std_32c", "elev_2_drainh", "presence")]
rf <- randomForest(dat_rnd, target)
treeList <- RF2List(rf) # transform rf object to an inTrees' format
exec <- extractRules(treeList,dat_rnd) # R-executable conditions
ruleMetric <- getRuleMetric(exec,dat_rnd,target) # get rule metrics
ruleMetric <- pruneRule(ruleMetric,dat_rnd,target)
ruleMetric <- selectRuleRRF(ruleMetric,dat_rnd,target)
learner <- buildLearner(ruleMetric,dat_rnd,target)
Simp_Learner <- presentRules(ruleMetric,colnames(dat_rnd))
target1 <- as.numeric(as.character(target))
gb_train <- data.frame(presence = target1, dat_rnd)
gb <- gbm(presence ~ ., data = gb_train, distribution = "bernoulli", n.trees = 100)
treeList_gb <- GBM2List(gb,dat_rnd)
ruleExec_gb = extractRules(treeList_gb,dat_rnd)
ruleExec_gb <- unique(ruleExec_gb)
ruleMetric_gb <- getRuleMetric(ruleExec_gb,dat_rnd,target)
ruleMetric_gb <- pruneRule(ruleMetric_gb,dat_rnd,target)
ruleMetric_gb <- unique(ruleMetric_gb)
learner_gb <- buildLearner(ruleMetric_gb,dat_rnd,target)
test_learned$pred_rfagg <- as.numeric(applyLearner(learner, test_learned))
rfagg_err <- mae(pred = test_learned$pred_rfagg, obs = test_learned$presence)
test_learned$pred_gbagg <- as.numeric(applyLearner(learner_gb, test_learned))
gbagg_err <- mae(pred = test_learned$pred_gbagg, obs = test_learned$presence)
rpart_train <- data.frame(presence = target, dat_rnd)
rp <- rpart(presence ~ ., data = rpart_train)
test_learned$pred_rp <- predict(rp, newdata = test_learned, type = "class")
test_learned$pred_rp <- as.numeric(as.character(test_learned$pred_rp))
rp_err <- mae(pred = test_learned$pred_rp, obs = test_learned$presence)
test_learned$pred_rf <- predict(rf, newdata = test_learned, type = "class")
test_learned$pred_rf <- as.numeric(as.character(test_learned$pred_rf))
rf_err <- mae(pred = test_learned$pred_rf, obs = test_learned$presence)
test_learned$pred_gbm <- predict(gb, newdata = test_learned, n.trees = 100, type = "response")
test_learned$pred_gbm <- ifelse(test_learned$pred_gbm > 0.5, 1, 0)
gbm_err <- mae(pred = test_learned$pred_gbm, obs = test_learned$presence)
simulated_results[i,] <- c(i, rfagg_err, gbagg_err, rp_err, rf_err, gbm_err)
}
print(simulated_results)
# preparing aggregate results plot
sim_melt <- reshape2::melt(simulated_results)[-c(1:runs),]
median_err <- group_by(sim_melt, Var2) %>%
summarise(median = median(value)) %>%
data.frame()
mean_err <- group_by(sim_melt, Var2) %>%
summarise(mean = mean(value)) %>%
data.frame()
sim_melt$Variable_ordered_mean <-factor(sim_melt$Var2,
levels = mean_err[order(mean_err$mean, decreasing = FALSE),
"Var2"])
p1 <- ggplot(sim_melt, aes(x = Variable_ordered_mean, y = value, group = Variable_ordered_mean)) +
# geom_boxplot(width = 0.5, color = "gray90") +
geom_line(aes(group = Var1), color = "gray68", alpha = 0.20) +
# geom_line(aes(group = Var1, color = as.factor(Var1)), alpha = 0.75) +
geom_jitter(width = 0.25) +
# geom_point() +
theme_bw() +
labs(title="Aggregate Prediction Error for Different Learners",
subtitle=paste0("Values for ", runs , " Resamples Ordered by Mean Error of 25% Hold-Out Sample"),
x = "Model",
y = "Mean Absolute Error") +
scale_y_continuous(breaks = seq(0.25,0.5,0.05)) +
theme(
legend.position = "none",
panel.border = element_rect(colour = "gray90"),
axis.text.x = element_text(angle = 90, size = 8, hjust = 1, family = "Trebuchet MS"),
axis.text.y = element_text(size = 8, family = "Trebuchet MS"),
axis.title = element_text(size = 10, family = "Trebuchet MS", face = "bold"),
plot.title = element_text(family="TrebuchetMS-Bold"),
plot.subtitle = element_text(family="TrebuchetMS-Italic")
)
plot(p1)
fn <- paste0("dot_line_plot_", runs, "_resamples.png")
ggsave(filename = fn, width = 6, height = 4)
# create results table
median_sims_row <- round(apply(simulated_results, 1, median),3)
median_sims <- cbind(simulated_results, median_sims_row)
median_sims_col <- round(apply(median_sims, 2, median),3)
median_sims <- rbind(median_sims_col, median_sims_col)
rownames(median_sims) <- c(1:10, "Median")
median_sims <- median_sims[,-1]
colnames(median_sims) <- c(colnames(median_sims)[-ncol(median_sims)], "Median")
median_sims <- apply(median_sims,2,round, digits = 3)
sims <- data.frame(median_sims, stringsAsFactors = FALSE)
simsx <- xtable(sims)
print(simsx, type = "html")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment