library(tidyverse)
library(rpart)
set.seed(2020 - 05 - 21)
x <- sort(runif(100) - 0.5)
df <- tibble(x,
y = 10 * c(x[1:50] ^ 2,
x[51:75] * 2,
-x[76:100] ^ 2) + rnorm(100) * 0.5)
df_rp <- rpart(y~x, data=df)
sum_sq <- function(data, y_var) {
if (nrow(data) == 1){
res <- 1
}
# res <- var(data[[deparse(substitute(y_var))]]) * (nrow(data) - 1)
res <- var(data[[y_var]]) * (nrow(data) - 1)
return(res)
}
ss_t <- sum_sq(df, "y")
compute_anova <- function(left, right, y_var) {
ss_l <- sum_sq(left, y_var)
ss_r <- sum_sq(right, y_var)
av <- ss_t - (ss_l + ss_r)
return(av)
}
aov_f <- tibble(x = df$x[-1], f = df$y[-1])
for (i in 2:nrow(df)) {
left <- df[1:(i - 1), ]
right <- df[i:nrow(df), ]
aov_f$x[i - 1] <- mean(df$x[c(i - 1, i)])
aov_f$f[i - 1] <- compute_anova(left, right, "y")
}
# perhaps try this out with accumulate?
accumulate(df$y, sum, .dir = "forward")
#> [1] 2.541811 5.187000 6.862554 9.116096 11.333949 12.556044
#> [7] 14.631407 16.445138 17.886618 19.703371 21.955921 23.115689
#> [13] 24.481332 25.729406 26.879628 28.014668 28.209359 29.089742
#> [19] 30.241354 30.885161 31.450041 32.452772 32.918260 33.105286
#> [25] 34.095481 35.128649 35.960553 36.334196 36.110644 36.462569
#> [31] 36.965526 37.340241 37.380008 36.783398 37.759678 38.435722
#> [37] 38.185789 38.871216 39.263064 39.110255 38.636636 38.409121
#> [43] 39.745407 40.110402 40.671673 40.866249 41.142651 40.632770
#> [49] 40.435210 40.493216 41.360741 41.665337 42.830184 44.344256
#> [55] 45.612297 47.612840 49.790733 52.191716 53.697596 55.161432
#> [61] 57.132571 60.815032 63.144292 66.398033 70.741568 74.860452
#> [67] 78.481951 82.427153 86.980777 91.596964 95.888646 101.295836
#> [73] 106.989854 112.069443 116.298744 116.385507 115.164947 115.198878
#> [79] 115.225320 114.221198 113.734495 112.861147 112.240967 111.281808
#> [85] 109.811725 107.632856 107.120648 105.788005 103.100719 101.142670
#> [91] 99.547995 97.867741 96.249498 93.338117 92.076367 90.599512
#> [97] 88.519454 86.375368 83.842920 80.810707
p1 <- ggplot(df, aes(x = x, y = y)) +
geom_point(alpha = 0.5) +
scale_x_continuous(breaks = seq(-0.5, 0.5, 0.1))
p2 <- ggplot(data = aov_f) +
geom_line(aes(x = x, y = f) , colour = "hotpink") +
geom_vline(xintercept = df_rp$splits[1, 4],
colour = "hotpink",
linetype = 2)
library(gridExtra)
#>
#> Attaching package: 'gridExtra'
#> The following object is masked from 'package:dplyr':
#>
#> combine
grid.arrange(p1, p2, ncol = 1)
#> Warning: Removed 2 row(s) containing missing values (geom_path).
aov_f
#> # A tibble: 99 x 2
#> x f
#> <dbl> <dbl>
#> 1 -0.485 NA
#> 2 -0.469 6.51
#> 3 -0.458 6.77
#> 4 -0.455 9.01
#> 5 -0.442 11.2
#> 6 -0.426 10.5
#> 7 -0.417 12.4
#> 8 -0.414 13.5
#> 9 -0.410 13.8
#> 10 -0.405 15.0
#> # … with 89 more rows
Created on 2020-05-22 by the reprex package (v0.3.0)