Skip to content

Instantly share code, notes, and snippets.

@smc77
Created November 9, 2011 03:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save smc77/1350234 to your computer and use it in GitHub Desktop.
Save smc77/1350234 to your computer and use it in GitHub Desktop.
Generalization
#
# Let's look at how the different models generalize between different datasets
#
n.training <- 10
n.test <- 100
error.function <- function(y, y.pred) sum((y.pred - y)^2) / 2
e.rms <- function(y, y.pred) sqrt(2 * error.function(y=y, y.pred=y.pred) / length(y))
build.data <- function(n) {
f <- function(x) sin(2 * pi * x)
x <- seq(0, 1, length=n)
y <- f(x) + rnorm(n, sd=0.2)
return(data.frame(y=y, x=x))
}
training <- build.data(n=n.training)
test <- build.data(n=n.test)
test.poly.error <- function(training, test, polynomials=1:9) {
errors.training <- errors.test <- numeric()
for(i in polynomials) {
fit <- lm(y ~ poly(x, i, raw=TRUE), data=training)
y.pred.training <- predict(fit)
errors.training[i] <- e.rms(training$y, y.pred.training)
y.pred.test <- predict(fit, newdata=test)
errors.test[i] <- e.rms(test$y, y.pred.test)
}
errors <- data.frame(polynomial=polynomials, training.error=errors.training, test.error=errors.test)
return(errors)
}
library(ggplot2)
errors <- test.poly.error(training, test)
errors <- melt(errors, x)
colnames(errors) <- c("polynomial", "dataset", "error")
p <- ggplot(errors, aes(x=polynomial, y=error, grouping=dataset, colour=dataset)) + geom_line()
p
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment