-
-
Save walterst/9444f900e5382ba50c92 to your computer and use it in GitHub Desktop.
Library associated with https://gist.github.com/walterst/2222618976a66b3fc8dd. Written by Zhenjiang Xu.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
interface_generalize <- function() { | |
library(optparse) | |
meta.fp <- 'mapping_file.txt' | |
otus.fp <- 'closed_reference_otu_table_rare.txt' | |
n.cores <- 12 | |
## common models | |
regression <- c("pls", "mars", "svm_rb", "svm_poly", "svm_linear", "knn", "cart", "rf", "cubist", "gbm") | |
option_list <- list(make_option(c("-i", "--input_otu_table"), default=otus.fp, | |
action="store", type="character", | |
help="Input OTU table [default %default]"), | |
make_option(c("-m", "--metadata"), default=meta.fp, | |
action="store", type="character", | |
help="Mapping file [default %default]"), | |
make_option(c("--input_otu_table_2"), default=otus.fp, | |
action="store", type="character", | |
help="Input OTU table [default %default]"), | |
make_option(c("--metadata_2"), default=meta.fp, | |
action="store", type="character", | |
help="Mapping file [default %default]"), | |
make_option(c("-f", "--fields"), default=NULL, | |
action="store", type="character", | |
help="Fields or categories to test [default %default]"), | |
make_option(c("-r", "--models"), default=NULL, | |
action="store", type="character", | |
help=paste("Regression models to use.", | |
"It can be pls, mars, nnet, svm_rb, svm_poly, svm_linear,", | |
"knn, cart, m5, ctree, rf, cubist, and gbm. [default %default]", | |
"If no modes are specified, it run all the common models: ", | |
paste(regression, collapse=', '), | |
sep='\n\t\t')), | |
make_option(c("-o", "--output"), default="benchmark", | |
action="store", type="character", | |
help="Output file names. [default %default]. It will save pdf and Rdata files."), | |
make_option(c("-c", "--cores"), default=n.cores, | |
action="store", type="integer", | |
help="Number of CPU cores [default %default]"), | |
make_option(c("-s", "--split"), default=1, | |
action="store", type="double", | |
help=paste("Split a only fraction of data as training set and", | |
"hold out the rest for final testing. It should be between 0", | |
"and 1. [default %default].", | |
sep='\n\t\t')), | |
make_option(c("--category"), default=NULL, | |
action="store", type="character", | |
help=paste("Provide a category and a value to extract the samples", | |
"of that category which has the value. For example,", | |
"'SITE::nostril,skin' will only use the samples that are collected", | |
"from nostril and skin.", | |
sep='\n\t\t')), | |
make_option(c("--numeric"), default=NULL, | |
action="store", type="character", | |
help=paste("Similar to the '--category' but applies to numeric meta data", | |
"For example, 'PH::6,12' will only use the samples with PH between 6 and 12 (not including 6 and 12);", | |
"'DAYS::,16' will only use the samples with days less than 16. [default %default]", | |
sep='\n\t\t')), | |
make_option(c("--numeric2"), default=NULL, | |
action="store", type="character", | |
help=paste("Similar to the '--category' but applies to numeric meta data", | |
"For example, 'PH::6,12' will only use the samples with PH between 6 and 12 (not including 6 and 12);", | |
"'DAYS::,16' will only use the samples with days less than 16. [default %default]", | |
sep='\n\t\t')), | |
make_option(c("--add_category"), default=NULL, | |
action="store", type="character", | |
help="Add the categorical field in metadata as preditors"), | |
make_option(c("--add_numeric"), default=NULL, | |
action="store", type="character", | |
help="Add the numeric field in metadata as preditors"), | |
make_option(c("--file"), default=NULL, | |
action="store", type="character", | |
help="Add the numeric field in metadata as preditors"), | |
make_option(c("-v", "--verbose"), default=TRUE, | |
action="store_true", type="logical", | |
help="Output running infomation? [default %default]"), | |
make_option(c("--diagnostic"), default=FALSE, | |
action="store_true", type="logical", | |
help="Output running infomation? [default %default]"), | |
make_option(c("-d", "--debug"), default=FALSE, | |
action="store_true", type="logical", | |
help="Output running infomation? [default %default]")) | |
args <- commandArgs(trailingOnly=T) | |
x <- which(args %in% "--file") | |
if (length(x) > 0) { | |
file.arg <- args[x+1] | |
if (file.exists(file.arg)) { | |
args <- c(args, scan(file.arg, what='character')) | |
} else { | |
stop("The args file ", file.arg, " does not exist.") | |
} | |
} | |
opt <- parse_args(OptionParser(usage = "Rscript %prog [options] file.", | |
description = paste("Example:", | |
"# Run on date and sex with cubist and random forest models", | |
"Rscript %prog -f DATE,SEX -r cubist,rf", | |
"# Run on PH with random forest model and save output rf.Rdata and rf.pdf", | |
"Rscript %prog -f PH -r rf -o rf", | |
sep="\n\t"), | |
option_list = option_list), | |
args=args) | |
opt | |
} | |
interface <- function() { | |
library(optparse) | |
meta.fp <- 'mapping_file.txt' | |
otus.fp <- 'closed_reference_otu_table_rare.txt' | |
n.cores <- 12 | |
## common models | |
regression <- c("pls", "mars", "svm_rb", "svm_poly", "svm_linear", "knn", "cart", "rf", "cubist", "gbm") | |
classification <- c("pls", "glm", "lda", "glmnet", "pam", "fda", "svmRadial", "knn", "nb", "mda", "sparseLDA", "PART", "C5.0Rules", "rf", "gbm") | |
option_list <- list(make_option(c("-i", "--input_otu_table"), default=otus.fp, | |
action="store", type="character", | |
help="Input OTU table [default %default]"), | |
make_option(c("-m", "--metadata"), default=meta.fp, | |
action="store", type="character", | |
help="Mapping file [default %default]"), | |
make_option(c("-f", "--fields"), default=NULL, | |
action="store", type="character", | |
help="Fields or categories to test [default %default]"), | |
make_option(c("--replicate"), default=NULL, | |
action="store", type="character", | |
help="Fields or categories to test [default %default]"), | |
make_option(c("-r", "--models"), default=NULL, | |
action="store", type="character", | |
help=paste("Regression models to use.", | |
"It can be pls, mars, nnet, svm_rb, svm_poly, svm_linear,", | |
"knn, cart, m5, ctree, rf, cubist, and gbm. [default %default]", | |
"If no modes are specified, it run all the common models: ", | |
paste(regression, collapse=', '), | |
sep='\n\t\t')), | |
make_option(c("-o", "--output"), default="benchmark", | |
action="store", type="character", | |
help="Output file names. [default %default]. It will save pdf and Rdata files."), | |
make_option(c("-c", "--cores"), default=n.cores, | |
action="store", type="integer", | |
help="Number of CPU cores [default %default]"), | |
make_option(c("-s", "--split"), default=1, | |
action="store", type="double", | |
help=paste("Split a only fraction of data as training set and", | |
"hold out the rest for final testing. It should be between 0", | |
"and 1. [default %default].", | |
sep='\n\t\t')), | |
make_option(c("--category"), default=NULL, | |
action="store", type="character", | |
help=paste("Provide a category and a value to extract the samples", | |
"of that category which has the value. For example,", | |
"'SITE::nostril,skin' will only use the samples that are collected", | |
"from nostril and skin.", | |
sep='\n\t\t')), | |
make_option(c("--numeric"), default=NULL, | |
action="store", type="character", | |
help=paste("Similar to the '--category' but applies to numeric meta data", | |
"For example, 'PH::6,12' will only use the samples with PH between 6 and 12 (not including 6 and 12);", | |
"'DAYS::,16' will only use the samples with days less than 16. [default %default]", | |
sep='\n\t\t')), | |
make_option(c("--add_category"), default=NULL, | |
action="store", type="character", | |
help="Add the categorical field in metadata as preditors"), | |
make_option(c("--add_numeric"), default=NULL, | |
action="store", type="character", | |
help="Add the numeric field in metadata as preditors"), | |
make_option(c("--file"), default=NULL, | |
action="store", type="character", | |
help="Add the numeric field in metadata as preditors"), | |
make_option(c("--balance"), default=FALSE, | |
action="store_true", type="logical", | |
help="Balance the classes. Only for classification."), | |
make_option(c("--feature_selection"), default=FALSE, | |
action="store_true", type="logical", | |
help="Do feature selection with RFE? Only support RF currently."), | |
make_option(c("-v", "--verbose"), default=TRUE, | |
action="store_true", type="logical", | |
help="Output running infomation? [default %default]"), | |
make_option(c("-d", "--debug"), default=FALSE, | |
action="store_true", type="logical", | |
help="Output running infomation? [default %default]")) | |
args <- commandArgs(trailingOnly=T) | |
x <- which(args %in% "--file") | |
if (length(x) > 0) { | |
file.arg <- args[x+1] | |
if (file.exists(file.arg)) { | |
args <- c(args, scan(file.arg, what='character')) | |
} else { | |
stop("The args file ", file.arg, " does not exist.") | |
} | |
} | |
opt <- parse_args(OptionParser(usage = "Rscript %prog [options] file.", | |
description = paste("Example:", | |
"# Run on date and sex with cubist and random forest models", | |
"Rscript %prog -f DATE,SEX -r cubist,rf", | |
"# Run on PH with random forest model and save output rf.Rdata and rf.pdf", | |
"Rscript %prog -f PH -r rf -o rf", | |
sep="\n\t"), | |
option_list = option_list), | |
args=args) | |
opt | |
} | |
## plot the feature importance to show | |
## the important taxonomies | |
plot.imp <- function (imp, tax.16s, topImp=10, ...) { | |
x <- imp$importance | |
x <- x[order(-x[1]), , drop=F] | |
taxId <- rev(gsub("`+", "", rownames(x)[1:topImp])) | |
taxImp <- tax.16s[taxId] | |
## this axis function will enable the tax ID to plot on the left side | |
## of y axis and tax string on the right side. | |
axis.sigmasq <- function(side, ...) { | |
switch(side, | |
left = { | |
panel.axis(side=side, outside=TRUE, text.cex=0.7, | |
at=c(1:topImp), labels=taxId) | |
}, | |
right = { | |
panel.axis(side=side, outside=TRUE, text.cex=0.7, | |
at=c(1:topImp), labels=taxImp) | |
}, | |
axis.default(side=side, ...)) | |
} | |
## plot top 10 variable importance | |
plot(imp, top=topImp, axis=axis.sigmasq, ...) | |
} | |
accuracy <- function(model.tuned, metric=c('RMSE', 'Rsquared'), stats=c('mean', 'se')) { | |
## require(caret) | |
## metric should be the column names or column numbers | |
se <- function(x) sd(x)/sqrt(length(x)) | |
accuracy <- apply(model.tuned$resample[, metric, drop=FALSE], | |
2, | |
function (x) { | |
sapply(stats, function(y) get(y)(x)) | |
}) | |
accuracy <- as.data.frame(accuracy) | |
accuracy$Model <- model.tuned$method | |
# accuracy$Stats <- c('MEAN', 'SE') | |
accuracy | |
} | |
classification.tune <- function (trainX, trainY, model, ctrl=NULL, idx=NULL, ...) { | |
require(caret) | |
if (is.null(ctrl)) { | |
set.seed(1) | |
ctrl <- trainControl(method = "repeatedcv", | |
repeats = 5, number = 10, # 10-fold CV, 5 repeats | |
selectionFunction = "oneSE", | |
summaryFunction = if (is.factor(trainY) & nlevels(trainY) == 2) { | |
function(...) c(twoClassSummary(...), defaultSummary(...)) } else defaultSummary, | |
classProbs = TRUE, | |
index = idx, | |
savePredictions = TRUE) | |
} | |
tuned <- tryCatch( { | |
if ('pls' == model) { | |
cat("\n---- running PLS...\n") | |
set.seed(10) | |
plsTune <- train(x = trainX, | |
y = trainY, | |
method = "pls", | |
trControl = ctrl, | |
metric = "Kappa", | |
tuneGrid = expand.grid(.ncomp = 1:9), | |
preProc = c("center", "scale"), | |
...) | |
plot(plsTune) | |
plsTune | |
} else if ('glm'==model) { | |
# logistic regression model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "glm", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('lda'==model) { | |
# linear discriminant model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "lda", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('glmnet'==model) { | |
# glmnet model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "glmnet", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('pam'==model) { | |
# nearest shrunken model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "pam", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('fda'==model) { | |
# flexible discriminant model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "fda", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('svmRadial'==model) { | |
# SVM radial model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "svmRadial", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('knn'==model) { | |
# KNN model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "knn", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('nb'==model) { | |
# naive baysian model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "nb", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('mda'==model) { | |
# mixture discriminant model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "mda", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('sparseLDA'==model) { | |
# sparse logistic regression model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "sparseLDA", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('PART'==model) { | |
# rule-based model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "PART", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 10) | |
plot(partTune) | |
partTune | |
} else if ('C5.0Rules'==model) { | |
# rule-based model | |
set.seed(10) | |
partTune<- train(x = trainX, | |
y = trainY, | |
method = "C5.0Rules", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
tuneLength = 3) | |
plot(partTune) | |
partTune | |
} else if ('rf'==model) { | |
mtryGrid <- data.frame(.mtry = floor(seq(10, ncol(trainX), length = 10))) | |
set.seed(10) | |
rfTune<- train(x = trainX, | |
y = trainY, | |
method = "rf", | |
trControl = ctrl, | |
metric = ifelse(nlevels(trainY) > 2, "Kappa", "ROC"), | |
ntree = 1000, | |
tuneGrid = mtryGrid, | |
importance = TRUE, | |
...) | |
plot(rfTune) | |
rfTune | |
} else if ('gbm'==model) { | |
cat("\n---- running boosting tree...\n") | |
gbmGrid <- expand.grid(.interaction.depth = seq(1, 7, by = 2), | |
.n.trees = seq(100, 1000, by = 50), | |
.shrinkage = c(0.01, 0.1)) | |
set.seed(10) | |
gbmTune <- train(x = trainX, | |
y = trainY, | |
method = "gbm", | |
trControl = ctrl, | |
# metric = if (nlevels(trainY) > 2) "Kappa" else "ROC", | |
metric = "Kappa", | |
tuneGrid = gbmGrid, | |
## tuneLength = 200, | |
verbose = FALSE, | |
...) | |
plot(gbmTune, auto.key = list(columns = 4, lines = TRUE)) | |
gbmTune | |
} | |
}, error=function(err) { | |
## message(paste("Failed running category:", label)) | |
cat("original error message:\n") | |
print(err) | |
return (NA) | |
} ) | |
tuned | |
} | |
regression.tune <- function (trainX, trainY, model, ctrl=NULL, ...) { | |
## use "within one standard deviation" model | |
## and repeated 10-fold CV | |
require(caret) | |
if (is.null(ctrl)) { | |
set.seed(1) | |
ctrl <- trainControl(method="repeatedcv", number=10, repeats=5, | |
savePredictions = TRUE, | |
selectionFunction = "oneSE") | |
} | |
tuned <- tryCatch( { | |
if ('lm' == model) { | |
cat("\n---- running Linear regression...\n") | |
set.seed(10) | |
lmTune <- train(x = trainX, | |
y = trainY, | |
method = "lm", | |
trControl = ctrl) | |
lmTune | |
} else if ('pls' == model) { | |
cat("\n---- running PLS...\n") | |
set.seed(10) | |
plsTune <- train(x = trainX, | |
y = trainY, | |
method = "pls", | |
trControl = ctrl, | |
tuneGrid = expand.grid(.ncomp = 1:9), | |
## tuneLength = 12, | |
preProc = c("center", "scale"), | |
...) | |
plot(plsTune) | |
plsTune | |
} else if ('mars' == model) { | |
## no need for data transformation; | |
## correlated features will confound feature importance. | |
cat("\n---- running MARS...\n") | |
marsGrid <- expand.grid(.degree = 1:2, | |
.nprune = c(1:9, seq(10, 100, by=2))) | |
set.seed(10) | |
marsTune <- train(x = trainX, | |
y = trainY, | |
method = "earth", | |
trControl = ctrl, | |
tuneGrid = marsGrid, | |
## tuneLength = 100, | |
...) | |
plot(marsTune) | |
marsTune | |
} else if("nnet" == model) { | |
cat("\n---- running neural network...\n") | |
nnetGrid <- expand.grid(.decay = c(0, 0.01, .1), | |
.size = c(1, 3, 5, 7, 9), | |
.bag = FALSE) | |
set.seed(10) | |
nnetTune <- train(x = trainX, | |
y = trainY, | |
trControl = ctrl, | |
method = "avNNet", | |
tuneGrid = nnetGrid, | |
## tuneLength = 100, | |
preProc = c("center", "scale"), | |
linout = TRUE, | |
trace = FALSE, | |
MaxNWts = 9 * (ncol(trainX) + 1) + 9 + 1, | |
allowParallel = FALSE, | |
maxit = 100, | |
...) | |
plot(nnetTune) | |
nnetTune | |
} else if("svm_rb" == model) { | |
cat("\n---- running SVM radial basis...\n") | |
set.seed(10) | |
svmRGrid <- expand.grid(.C = 2^(-2:5), | |
.sigma = c(0.1, 0.3, 0.5)) | |
svmRTune <- train(x = trainX, | |
y = trainY, | |
method = "svmRadial", | |
tuneGrid = svmRGrid, | |
## trControl = ctrl, | |
tuneLength = 100, | |
preProc = c("center", "scale"), | |
...) | |
plot(svmRTune, scales = list(x = list(log = 2))) | |
svmRTune | |
} else if("svm_poly" == model) { | |
cat("\n---- running SVM polynomial...\n") | |
svmPGrid <- expand.grid(.degree = 1:2, | |
.scale = c(0.01, 0.005, 0.001), | |
.C = 2^(-2:5)) | |
set.seed(10) | |
svmPTune <- train(x = trainX, | |
y = trainY, | |
method = "svmPoly", | |
trControl = ctrl, | |
preProc = c("center", "scale"), | |
tuneGrid=svmPGrid, | |
## tuneLength = 200, | |
...) | |
plot(svmPTune, | |
scales = list(x = list(log = 2), between = list(x = .5, y = 1))) | |
svmPTune | |
} else if ("svm_linear" == model) { | |
cat("\n---- running SVM linear...\n") | |
set.seed(10) | |
svmLGrid <- expand.grid(.C = 2^(-2:5)) | |
svmLTune <- train(x = trainX, | |
y = trainY, | |
method = "svmLinear", | |
trControl = ctrl, | |
tuneGrid = svmLGrid, | |
## tuneLength = 12, | |
preProc = c("center", "scale"), | |
...) | |
plot(svmLTune, | |
scales = list(x = list(log = 2), between = list(x = .5, y = 1))) | |
svmLTune | |
} else if ("knn" == model) { | |
cat("\n---- running KNN...\n") | |
set.seed(10) | |
knnTune <- train(x = trainX, | |
y = trainY, | |
method = "knn", | |
trControl = ctrl, | |
tuneGrid = data.frame(.k = 1:20), | |
## tuneLength = 30, | |
preProc = c("center", "scale"), | |
...) | |
plot(knnTune) | |
knnTune | |
} else if ("cart" == model) { | |
cat("\n---- running CART...\n") | |
library(rpart) | |
set.seed(10) | |
cartTune <- train(x = trainX, | |
y = trainY, | |
method = "rpart", | |
trControl = ctrl, | |
## tune the complexity parameter | |
tuneLength = 25, | |
...) | |
plot(cartTune, scales = list(x = list(log = 10))) | |
cartTune | |
} else if ("ctree" == model) { | |
cat("\n---- running conditional inference tree...\n") | |
cGrid <- data.frame(.mincriterion = sort(c(.95, seq(.75, .99, length = 2)))) | |
set.seed(10) | |
ctreeTune <- train(x = trainX, | |
y = trainY, | |
method = "ctree", | |
trControl = ctrl, | |
tuneGrid = cGrid, | |
## tuneLength = 100, | |
...) | |
plot(ctreeTune) | |
plot(ctreeTune$finalModel) | |
ctreeTune | |
} else if ("m5" == model) { | |
cat("\n---- running M5...\n") | |
set.seed(10) | |
m5Tune <- train(x = trainX, | |
y = trainY, | |
method = "M5", | |
trControl = ctrl, | |
control = Weka_control(M = 10), | |
...) | |
plot(m5Tune) | |
plot(m5Tune$finalModel) | |
m5Tune | |
} else if ("gbm" == model) { | |
cat("\n---- running boosting tree...\n") | |
gbmGrid <- expand.grid(.interaction.depth = seq(1, 7, by = 2), | |
.n.trees = seq(100, 1000, by = 50), | |
.shrinkage = c(0.01, 0.1)) | |
set.seed(10) | |
gbmTune <- train(x = trainX, | |
y = trainY, | |
method = "gbm", | |
trControl = ctrl, | |
tuneGrid = gbmGrid, | |
## tuneLength = 200, | |
verbose = FALSE, | |
...) | |
plot(gbmTune, auto.key = list(columns = 4, lines = TRUE)) | |
gbmTune | |
} else if ("rf" == model) { | |
cat("\n---- running random forest...\n") | |
mtryGrid <- data.frame(.mtry = floor(seq(10, ncol(trainX), length = 10))) | |
set.seed(10) | |
rfTune <- train(x = trainX, | |
y = trainY, | |
method = "rf", | |
trControl = ctrl, | |
tuneGrid = mtryGrid, | |
ntree = 1000, | |
importance = TRUE, | |
...) | |
plot(rfTune) | |
print(rfTune) | |
rfTune | |
} else if ("cubist" == model) { | |
cat("\n---- running cubist...\n") | |
cbGrid <- expand.grid(.committees = c(1:10, 20, 50, 75, 100), | |
.neighbors = c(0, 1, 5, 9)) | |
set.seed(10) | |
cubistTune <- train(x = trainX, | |
y = trainY, | |
method = "cubist", | |
trControl = ctrl, | |
tuneGrid = cbGrid, | |
## tuneLength = 100, | |
...) | |
plot(cubistTune, auto.key = list(columns = 4, lines = TRUE)) | |
cubistTune | |
} | |
}, error=function(err) { | |
## message(paste("Failed running category:", label)) | |
cat("ORIGINAL ERROR MESSAGE:\n") | |
print(err) | |
return (NA) | |
}) | |
tuned | |
} | |
y.yhat.ggplot2 <- function(testResults) { | |
require(reshape2) | |
## The testResults must contain y as 1st column and one or more yhat columns | |
method.names <- names(testResults) | |
obs <- testResults[,1] | |
## random predictions by permutating the obs: repeat for 100 times | |
rand = replicate(100, sample(obs, size=length(obs), replace=F)) | |
rand.rmse = apply(rand, 2, caret::RMSE, obs) | |
rand.r2 = apply(rand, 2, caret::R2, obs) | |
for(i in 2:length(testResults)) { | |
pred <- testResults[,i] | |
df <- cbind(testResults[, c(1, i)], random=rand[,i]) | |
df.2 <- melt(df, id = 'obs', value.name='prediction') | |
p <- ggplot(df.2, aes(x=obs, y=prediction, color=variable)) + | |
geom_point(shape=16, size=3) + geom_abline(mapping=aes(slope=1, intercept=0)) | |
print(p) | |
} | |
} | |
y.yhat <- function(testResults) { | |
## The testResults must contain y as 1st column and one or more yhat columns | |
method.names <- names(testResults) | |
obs <- testResults[,1] | |
## random predictions by permutating the obs | |
rand = replicate(100, sample(obs, size=length(obs), replace=F)) | |
rand.rmse = apply(rand, 2, caret::RMSE, obs) | |
rand.r2 = apply(rand, 2, caret::R2, obs) | |
for(i in 2:length(testResults)) { | |
pred <- testResults[,i] | |
plot(pred ~ obs, | |
xlab=method.names[1], | |
ylab=method.names[i], | |
pch=20) | |
## plot one random permutation result too | |
points(x=obs, y=rand[,1], col='blue', pch=20) | |
abline(0, 1, col="red") | |
## mtext(paste(c("RMSE=", "R^2="), | |
## c(RMSE()))) | |
rmse <- format(round(caret::RMSE(pred, obs), 2), nsmall=2) | |
rsq <- format(round(caret::R2(pred, obs), 2), nsmall=2) | |
mtext(paste(c("RMSE","R^2 "), c(rmse, rsq), sep='=', collapse=' '), | |
line=1) | |
rand.rmse.mean <- format(round(mean(rand.rmse), 2), nsmall=2) | |
rand.rmse.sd <- format(round(sd(rand.rmse), 2), nsmall=2) | |
## the unicode is for plus/minus symbol | |
rand.rmse.str <- paste(rand.rmse.mean, '\u00b1', rand.rmse.sd) | |
rand.r2.mean <- format(round(mean(rand.r2), 2), nsmall=2) | |
rand.r2.sd <- format(round(sd(rand.r2), 2), nsmall=2) | |
rand.r2.str <- paste(rand.r2.mean, '\u00b1', rand.r2.sd) | |
mtext(paste(c("RMSE","R^2 "), c(rand.rmse.str, rand.r2.str), sep='=', collapse=' '), | |
col='blue') | |
## legend("topleft", text.col="black", "ab", | |
## paste(c("RMSE","R^2 "), c(rmse, rsq), sep='=', collapse='\n')) | |
} | |
} | |
diagnostics <- function(trainX, trainY, testX, testY, sizes=NULL, steps=15, repeats=3, | |
model='rf', metrics=c('RMSE', 'Rsquared')) { | |
require(caret) | |
if (is.null(sizes)) { | |
maxSize <- nrow(trainX) | |
minSize <- 32 | |
if ((maxSize - minSize) <= steps) { | |
sizes <- minSize:maxSize | |
} else { | |
sizes <- round(seq(minSize, maxSize, length.out=steps)) | |
} | |
} | |
partitions <- sizes / nrow(trainX) | |
error.types <- c('Train', 'CV', 'Test') | |
labels <- apply(expand.grid(metrics, error.types), 1, paste, collapse='.') | |
## the order of train, cv, test is consistent with what's inside for loop | |
error.array <- replicate(repeats, sapply(partitions, function(partition) { | |
rows <- createDataPartition(trainY, p=partition, list=F) | |
trX <- trainX[rows, ] | |
trY <- trainY[rows] | |
tuned <- regression.tune(trX, trY, model) | |
train.err <- postResample(trY, predict(tuned, trX))[metrics] | |
cv.err <- colMeans(tuned$resample[,metrics,drop=FALSE]) | |
test.err <- postResample(testY, predict(tuned, testX))[metrics] | |
x <- c(partition, train.err, cv.err, test.err) | |
})) | |
## error.arrays is now a 3-D array | |
if (FALSE) save(error.array, file='diagnostic.Rdata') | |
dimnames(error.array) <- list(d1=c('Fraction', labels), | |
d2=paste('Fraction', 1:steps, sep=''), | |
d3=paste('Repeat', 1:repeats, sep='')) | |
if (FALSE) save(error.array, file='diagnostic.Rdata') | |
t(as.data.frame(error.array)) | |
} | |
diagn.plot <- function (diagn, metric) { | |
## This function is closely related to diagnostic function. | |
if (class(diagn) != 'data.frame') | |
diagn = as.data.frame(diagn) | |
if (length(metric) > 1 | mode(metric) != 'character') | |
stop('Wrong metric') | |
require(ggplot2) | |
diagn$Repeat <- gsub('^.*\\.', '', rownames(diagn)) | |
i <- grep(paste('^', metric, sep=''), colnames(diagn), value=TRUE) | |
x <- diagn[, c('Fraction', 'Repeat', i)] | |
require(reshape2) | |
x <- melt(x, id=c('Fraction', 'Repeat')) | |
if (metric=='Rsquared') { | |
x.p <- ggplot(x, aes(x=Fraction, y= 1-value, | |
colour=variable, shape=Repeat, | |
group = interaction(variable, Repeat))) + | |
labs(y = paste('1 -', metric), | |
x='Fraction', | |
title='Learning Curve') | |
} else { | |
x.p <- ggplot(x, aes(x=Fraction, y= value, | |
colour=variable, shape=Repeat, | |
group = interaction(variable, Repeat))) + | |
labs(y = metric, | |
x='Fraction', | |
title='Learning Curve') | |
} | |
x.p + geom_line() + geom_point(size=3) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment