Last active
July 7, 2018 04:25
-
-
Save halflearned/ffc0cfc6d822b11416bc3c1ab52773e1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
test_fun <- function(n=200, p=20, num.trees=NULL, mtry=NULL) { | |
# Creating regressors: x1, ..., xp | |
X <- as.data.frame(matrix(rnorm(n*p), n, p)) | |
colnames(X) <- paste("x", seq_len(p), sep="") | |
# Let x1 be the only relevant covariate | |
Y <- X[,1] | |
order_A <- colnames(X) # Regular column order; x1 first | |
order_B <- rev(colnames(X)) # Reverse column order; x1 last | |
fmla_A <- as.formula(str_c("Y ~", str_c(order_A, collapse=" + "))) | |
fmla_B <- as.formula(str_c("Y ~", str_c(order_B, collapse=" + "))) | |
# Train same forest twice | |
set.seed(1234) | |
rforest_A <- ranger::ranger( | |
formula=fmla_A, | |
data=cbind(X, Y), | |
num.trees=num.trees, | |
mtry=mtry) | |
set.seed(1234) | |
rforest_B <- ranger::ranger( | |
formula=fmla_B, | |
data=cbind(X, Y), | |
num.trees=num.trees, | |
mtry=mtry) | |
# Reset the seed | |
rm(.Random.seed, envir=globalenv()) | |
error_A <- rforest_A$prediction.error | |
error_B <- rforest_B$prediction.error | |
return(tibble(n=n, p=p, num.trees=num.trees, mtry=mtry, | |
error_A=error_A, error_B=error_B, | |
diff=error_A-error_B, ratio=error_A/error_B)) | |
} | |
# Helper function to repeat the experiment n_sim times and take the average of the results | |
iterate <- function(num.trees, mtry, n, p, n_sim=5000) { | |
print("Running with") | |
print(c(n, p, num.trees, mtry)) | |
purrr::rerun(test_fun(num.trees=num.trees, mtry=mtry, n=n, p=p), .n=n_sim) %>% | |
bind_rows() | |
} | |
# Table with different parameters | |
# (Note we are filtering out parameter combinations with mtry > p) | |
params <- expand.grid( | |
num.trees=c(20), | |
mtry=seq(20), | |
n=c(100), | |
p=c(20) | |
) %>% filter(mtry <= p) | |
output <- purrr::pmap(params, iterate) %>% | |
bind_rows() | |
avg_output <- output %>% group_by(mtry) %>% | |
summarize_all(mean) | |
# Let's see if error_A - error_B is significantly different from zero | |
table <- lm(diff ~ 0 + factor(mtry), data=output) %>% summary %>% coef %>% round(., 4) | |
print(table) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment