Skip to content

Instantly share code, notes, and snippets.

@ramhiser
Last active July 24, 2018 15:06
Show Gist options
  • Save ramhiser/f411c6e86aced32232d0e160eff9aad7 to your computer and use it in GitHub Desktop.
Save ramhiser/f411c6e86aced32232d0e160eff9aad7 to your computer and use it in GitHub Desktop.
An illustration of underfitting and overfitting on an unknown curve compared with a random forest
library(tidyverse)
library(randomForest)
library(rpart)
set.seed(42)
num_points <- 20
x <- sort(runif(num_points, min=-5, max=6))
y <- x^2/5 + sin(3*x) # + rnorm(num_points, sd=0.1)
df <- data_frame(x=x, y=y)
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p
ggsave("~/Desktop/scatterplot-20points.png")
set.seed(42)
num_points <- 1000
x <- sort(runif(num_points, min=-5, max=6))
y <- x^2/5 + sin(3*x) + rnorm(num_points, sd=0.0001)
df_fits <- data_frame(x=x, y=y)
df_fits$linear_fit <- predict(lm(y ~ x, data=df), newdata=df_fits)
df_fits$quadratic_fit <- predict(lm(y ~ poly(x, 2), data=df), newdata=df_fits)
df_fits$polynomial_fit <- predict(lm(y ~ poly(x, 10), data=df), newdata=df_fits)
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=linear_fit), color="blue")
p
ggsave("~/Desktop/linear-20points.png")
p <- ggplot(df_fits, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=linear_fit), color="blue")
p
ggsave("~/Desktop/linear-1000points.png")
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=quadratic_fit), color="blue")
p
ggsave("~/Desktop/quadratic-20points.png")
p <- ggplot(df_fits, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=quadratic_fit), color="blue")
p
ggsave("~/Desktop/quadratic-1000points.png")
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=polynomial_fit), color="blue")
p
ggsave("~/Desktop/polynomial-20points.png")
p <- ggplot(df_fits, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=polynomial_fit), color="blue")
p
ggsave("~/Desktop/polynomial-1000points.png")
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=linear_fit), color="blue")
p
ggsave("~/Desktop/linear-20points.png")
set.seed(42)
num_points <- 20
x <- sort(runif(num_points, min=-5, max=6))
y <- x^2/5 + sin(3*x) + rnorm(num_points, sd=0.0001)
df <- data_frame(x=x, y=y)
rf_10 <- randomForest(x=as.matrix(x), y=y, ntree=1000)
df_fits$predict_rf20 <- predict(rf_10, newdata=as.matrix(df_fits$x))
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=predict_rf20), color="blue")
p
ggsave("~/Desktop/rf-20points.png")
set.seed(42)
num_points <- 100
x <- sort(runif(num_points, min=-5, max=6))
y <- x^2/5 + sin(3*x) + rnorm(num_points, sd=0.0001)
df <- data_frame(x=x, y=y)
rf_10 <- randomForest(x=as.matrix(x), y=y, ntree=1000)
df_fits$predict_rf20 <- predict(rf_10, newdata=as.matrix(df_fits$x))
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=predict_rf20), color="blue")
p
ggsave("~/Desktop/rf-100points.png")
set.seed(42)
num_points <- 500
x <- sort(runif(num_points, min=-5, max=4))
y <- x^2/5 + sin(3*x) + rnorm(num_points, sd=0.0001)
df <- data_frame(x=x, y=y)
rf_10 <- randomForest(x=as.matrix(x), y=y, ntree=1000)
lm_quadratic <- lm(y ~ poly(x, 2), data=df)
df_fits$predict_rf20 <- predict(rf_10, newdata=as.matrix(df_fits$x))
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point() + xlim(-5, 6) + ylim(-1, 7)
p <- p + geom_line(data=df_fits, aes(x=x, y=predict_rf20), color="blue")
p
ggsave("~/Desktop/rf-500points.png")
color_palette <- c("red", "blue")
set.seed(42)
num_points <- 1000
x <- sort(runif(num_points, min=-7, max=7))
y <- x^2/5 + sin(3*x) + rnorm(num_points, sd=0.0001)
df <- data_frame(x=x, y=y) %>%
mutate(
`Random Forest`=predict(rf_10, newdata=as.matrix(.$x)),
Quadratic=predict(lm_quadratic, newdata=.)
) %>%
gather(Model, Prediction, -x, -y)
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point(size=2) + xlim(-7, 7) + ylim(-1, 7)
p <- p + geom_line(aes(x=x, y=Prediction, color=Model), size=1.0)
p <- p + scale_color_manual(values=color_palette)
p <- p + scale_x_continuous(breaks = seq(-7, 7, 2))
p + theme(
axis.title=element_text(face="bold", size=20),
axis.text=element_text(vjust=0.5, size=16),
legend.title=element_text(size=16, face="bold"),
legend.text=element_text(size=14)
)
ggsave("~/Desktop/rf-shortcoming.png")
# Extrapolation via Taylor Series
set.seed(42)
num_points <- 1000
x <- sort(runif(num_points, min=-7, max=7))
y <- x^2/5 + sin(3*x) + rnorm(num_points, sd=0.0001)
df <- data_frame(x=x, y=y) %>%
mutate(
rf_prediction=predict(rf_10, newdata=as.matrix(.$x))
)
# Predictions nearest x = 4
rf_upper <- df %>%
filter(x < 4) %>%
arrange(-x) %>%
head(1)
x_upper <- rf_upper$x
yhat_upper <- rf_upper$rf_prediction
# Predictions nearest x = -5
rf_lower <- df %>%
filter(x > -5.01) %>%
arrange(x) %>%
head(1)
x_lower <- rf_lower$x
yhat_lower <- rf_lower$rf_prediction
# Taylor Series approximation (2nd-order)
lm_quadratic <- lm(y ~ I(x^2) - 1, data=df)
#ts_first_derivative <- coef(lm_quadratic)[2]
ts_second_derivative <- coef(lm_quadratic)
ts_approximation <- function(y_hat, x, x0, first_derivative, second_derivative) {
y_hat + first_derivative * (x - x0) + second_derivative * (x - x0)^2 / 2
}
df_upper <- df %>%
filter(x > 4) %>%
mutate(
yhat_taylor=ts_approximation(
y_hat=yhat_upper,
x=.$x,
x0=x_upper,
first_derivative=0,
second_derivative=ts_second_derivative
)
) %>%
select(x, yhat_taylor)
df_lower <- df %>%
filter(x < -5) %>%
mutate(
yhat_taylor=ts_approximation(
y_hat=yhat_lower,
x=.$x,
x0=x_lower,
first_derivative=0,
second_derivative=ts_second_derivative
)
) %>%
select(x, yhat_taylor)
# Add Taylor extrapolation to df
df <- df %>%
left_join(., df_lower) %>%
left_join(., df_upper) %>%
mutate(yhat_taylor=ifelse(is.na(yhat_taylor), rf_prediction, yhat_taylor)) %>%
rename(
`Random Forest`=rf_prediction,
`Taylor Extrapolation`=yhat_taylor
) %>%
gather(Model, Prediction, -x, -y)
p <- ggplot(df, aes(x=x, y=y))
p <- p + geom_point(size=2) + xlim(-7, 7) + ylim(-1, 7)
p <- p + geom_line(aes(x=x, y=Prediction, color=Model), size=1.0)
p <- p + scale_color_manual(values=color_palette)
p <- p + scale_x_continuous(breaks = seq(-7, 7, 2))
p + theme(
axis.title=element_text(face="bold", size=20),
axis.text=element_text(vjust=0.5, size=16),
legend.title=element_text(size=16, face="bold"),
legend.text=element_text(size=14)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment