Skip to content

Instantly share code, notes, and snippets.

@mcmtroffaes
Last active November 16, 2018 02:41
Show Gist options
  • Save mcmtroffaes/709908 to your computer and use it in GitHub Desktop.
Save mcmtroffaes/709908 to your computer and use it in GitHub Desktop.
k-fold cross validation script for R
kfcv.sizes = function(n, k=10) {
# generate sample sizes for k-fold cross validation on a data set of
# size n
# author: Matthias C. M. Troffaes
# date: 22 Nov 2010
# license: GPLv3
# usage:
#
# kfcv.sizes(n, k=...)
#
sizes = c()
for (i in 1:k) {
first = 1 + (((i - 1) * n) %/% k)
last = ((i * n) %/% k)
sizes = append(sizes, last - first + 1)
}
sizes
}
kfcv.testing = function(n, k=10) {
# generate testing sample indices for k-fold cross validation on a
# data set of size n
# author: Matthias C. M. Troffaes
# date: 22 Nov 2010
# license: GPLv3
# usage:
#
# kfcv.testing(n, k=...)
#
indices = list()
sizes = kfcv.sizes(n, k=k)
values = 1:n
for (i in 1:k) {
# take a random sample of given size
s = sample(values, sizes[i])
# append random sample to list of indices
indices[[i]] = s
# remove sample from values
values = setdiff(values, s)
}
indices
}
kfcv.classifier = function(data, attribs, class, make.classifier, k=10) {
# run k-fold cross validation with an arbitrary classifier
# author: Matthias C. M. Troffaes
# date: 27 July 2018
# license: GPLv3
# usage:
#
# kfcv.classifier(data, attribs, class, make.classifier, k=...)
#
# where data is the data frame (each column is an attribute, and
# each row is an observation), class is the column index for the
# attribute to be predicted
# make.classifier is a function which takes a training set,
# attribute column indices, and a class column index; it returns a
# function which takes a single row from a test set and returns a
# list of test results (e.g. the predicted class, whether the
# classifier predicted correctly, utility for misclassification,
# ...)
do.call(
rbind.data.frame,
lapply(
kfcv.testing(dim(data)[1], k=k),
function(testingindices) {
classifier = make.classifier(data[-testingindices,], attribs, class)
classifier(data[1,])
do.call(
rbind.data.frame,
lapply(
testingindices,
function(rowid) { classifier(data[rowid,]) }
))
})
)
}
make.classifier.naivebayes = function(train, attribs, class, laplace=0) {
# simple example of a classifier
require(e1071)
model = naiveBayes(train[,attribs], train[,class], laplace=laplace)
classes = levels(train[,class])
function(testrow) {
actualclass = testrow[1,class]
probs = predict(model, testrow, type="raw")
predictedclass = factor(classes[which.max(probs)], classes)
list(
class=predictedclass,
acc=(predictedclass==actualclass),
prob=max(probs))
}
}
kfcv.sizes.test = function() {
# test simple cases
stopifnot(kfcv.sizes(10, k=2) == c(5, 5))
stopifnot(kfcv.sizes(10, k=5) == c(2, 2, 2, 2, 2))
stopifnot(kfcv.sizes(12, k=5) == c(2, 2, 3, 2, 3))
# test that sum of sample sizes is total sample size
for (k in 1:10) {
for (n in 1:100) {
sizes = kfcv.sizes(n, k=k)
stopifnot(length(sizes) == k)
stopifnot(sum(sizes) == n);
}
}
}
kfcv.testing.test = function() {
# set seed so test is deterministic
set.seed(10)
# 3 fold sample from 10 indices
indices = kfcv.testing(10, k=3)
stopifnot(length(indices) == 3)
stopifnot(indices[[1]] == c(6, 3, 4))
stopifnot(indices[[2]] == c(8, 1, 2))
stopifnot(indices[[3]] == c(7, 5, 10, 9))
}
kfcv.example = function() {
# the data
data = c("hello", "world", "dog", "cat", "fox", "rabbit", "gnome", "orc", "imp", "zombie", "vampire")
# generate testing and training samples for 4-fold cross validation
for (testingindices in kfcv.testing(length(data), k=4)) {
testing = data[testingindices]
training = data[-testingindices]
print(testing)
print(training)
print(NULL)
}
}
kfcv.example2 = function() {
# the data
data(iris)
disc = function(xs) xs >= mean(xs)
iris[,1] = disc(iris[,1])
iris[,2] = disc(iris[,2])
iris[,3] = disc(iris[,3])
iris[,4] = disc(iris[,4])
classifier = make.classifier.naivebayes(iris, 1:4, 5)
result = kfcv.classifier(iris, 1:4, 5, make.classifier.naivebayes)
colMeans(result[,c("prob", "acc")])
}
#kfcv.sizes.test()
#kfcv.testing.test()
#kfcv.example()
#kfcv.example2()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment