Skip to content

Instantly share code, notes, and snippets.

@njtierney
Created May 22, 2020 13:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save njtierney/9dbcc813fafe8c007f559e1d94bed4d4 to your computer and use it in GitHub Desktop.
Save njtierney/9dbcc813fafe8c007f559e1d94bed4d4 to your computer and use it in GitHub Desktop.
library(tidyverse)
library(rpart)
set.seed(2020 - 05 - 21)
x <- sort(runif(100) - 0.5)
df <- tibble(x,
             y = 10 * c(x[1:50] ^ 2,
                        x[51:75] * 2,
                        -x[76:100] ^ 2) + rnorm(100) * 0.5)

df_rp <- rpart(y~x, data=df)
sum_sq <- function(data, y_var) {
  if (nrow(data) == 1){
    res <- 1
  }
  # res <- var(data[[deparse(substitute(y_var))]]) * (nrow(data) - 1)
  res <- var(data[[y_var]]) * (nrow(data) - 1)
  
  return(res)
}

ss_t <- sum_sq(df, "y")

compute_anova <- function(left, right, y_var) {
  ss_l <- sum_sq(left, y_var)
  ss_r <- sum_sq(right, y_var)
  av <- ss_t - (ss_l + ss_r)
  return(av)
}

aov_f <- tibble(x = df$x[-1], f = df$y[-1])

for (i in 2:nrow(df)) {
  left <- df[1:(i - 1), ]
  right <- df[i:nrow(df), ]
  aov_f$x[i - 1] <- mean(df$x[c(i - 1, i)])
  aov_f$f[i - 1] <- compute_anova(left, right, "y")
}

# perhaps try this out with accumulate?
accumulate(df$y, sum, .dir = "forward")
#>   [1]   2.541811   5.187000   6.862554   9.116096  11.333949  12.556044
#>   [7]  14.631407  16.445138  17.886618  19.703371  21.955921  23.115689
#>  [13]  24.481332  25.729406  26.879628  28.014668  28.209359  29.089742
#>  [19]  30.241354  30.885161  31.450041  32.452772  32.918260  33.105286
#>  [25]  34.095481  35.128649  35.960553  36.334196  36.110644  36.462569
#>  [31]  36.965526  37.340241  37.380008  36.783398  37.759678  38.435722
#>  [37]  38.185789  38.871216  39.263064  39.110255  38.636636  38.409121
#>  [43]  39.745407  40.110402  40.671673  40.866249  41.142651  40.632770
#>  [49]  40.435210  40.493216  41.360741  41.665337  42.830184  44.344256
#>  [55]  45.612297  47.612840  49.790733  52.191716  53.697596  55.161432
#>  [61]  57.132571  60.815032  63.144292  66.398033  70.741568  74.860452
#>  [67]  78.481951  82.427153  86.980777  91.596964  95.888646 101.295836
#>  [73] 106.989854 112.069443 116.298744 116.385507 115.164947 115.198878
#>  [79] 115.225320 114.221198 113.734495 112.861147 112.240967 111.281808
#>  [85] 109.811725 107.632856 107.120648 105.788005 103.100719 101.142670
#>  [91]  99.547995  97.867741  96.249498  93.338117  92.076367  90.599512
#>  [97]  88.519454  86.375368  83.842920  80.810707


p1 <- ggplot(df, aes(x = x, y = y)) + 
  geom_point(alpha = 0.5) + 
  scale_x_continuous(breaks = seq(-0.5, 0.5, 0.1))

p2 <- ggplot(data = aov_f) +
  geom_line(aes(x = x, y = f) , colour = "hotpink") +
  geom_vline(xintercept = df_rp$splits[1, 4],
             colour = "hotpink",
             linetype = 2)

library(gridExtra)
#> 
#> Attaching package: 'gridExtra'
#> The following object is masked from 'package:dplyr':
#> 
#>     combine
grid.arrange(p1, p2, ncol = 1)
#> Warning: Removed 2 row(s) containing missing values (geom_path).

aov_f
#> # A tibble: 99 x 2
#>         x     f
#>     <dbl> <dbl>
#>  1 -0.485 NA   
#>  2 -0.469  6.51
#>  3 -0.458  6.77
#>  4 -0.455  9.01
#>  5 -0.442 11.2 
#>  6 -0.426 10.5 
#>  7 -0.417 12.4 
#>  8 -0.414 13.5 
#>  9 -0.410 13.8 
#> 10 -0.405 15.0 
#> # … with 89 more rows

Created on 2020-05-22 by the reprex package (v0.3.0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment