-
-
Save walterst/2222618976a66b3fc8dd to your computer and use it in GitHub Desktop.
See comment below about additional files, configuration needed. These scripts were written by Zhenjiang Xu.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env Rscript | |
## Code written by Zhenjiang Xu | |
read.table.x = function (filename, ...) { | |
## It will skip all the beginning comment lines except | |
## the last line. | |
## This is specifically for QIIME related tables. | |
lines <- readLines(filename) | |
## lines of comment | |
n <- grep("^#", lines) | |
if(length(n) > 0) | |
start <- n[length(n)] | |
else | |
start <- 1 | |
end <- length(lines) | |
x <- read.table(text=lines[start:end], | |
header=T, | |
sep='\t', | |
comment.char='', | |
check.names=F, | |
...) | |
} | |
library(optparse) | |
## Replace with path to ml_util.R on local system | |
source('/Users/tony/code/r_scripts_new/ml_util.R') | |
opt <- interface() | |
if (opt$debug) save.image('debug.Rdata') | |
if (opt$verbose) { | |
cat("Running command with args:\n", | |
paste(commandArgs(), collapse = " "), | |
'\n') | |
} | |
if (opt$split <= 0 & opt$split > 1) { | |
stop("The split arg should be greater than 0 and not greater than 1.") | |
} | |
if(is.null(opt$models)) { | |
models <- regression | |
} else { | |
models <- strsplit(opt$models, ',')[[1]] | |
} | |
library(caret) | |
if (opt$cores > 1) { | |
library(doMC) | |
registerDoMC(opt$cores) | |
} | |
meta <- read.table.x(opt$metadata) | |
meta.col <- colnames(meta) | |
if (is.null(opt$fields)) { | |
## stop("No field is provided to do regression on.") | |
outcome.col <- names(meta) | |
boring <- c("#SampleID", | |
"BarcodeSequence", | |
"LinkerPrimerSequence", | |
"TARGET_SUBFRAGMENT", | |
"ASSIGNED_FROM_GEO", | |
"EXPERIMENT_CENTER", | |
"RUN_PREFIX", | |
"TAXON_ID", | |
"ILLUMINA_TECHNOLOGY", | |
"COMMON_NAME", | |
"EXTRACTED_DNA_AVAIL_NOW", | |
"SAMPLE_CENTER", | |
"STUDY_CENTER", | |
"Description") | |
outcome.col <- outcome.col[! outcome.col %in% boring] | |
} else { | |
outcome.col <- strsplit(opt$fields, ',')[[1]] | |
x <- which(! outcome.col %in% meta.col) | |
if (length(x) > 0) { | |
stop("Field(s) ", paste(outcome.col[x], collapse=','), " do not exist in meta data") | |
} | |
} | |
## extract part of the samples by their meta data | |
if (! is.null(opt$category)) { | |
## e.g. "SITE::nostril,skin:_:SEX::male" | |
extract <- strsplit(opt$category, ':_:')[[1]] | |
extract <- strsplit(extract, '::') | |
for (x in extract) { | |
if (! x[1] %in% meta.col) | |
stop("The field ", x[1], " does not exist in meta data") | |
i <- meta[[ x[1] ]] | |
j <- strsplit(x[2], ',')[[1]] | |
if (! all(j %in% i)) { | |
## insanity check to avoid typos | |
stop("You specified non-existing values for field ", x[1], " in meta data") | |
} | |
meta <- meta[ i %in% j, ] | |
} | |
} | |
if (! is.null(opt$numeric)) { | |
## e.g. "PH::6,12:_:TEMP::,32" | |
extract <- strsplit(opt$numeric, ':_:')[[1]] | |
extract <- strsplit(extract, '::') | |
for (x in extract) { | |
if (! x[1] %in% meta.col) | |
stop("The field ", x[1], " does not exist in meta data") | |
## in case there is None, NA, etc in the column (R will read it into character | |
## instead of numerical) | |
i <- as.numeric(as.character(meta[[ x[1] ]])) | |
j <- as.numeric(strsplit(x[2], ',')[[1]]) | |
## NA & TRUE -> NA | |
## NA & FALSE -> FALSE | |
n <- rep(T, nrow(meta)) | |
if (! is.na(j[1])) | |
n <- n & i > j[1] | |
if (! is.na(j[2])) | |
n <- n & i < j[2] | |
meta <- meta[n, ] | |
} | |
} | |
otus <- read.table.x(opt$input_otu_table) | |
tax.16s <- otus[, length(otus)] | |
tax.16s <- gsub("^Root; ", "", tax.16s) | |
## insert a newline for every three levels of taxonomy | |
tax.16s <- gsub("([^;]*); ([^;]*); ([^;]*); ", '\\1; \\2; \\3\n', tax.16s) | |
names(tax.16s) <- otus[[1]] | |
## remove the 6-digit suffix of the sample IDs in the mapping file. | |
## meta.sid <- gsub(".[0-9]{6}$", "", as.character(meta[[1]])) | |
meta.sid <- as.character(meta[[1]]) | |
rownames(meta) <- meta.sid | |
sample.ids <- intersect(meta.sid, colnames(otus)) | |
meta <- meta[sample.ids, ] | |
rownames(otus) <- otus[[1]] | |
otus <- data.frame(t(otus[, sample.ids]), check.names=FALSE) | |
if (opt$debug) save.image('debug.Rdata') | |
## add numeric fields as predictors | |
if (! is.null(opt$add_numeric)) { | |
add.pred <- strsplit(opt$add_numeric, ',', fixed=TRUE)[[1]] | |
not.in <- which(! add.pred %in% colnames(meta)) | |
if (length(not.in) > 0) { | |
stop(paste(c(add.pred[not.in], "not in the meta data!!!"), collapse=' ')) | |
} | |
to.add <- meta[, add.pred] | |
not.numeric <- which(! sapply(to.add, is.numeric)) | |
if (length(not.numeric) > 0) { | |
stop(paste(c(add.pred[not.numeric], "not numeric!!!"), collapse='')) | |
} | |
otus <- cbind(meta[, add.pred], otus) | |
colnames(otus)[1:length(add.pred)] <- add.pred | |
} | |
## add the categorical fields in the meta data as predictors | |
if (! is.null(opt$add_category)) { | |
add.pred <- strsplit(opt$add_category, ',', fixed=TRUE)[[1]] | |
not.in <- which(! add.pred %in% colnames(meta)) | |
if (length(not.in) > 0) { | |
stop(paste(c(add.pred[not.in], "not in the meta data!!!"), collapse=' ')) | |
} | |
if (opt$debug) save.image('debug.Rdata') | |
x <- meta[, add.pred, drop=FALSE] | |
dummy <- dummyVars(~., data=x) | |
otus <- cbind(predict(dummy, x), otus) | |
} | |
pdf(sprintf("%s.pdf", opt$output)) | |
min.sample.size <- 12 | |
accuracies <- data.frame() | |
top.features <- data.frame() | |
big.tuned.list <- list() | |
replicate <- meta[, opt$replicate] | |
for(label in outcome.col) { | |
if(opt$verbose) | |
cat("=========", label, ":\n") | |
outcome <- as.character(meta[[label]]) | |
## the space in the class names cause problem for some models | |
outcome <- as.factor(gsub("[[:space:]]+", "_", outcome)) | |
## remove NA values | |
outcome.na <- is.na(outcome) | |
not.na <- (! outcome.na) & complete.cases(otus) | |
## if more than half of the samples are not numeric | |
## if(sum(not.na) < 0.5 * length(outcome)) { | |
## if(opt$verbose) | |
## cat("outcome has less than half of numeric values. skip it.\n") | |
## next | |
## } | |
if (! is.null(opt$replicate)) { | |
replicate <- replicate[not.na] | |
uniq_rep <- unique(replicate) | |
## create 5-repeat 10 folds | |
idx_rep <- createMultiFolds(uniq_rep) | |
idx <- lapply(idx_rep, | |
function (i) { | |
reps <- uniq_rep[i] | |
which(replicate %in% reps) | |
}) | |
} else { | |
idx <- NULL | |
} | |
outcome <- outcome[not.na] | |
if (opt$verbose) { | |
cat("---- a glimpse of outcome:\n") | |
print(table(outcome)) | |
} | |
if (opt$debug) save.image('debug.Rdata') | |
if (length(outcome) < min.sample.size) | |
stop("There should be more than ", min.sample.size, " samples.") | |
## if there less than 2 classes in this category | |
if (nlevels(outcome) < 2) { | |
if(opt$verbose) | |
cat("outcome has less than 2 classes. skip it.\n") | |
next | |
} | |
train.set <- otus[not.na, ] | |
if (opt$split < 1) { | |
set.seed(1) | |
training.rows <- createDataPartition(outcome, p=opt$split, list=F) | |
} else { | |
training.rows <- 1:length(outcome) | |
} | |
train.full <- train.set[training.rows, ] | |
test.full <- train.set[-training.rows, ] | |
train.outcome <- outcome[training.rows] | |
test.outcome <- outcome[-training.rows] | |
if (opt$debug) save.image('debug.Rdata') | |
nzv <- nearZeroVar(train.full) | |
if (length(nzv) > 0) { | |
train.full <- train.full[, -nzv] | |
test.full <- test.full[, -nzv] | |
} | |
tooHigh <- findCorrelation(cor(train.full), .9) | |
if (length(tooHigh) > 0) { | |
train.full <- train.full[, -tooHigh] | |
test.full <- test.full[, -tooHigh] | |
} | |
## save the test set results in a data.frame | |
if (length(test.outcome) > 0) | |
testResults <- data.frame(obs=test.outcome) | |
if (opt$debug) save.image('debug.Rdata') | |
## benchmark the specified models | |
tuned.list <- list() | |
accu <- data.frame() | |
top.f <- data.frame() | |
for (model in models) { | |
if (opt$feature_selection) { | |
fiveStats <- function(...) c(twoClassSummary(...), defaultSummary(...)) | |
ctrl <- rfeControl(method = "repeatedcv", | |
repeats = 5, number=10, | |
index = idx, | |
saveDetails = TRUE) | |
## random forest | |
ctrl$functions <- rfFuncs | |
if (nlevels(train.outcome) == 2) { | |
ctrl$functions$summary <- fiveStats | |
} else { | |
ctrl$functions$summary <- defaultSummary | |
} | |
set.seed(721) | |
tuned <- rfe(train.full, | |
train.outcome, | |
sizes = seq(10, ncol(train.full)-10, by=10), | |
metric = "Kappa", | |
ntree = 1000, | |
rfeControl = ctrl) | |
tuned$method = 'rf' | |
} else { | |
tuned <- classification.tune(train.full, train.outcome, model, idx=idx) | |
## if(is.na(tuned) | is.null(tuned)) next | |
if (class(tuned) != 'train') { | |
cat("Warning message:\nModel ", model, " failed.\n") | |
next | |
} | |
} | |
tuned.list[[model]] <- tuned | |
if (opt$verbose) { | |
print(tuned) | |
} | |
save.image(sprintf("%s.Rdata", opt$output)) | |
## add a new column - Model | |
accu <- rbind(accu, cbind(tuned$resample, Model=tuned$method)) | |
if (length(test.outcome) > 0) | |
testResults[model] <- predict(tuned, test.full) | |
imp <- varImp(tuned, scale=FALSE) | |
top.f <- rbind(top.f, | |
data.frame(imp$importance[order(imp$importance, | |
decreasing=T),,drop=FALSE], | |
Model=model)) | |
## plot top features | |
pimp <- plot.imp(imp, tax.16s, main=model) | |
print(pimp, position=c(0, 0, 0.56, 1)) | |
save.image(sprintf("%s.Rdata", opt$output)) | |
} | |
accu$Field <- label | |
accuracies <- rbind(accuracies, accu) | |
top.f$Field <- label | |
top.features <- rbind(top.features, top.f) | |
## if (opt$verbose) print(accuracies) | |
big.tuned.list[[label]] <- tuned.list | |
if (length(tuned.list) > 1) { | |
## compare the model performances | |
resamp <- resamples(tuned.list) | |
m.diff <- diff(resamp) | |
if (opt$verbose) print(summary(m.diff)) | |
print(dotplot(m.diff)) | |
} | |
## plot yhat vs. obs | |
if (length(test.outcome) > 0) { | |
method.names <- names(testResults) | |
obs <- testResults[,1] | |
for(i in 2:length(testResults)) { | |
pred <- testResults[,i] | |
plot(pred ~ obs, | |
xlab=method.names[1], ylab=method.names[i]) | |
abline(0, 1, col="red") | |
## mtext(paste(c("RMSE=", "R^2="), | |
## c(RMSE()))) | |
rmse <- format(round(caret::RMSE(pred, obs), 2), nsmall=2) | |
rsq <- format(round(caret::R2(pred, obs), 2), nsmall=2) | |
legend("topleft", text.col="blue", "ab", | |
paste(c("RMSE","R^2 "), c(rmse, rsq), sep='=', collapse='\n')) | |
} | |
} | |
} | |
dev.off() | |
save.image(sprintf("%s.Rdata", opt$output)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This was tested in R version 3.1.2. Additional packages needed in R: caret, pROC, doMC, pls, e1071, and optparse. e.g.:
install.packages(c('caret','pROC','doMC','pls','e1071','optparse'))
Versions tested:
caret 6.0.58
pROC 1.8
doMC 1.3.4
pls 2.5.0
e1071 1.6.7
optparse 1.3.0
The R scripts ml_util.R and regression.R must be downloaded as well:
https://gist.github.com/walterst/9444f900e5382ba50c92
https://gist.github.com/walterst/5e6f47bf12f8ac334e36
For OTU level data, the taxonomy strings can be added to the feature importance scores with the script located here:
https://gist.github.com/walterst/71b587c1eb7a90a69297
This line must be changed in both the regression.R and classification.R to point to the location on your system:
source('/Users/tony/code/r_scripts_new/ml_util.R')
To run the script, you'll need an OTU table in tab-separated format, and a QIIME-compatible metadata mapping file. If you have sparse metadata, leave those fields empty in your mapping file rather than putting in strings such as "NA". You cannot have one category of samples completely missing all metadata, or the supervised learning will fail. Run supervised learning using the example given below.
Example of running scripts with metadata added:
Rscript classification.R -i otu_table_tab_sep -m mapping_fp -o output_base_filename -f metadata_cat -r 'rf' --add_numeric "comma-separated categories" --add_category "comma-separated categories"
Example command ran with real data
Rscript classification.R -i ../filtered_otu_tables/otu_table_even_14553_no_HC.txt -m ../validate_mapping/IBD_mapping_mar_16_metadata_sparse_no_HC.txt -o ../ROC_results_new_scripts/even_sampled_no_HC -f IBD_TYPE -r 'rf' --add_numeric "BMI,CALPROTECTIN,FROM_HEALTHY_W,FROM_HEALTHY_UW,IBD_TODAY,NUM_BOWEL_MOVE,NUM_DIARRHEA," --add_category "CPT_FLARE,CPT_FLARE_TIME,IN_HEALTHY_W,IN_HEALTHY_UW,BAC_CULTURE,AB_PAIN,BLOOD_IN_STOOL,MUCUS_IN_STOOL,RELAPSE,GASTRO"
To parse the resulting .Rdata files:
library(caret)
load('Rdata file')
Create and write confusion matrix
cm = confusionMatrix(big.tuned.list$IBD_TYPE$rf)
replace IBD_TYPE with the metadata category used for the -f parameter when calling classification.R earlier
write.table(cm$table, 'cm.txt', sep='\t')
Create, sort, and write importance scores
i = imp$importance
i = i[order(rowSums(i), decreasing=T), ]
write.table(i, 'imp.txt', sep='\t')
To get average accuracies (percent of all correctly classified samples):
a=accuracies[-nrow(accuracies),]
m=mean(as.numeric(a$Accuracy))
s = sd(as.numeric(a$Accuracy))
To create ROC AUC plots (replace IBD_DIAGNOSIS with metadata category, and CCD with subset of metadata to query):
library(pROC)
aa = big.tuned.list$IBD_DIAGNOSIS$rf
pred = aa$pred$CCD
ref = (as.character(aa$pred$obs) == 'CCD')
plot.roc(x=ref, predictor=pred, main="ROC Plots", col="blue")
pred = aa$pred$HC
ref = (as.character(aa$pred$obs) == 'HC')
plot.roc(x=ref, predictor=pred, col="green", add=TRUE)
And so on for other categories...
legend("bottomright", legend=c("CCD", "HC"), col=c("blue", "green"), lwd=2)