Skip to content

Instantly share code, notes, and snippets.

@zmjones
Last active December 31, 2015 06:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zmjones/7945364 to your computer and use it in GitHub Desktop.
Save zmjones/7945364 to your computer and use it in GitHub Desktop.
k-fold cross-validation using generic fitting and loss functions
require(parallel)
validate.cv <- function(df, folds, resamples, model, loss, cores) {
mclapply(1:resamples, function(x) {
df$folds <- sample(rep(1:folds, length.out = nrow(df)))
lapply(1:folds, function(test) {
fit <- model(df[df$folds != test, ])
loss(fit, df[df$folds == test, ])
})}, mc.cores = CORES)
}
n <- 1000
x1 <- rnorm(n)
x2 <- rnorm(n)
y <- 2 * x1 + 2 * x2 + rnorm(n)
df <- data.frame(y, x1, x2)
loss.rmse <- function(fit, df) sqrt((predict(fit, newdata = df) - df$y)^2)
model.add <- function(df) lm(y ~ x1 + x2, df)
cv.add <- validate.cv(df, 10, 100, model.add, loss.rmse, detectCores())
cv <- as.numeric(unlist(cv.add))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment