Skip to content

Instantly share code, notes, and snippets.

@aaroncharlton
Created May 12, 2015 20:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save aaroncharlton/2a58914de3471052f798 to your computer and use it in GitHub Desktop.
Save aaroncharlton/2a58914de3471052f798 to your computer and use it in GitHub Desktop.
## model selection
# This is Max Kuhn's tutorial on caret: http://topepo.github.io/caret/training.html
library(mlbench)
data(Sonar)
str(Sonar[, 1:10])
library(caret)
set.seed(998)
inTraining <- createDataPartition(Sonar$Class, p = .75, list = FALSE)
training <- Sonar[ inTraining,]
testing <- Sonar[-inTraining,]
# basic parameter tuning
fitControl <- trainControl(
method = "repeatedcv",
number = 10,
repeats = 10)
set.seed(825)
gbmFit1 <- train(Class ~ ., data = training,
method = "gbm",
trControl = fitControl,
## This last option is actually one
## for gbm() that passes through
verbose = FALSE)
gbmFit1
# alternate tuning grids
gbmGrid <- expand.grid(
n.trees = (1:30)*50,
interaction.depth = c(1, 5, 9),
shrinkage = 0.1,
n.minobsinnode = 10)
nrow(gbmGrid)
set.seed(825)
gbmFit2 <- train(Class ~ ., data = training,
method = "gbm",
trControl = fitControl,
verbose = FALSE,
## Now specify the exact models
## to evaludate:
tuneGrid = gbmGrid)
gbmFit2
# plotting the resampling profile
trellis.par.set(caretTheme())
plot(gbmFit2)
trellis.par.set(caretTheme())
plot(gbmFit2, metric = "Kappa")
trellis.par.set(caretTheme())
plot(gbmFit2, metric = "Kappa", plotType = "level",
scales = list(x = list(rot = 90)))
ggplot(gbmFit2)
fitControl <- trainControl(method = "repeatedcv",
number = 10,
repeats = 10,
classProbs = TRUE,
summaryFunction = twoClassSummary)
set.seed(825)
gbmFit3 <- train(Class ~ ., data = training,
method = "gbm",
trControl = fitControl,
verbose = FALSE,
tuneGrid = gbmGrid,
## Specify which metric to optimize
metric = "ROC")
gbmFit3
## choosing the final model
whichTwoPct <- tolerance(gbmFit3$results, metric = "ROC",
tol = 2, maximize = TRUE)
cat("best model within 2 pct of best:\n")
gbmFit3$results[whichTwoPct,1:6]
## Extracting class probabilities
predict(gbmFit3, newdata = head(testing))
predict(gbmFit3, newdata = head(testing), type = "prob")
## between models
set.seed(825)
svmFit <- train(Class ~ ., data = training,
method = "svmRadial",
trControl = fitControl,
preProc = c("center", "scale"),
tuneLength = 8,
metric = "ROC")
svmFit
set.seed(825)
rdaFit <- train(Class ~ ., data = training,
method = "rda",
trControl = fitControl,
tuneLength = 4,
metric = "ROC")
rdaFit
resamps <- resamples(list(GBM = gbmFit3,
SVM = svmFit,
RDA = rdaFit))
resamps
summary(resamps)
trellis.par.set(theme1)
bwplot(resamps, layout = c(3, 1))
trellis.par.set(caretTheme())
dotplot(resamps, metric = "ROC")
trellis.par.set(theme1)
xyplot(resamps, what = "BlandAltman")
splom(resamps)
difValues <- diff(resamps)
difValues
summary(difValues)
trellis.par.set(theme1)
bwplot(difValues, layout = c(3, 1))
trellis.par.set(caretTheme())
dotplot(difValues)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment