Skip to content

Instantly share code, notes, and snippets.

@primaryobjects
Last active October 13, 2017 13:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save primaryobjects/f551e121ad8c2db41ce021813c0f7f70 to your computer and use it in GitHub Desktop.
Save primaryobjects/f551e121ad8c2db41ce021813c0f7f70 to your computer and use it in GitHub Desktop.
Plotting a learning curve from image recognition data.
library(caret)
library(reshape2)
dataTrain <- data.frame(x = trainData$x[1:30000,], y = trainData$y[1:30000])
dataTest <- data.frame(x = testData$x, y = testData$y)
# Train.
trainctrl <- trainControl(verboseIter = TRUE, number=5, repeats=1, method='repeatedcv')
results <- data.frame()
for (i in 1:30) {
partialSet <- dataTrain[1:(1000 * i),]
fit <- train(y ~ ., data=partialSet, method = 'LogitBoost', trControl = trainctrl)
correct1 <- length(which(predict(fit, partialSet) == partialSet$y)) / nrow(partialSet)
correct2 <- length(which(predict(fit, dataTest) == dataTest$y)) / nrow(dataTest)
# Record accuracy history.
results <- rbind(results, c(correct1, correct2))
# Plot learning curve.
names(results) <- c('Train', 'CV')
r <- melt(results)
r <- cbind(r, seq(from = 1000, to = nrow(results) * 1000, by = 1000))
names(r) <- c('Set', 'Accuracy', 'Count')
print(ggplot(data = r, aes(x = Count, y = Accuracy, colour = Set)) + geom_line() + geom_smooth(method = 'lm', se=F))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment