Skip to content

Instantly share code, notes, and snippets.

@jeffreyiacono
Last active August 29, 2015 14:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeffreyiacono/4535f8bad1300183cfcd to your computer and use it in GitHub Desktop.
Save jeffreyiacono/4535f8bad1300183cfcd to your computer and use it in GitHub Desktop.
PML: combining predictors
library(ISLR)
library(ggplot2)
library(caret)
data(Wage)
# remove logwag as predicting wage, so this would be a pretty good predictor :)
Wage <- subset(Wage, select = -c(logwage))
# cross-validation data
inBuild <- createDataPartition(y = Wage$wage, p = 0.7, list = FALSE)
validation <- Wage[-inBuild, ]
# training / test data
buildData <- Wage[inBuild, ]
inTrain <- createDataPartition(y = buildData$wage, p = 0.7, list = FALSE)
training <- buildData[inTrain, ]; testing <- buildData[-inTrain, ]
# train
mod1 <- train(wage ~., method = "glm", data = training)
mod2 <- train(wage ~., method = "rf", data = training, trControl = trainControl(method = "cv"), number = 3)
# predict
pred1 <- predict(mod1, testing)
pred2 <- predict(mod2, testing)
# plot it: will demonstrate difference between model predictions + highlight how
# accurate / inaccurate the overall prediction was graphically
qplot(pred1, pred2, color = wage, data = testing)
# combine predictors and re-predict
predDF <- data.frame(pred1, pred2, wage = testing$wage)
combModFit <- train(wage ~ ., method = "gam", data = predDF)
combPred <- predict(combModFit, predDF)
# show that sse is less for combined vs singular predictors
sqrt(sum((pred1 - testing$wage)^2))
sqrt(sum((pred2 - testing$wage)^2))
sqrt(sum((combPred - testing$wage)^2))
# cross-validate predictions
pred1V <- predict(mod1, validation); pred2V <- predict(mod2, validation)
predVDF <- data.frame(pred1 = pred1V, pred2 = pred2V)
combPredV <- predict(combModFit, predVDF)
# again show that sse was reduced using the validation data set and the
# combination of the two models
sqrt(sum((pred1V - validation$wage)^2))
sqrt(sum((pred2V - validation$wage)^2))
sqrt(sum((combPredV - validation$wage)^2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment