Last active
August 22, 2019 15:06
-
-
Save StaffanBetner/3adeb49fd60174222788509b904f0d01 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
library(mgcv) | |
splines_and_derivative <- function(gam_object, n_eval = 200, n_sim = 100, eps = 0.0000001) { | |
number_of_smooths <- gam_object$smooth %>% length() | |
data <- model.frame(gam_object) | |
Vc <- vcov(gam_object, | |
unconditional = TRUE | |
) | |
tibble(id = 1:number_of_smooths) %>% | |
mutate( | |
smooth_obj = map(id, ~ gam_object$smooth[[.x]]), | |
dims = map_dbl(smooth_obj, ~ .$dim) | |
) %>% | |
filter(dims == 1) %>% | |
mutate( | |
label = map_chr(smooth_obj, ~ .x$label), | |
term = map_chr(smooth_obj, ~ .x$term), | |
by = map_chr(smooth_obj, ~ .x$by), | |
data = map(term, ~ tibble(x = data[[.x]])), | |
grid = ifelse(by != "NA", map(data, ~ tibble(x = seq(min(.x[[1]]), max(.x[[1]]), length.out = n_eval)) %>% mutate(by = 1)), | |
map(data, ~ tibble(x = seq(min(.x[[1]]), max(.x[[1]]), length.out = n_eval))) | |
), | |
grid = pmap( | |
list(grid, term, by), | |
function(grid, term, by) { | |
if (ncol(grid) == 1) { | |
names(grid) <- term | |
} else { | |
names(grid) <- c(term, by) | |
} | |
return(grid) | |
} | |
), | |
grid2 = ifelse(by != "NA", map(data, ~ tibble(x = seq(min(.x[[1]]), max(.x[[1]]), length.out = n_eval)) %>% mutate(x = x + eps, by = 1)), | |
map(data, ~ tibble(x = seq(min(.x[[1]]), max(.x[[1]]), length.out = n_eval)) %>% mutate(x = x + eps)) | |
), | |
grid2 = pmap( | |
list(grid2, term, by), | |
function(grid, term, by) { | |
if (ncol(grid) == 1) { | |
names(grid) <- term | |
} else { | |
names(grid) <- c(term, by) | |
} | |
return(grid) | |
} | |
), | |
first = map_dbl(smooth_obj, ~ .x$first.para), | |
last = map_dbl(smooth_obj, ~ .x$last.para), | |
coefs = map2(first, last, ~ coef(gam_object)[.x:.y]), | |
Vc = map2(first, last, ~ Vc[.x:.y, .x:.y]), | |
Bu = map(Vc, ~ MASS::mvrnorm( | |
n = n_sim, | |
mu = rep( | |
0, | |
nrow(.x) | |
), | |
Sigma = .x | |
)), | |
knotvalues = map2(smooth_obj, grid, ~ PredictMat(.x, .y)), | |
knotvalues_deriv = map2(smooth_obj, grid2, ~ PredictMat(.x, .y)), | |
knotvalues_deriv = map2(knotvalues, knotvalues_deriv, ~ (.y - .x) / eps), | |
Se = map2(knotvalues, Vc, ~ sqrt(rowSums(.x * (.x %*% .y)))), | |
Se_deriv = map2(knotvalues_deriv, Vc, ~ sqrt(rowSums(.x * (.x %*% .y)))), | |
crit = pmap_dbl(list(knotvalues, Bu, Se), function(Xp, Bu, Se) { | |
quantile(apply(abs((Xp %*% t(Bu)) / Se), 2, max), | |
type = 8, | |
prob = 0.95 | |
) | |
}), | |
crit_deriv = pmap_dbl(list(knotvalues_deriv, Bu, Se_deriv), function(Xp, Bu, Se) { | |
quantile(apply(abs((Xp %*% t(Bu)) / Se), 2, max), | |
type = 8, | |
prob = 0.95 | |
) | |
}), | |
est = map2(knotvalues, coefs, ~ .x %*% .y), | |
est_deriv = map2(knotvalues_deriv, coefs, ~ .x %*% .y), | |
error_margin = map2(crit, Se, ~ .x * .y), | |
error_margin_deriv = map2(crit_deriv, Se_deriv, ~ .x * .y), | |
pval = map2(est, error_margin, ~ (1 - pnorm(abs(.x / (.y * 0.5102137))))), | |
pval_deriv = map2(est_deriv, error_margin_deriv, ~ (1 - pnorm(abs(.x / (.y * 0.5102137))))), | |
grid = map2(.x = grid, .y = term, ~rename_at(.x, .y, ~"x")) | |
) %>% | |
unnest(grid, Se, Se_deriv, est, est_deriv, error_margin, error_margin_deriv, pval, pval_deriv) %>% | |
mutate( | |
error_margin_point = 1.96 * Se, | |
error_margin_point_deriv = 1.96 * Se_deriv, | |
pval_point = 1 - pnorm(abs(est / (Se))), | |
pval_point_deriv = 1 - pnorm(abs(est_deriv / (Se_deriv))) | |
) -> | |
output | |
class(output) <- c("splines", class(output)) | |
return(output) | |
} | |
plot.splines <- function(splines_obj) { | |
deriv_vars <- c( | |
"crit_deriv", | |
"Se_deriv", | |
"est_deriv", | |
"error_margin_deriv", | |
"pval_deriv", | |
"error_margin_point_deriv", | |
"pval_point_deriv" | |
) | |
spline_vars <- c( | |
"crit", | |
"Se", | |
"est", | |
"error_margin", | |
"pval", | |
"error_margin_point", | |
"pval_point" | |
) | |
splines_obj %>% | |
select(-one_of(spline_vars)) %>% | |
rename( | |
crit = crit_deriv, | |
Se = Se_deriv, | |
est = est_deriv, | |
error_margin = error_margin_deriv, | |
pval = pval_deriv, | |
error_margin_point = error_margin_point_deriv, | |
pval_point = pval_point_deriv | |
) %>% | |
mutate(type = "deriv") -> | |
deriv_data | |
splines_obj %>% | |
select(-one_of(deriv_vars)) %>% | |
mutate(type = "spline") -> | |
spline_data | |
spline_data %>% | |
full_join(deriv_data) %>% | |
mutate(type = type %>% factor(levels = c("spline", "deriv"))) %>% | |
ggplot(aes(x = x, y = est)) + | |
geom_line() + | |
facet_wrap(type ~ label, scales = "free", nrow = 2) + | |
geom_ribbon(alpha = 0.4, mapping = aes(ymin = est - error_margin, ymax = est + error_margin)) + | |
geom_ribbon(alpha = 0.3, mapping = aes(ymin = est - error_margin_point, ymax = est + error_margin_point)) | |
} | |
#splines_and_derivative(gam_object, n_eval = 500) %>% | |
# plot |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment