Skip to content

Instantly share code, notes, and snippets.

@jake-westfall
Created May 20, 2017 00:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jake-westfall/00557e170cfc219d15f3af029e3b6b21 to your computer and use it in GitHub Desktop.
Save jake-westfall/00557e170cfc219d15f3af029e3b6b21 to your computer and use it in GitHub Desktop.
# function to fold data into k folds. this returns a list of matrices where
# the 1st column in each is the response and all other columns are predictors
fold <- function(y, X, k){
n <- length(y)
lapply(0:(k-1)*n/k + 1, function(i){
cbind(y, X)[seq(from=i, length.out=n/k),]
})
}
# function to compute MSE for datasets with different numbers of folds
sim <- function(n=50, p=3){
# create the full dataset. predictors are uniform in [0,1],
# uncorrelated, and with slopes evenly spaced from 0 to 2
X <- cbind(1, matrix(runif(p*n), ncol=p))
y <- X %*% cbind(c(0, seq(0,2,length.out=p))) + rnorm(n)
# fold the same data 1 of 4 different ways and get MSE for each fold
sapply(c(2,5,10,n), function(k){
dat <- fold(y, X, k)
# do the computations below for the ith fold
# then average the resulting MSEs over all k folds
mean(sapply(seq(k), function(i){
# create the training set by deleting the ith fold
train <- Reduce(rbind, dat[seq(k)[-i]])
# fit model to training set. parameters are computed by hand
# (using the normal equations) for speed reasons
beta <- solve(t(train[,-1]) %*% train[,-1]) %*% t(train[,-1]) %*% train[,1]
# compute mean squared error for test set
yh <- matrix(dat[[i]], ncol=p+2)[,-1] %*% beta
e2 <- (matrix(dat[[i]], ncol=p+2)[,1] - yh)^2
# compute the mean squared error across all n/k test observations
mean(e2)
}))
})
}
# estimate sim time: linear extrapolation from 100 iterations
system.time(replicate(100, sim(n=50, p=3)))
# 1.901 seconds
# estimated time: 3.16 minutes
# run full simulation (10000 iterations)
results <- replicate(10000, sim(n=50, p=3))
rownames(results) <- paste("k =",c(2,5,10,50))
# print means and variances of MSEs for each k
round(rbind(mean=rowMeans(results),
variance=apply(results, 1, var)), 3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment