Last active
November 16, 2018 02:41
-
-
Save mcmtroffaes/709908 to your computer and use it in GitHub Desktop.
k-fold cross validation script for R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
kfcv.sizes = function(n, k=10) { | |
# generate sample sizes for k-fold cross validation on a data set of | |
# size n | |
# author: Matthias C. M. Troffaes | |
# date: 22 Nov 2010 | |
# license: GPLv3 | |
# usage: | |
# | |
# kfcv.sizes(n, k=...) | |
# | |
sizes = c() | |
for (i in 1:k) { | |
first = 1 + (((i - 1) * n) %/% k) | |
last = ((i * n) %/% k) | |
sizes = append(sizes, last - first + 1) | |
} | |
sizes | |
} | |
kfcv.testing = function(n, k=10) { | |
# generate testing sample indices for k-fold cross validation on a | |
# data set of size n | |
# author: Matthias C. M. Troffaes | |
# date: 22 Nov 2010 | |
# license: GPLv3 | |
# usage: | |
# | |
# kfcv.testing(n, k=...) | |
# | |
indices = list() | |
sizes = kfcv.sizes(n, k=k) | |
values = 1:n | |
for (i in 1:k) { | |
# take a random sample of given size | |
s = sample(values, sizes[i]) | |
# append random sample to list of indices | |
indices[[i]] = s | |
# remove sample from values | |
values = setdiff(values, s) | |
} | |
indices | |
} | |
kfcv.classifier = function(data, attribs, class, make.classifier, k=10) { | |
# run k-fold cross validation with an arbitrary classifier | |
# author: Matthias C. M. Troffaes | |
# date: 27 July 2018 | |
# license: GPLv3 | |
# usage: | |
# | |
# kfcv.classifier(data, attribs, class, make.classifier, k=...) | |
# | |
# where data is the data frame (each column is an attribute, and | |
# each row is an observation), class is the column index for the | |
# attribute to be predicted | |
# make.classifier is a function which takes a training set, | |
# attribute column indices, and a class column index; it returns a | |
# function which takes a single row from a test set and returns a | |
# list of test results (e.g. the predicted class, whether the | |
# classifier predicted correctly, utility for misclassification, | |
# ...) | |
do.call( | |
rbind.data.frame, | |
lapply( | |
kfcv.testing(dim(data)[1], k=k), | |
function(testingindices) { | |
classifier = make.classifier(data[-testingindices,], attribs, class) | |
classifier(data[1,]) | |
do.call( | |
rbind.data.frame, | |
lapply( | |
testingindices, | |
function(rowid) { classifier(data[rowid,]) } | |
)) | |
}) | |
) | |
} | |
make.classifier.naivebayes = function(train, attribs, class, laplace=0) { | |
# simple example of a classifier | |
require(e1071) | |
model = naiveBayes(train[,attribs], train[,class], laplace=laplace) | |
classes = levels(train[,class]) | |
function(testrow) { | |
actualclass = testrow[1,class] | |
probs = predict(model, testrow, type="raw") | |
predictedclass = factor(classes[which.max(probs)], classes) | |
list( | |
class=predictedclass, | |
acc=(predictedclass==actualclass), | |
prob=max(probs)) | |
} | |
} | |
kfcv.sizes.test = function() { | |
# test simple cases | |
stopifnot(kfcv.sizes(10, k=2) == c(5, 5)) | |
stopifnot(kfcv.sizes(10, k=5) == c(2, 2, 2, 2, 2)) | |
stopifnot(kfcv.sizes(12, k=5) == c(2, 2, 3, 2, 3)) | |
# test that sum of sample sizes is total sample size | |
for (k in 1:10) { | |
for (n in 1:100) { | |
sizes = kfcv.sizes(n, k=k) | |
stopifnot(length(sizes) == k) | |
stopifnot(sum(sizes) == n); | |
} | |
} | |
} | |
kfcv.testing.test = function() { | |
# set seed so test is deterministic | |
set.seed(10) | |
# 3 fold sample from 10 indices | |
indices = kfcv.testing(10, k=3) | |
stopifnot(length(indices) == 3) | |
stopifnot(indices[[1]] == c(6, 3, 4)) | |
stopifnot(indices[[2]] == c(8, 1, 2)) | |
stopifnot(indices[[3]] == c(7, 5, 10, 9)) | |
} | |
kfcv.example = function() { | |
# the data | |
data = c("hello", "world", "dog", "cat", "fox", "rabbit", "gnome", "orc", "imp", "zombie", "vampire") | |
# generate testing and training samples for 4-fold cross validation | |
for (testingindices in kfcv.testing(length(data), k=4)) { | |
testing = data[testingindices] | |
training = data[-testingindices] | |
print(testing) | |
print(training) | |
print(NULL) | |
} | |
} | |
kfcv.example2 = function() { | |
# the data | |
data(iris) | |
disc = function(xs) xs >= mean(xs) | |
iris[,1] = disc(iris[,1]) | |
iris[,2] = disc(iris[,2]) | |
iris[,3] = disc(iris[,3]) | |
iris[,4] = disc(iris[,4]) | |
classifier = make.classifier.naivebayes(iris, 1:4, 5) | |
result = kfcv.classifier(iris, 1:4, 5, make.classifier.naivebayes) | |
colMeans(result[,c("prob", "acc")]) | |
} | |
#kfcv.sizes.test() | |
#kfcv.testing.test() | |
#kfcv.example() | |
#kfcv.example2() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment