library(mgcv)
#> Loading required package: nlme
#> This is mgcv 1.8-40. For overview type 'help("mgcv-package")'.
library(tidyverse)
library(ggplot2)
set.seed(2022-08-08)
# the fake function we would like to mirror
f <- function(age) {
points <- tribble(
~x, ~y,
0, 0.1,
5, 0.1,
10, 0.1,
15, 0.1,
20, 0.1,
23, 0.5,
24, 0.7,
25, 0.9,
27, 0.9,
30, 0.9,
35, 0.9,
40, 0.9,
45, 0.9,
48, 0.9,
50, 0.9,
51, 0.7,
52, 0.5,
55, 0.1,
60, 0.1,
65, 0.1,
70, 0.1,
75, 0.1,
80, 0.1
)
interp <- splinefun(x = points$x,
y = points$y,
method = "periodic")
interp(age)
}
g <- function(age_from, age_to) {
sin(age_from / 10 - 1) * sin(age_to / 10 - 1)
}
h <- function(age_from, age_to) {
-0.1 * exp(sin(abs(age_from - age_to) / 5))^2
}
eta <- function(age_from, age_to) {
1 + f(age_from) + f(age_to) + g(age_from, age_to) + h(age_from, age_to)
}
sim_data <- expand.grid(
age_from = 1:90,
age_to = 1:90
) %>%
as_tibble() %>%
mutate(
# the fake log number of contacts, as if we had the model:
# smooth(x, y) = f(x) + f(y)
term_eta = eta(age_from, age_to),
participants = 500,
# fake data (ignore number of participants and population offsets for now)
response = rpois(n(), exp(term_eta)) * participants,
gam_age_offdiag = abs(age_from - age_to),
gam_age_offdiag_2 = abs(age_from - age_to)^2,
gam_age_diag_prod = abs(age_from * age_to),
gam_age_diag_sum = abs(age_from + age_to),
gam_age_pmax = pmax(age_from, age_to),
gam_age_pmin = pmin(age_from, age_to)
)
# fit with various symmetric smooth types
sim_m <- gam(
response ~
s(gam_age_offdiag) +
s(gam_age_offdiag_2) +
s(gam_age_diag_prod) +
s(gam_age_diag_sum) +
s(gam_age_pmax) +
s(gam_age_pmin),
family = poisson,
offset = log(participants),
data = sim_data
)
###
get_formulas_terms <- function(model){
as.character(attr(terms(model$formula), "variables"))[-c(1,2)]
}
formula_terms <- get_formulas_terms(sim_m)
formula_terms
#> [1] "s(gam_age_offdiag)" "s(gam_age_offdiag_2)" "s(gam_age_diag_prod)"
#> [4] "s(gam_age_diag_sum)" "s(gam_age_pmax)" "s(gam_age_pmin)"
extract_term_name <- function(x){
term <- as.character(stringr::str_extract_all(x, "(?<=\\().+?(?=\\))"))
glue::glue("fitted_{term}")
}
extract_term_name(formula_terms)
#> fitted_gam_age_offdiag
#> fitted_gam_age_offdiag_2
#> fitted_gam_age_diag_prod
#> fitted_gam_age_diag_sum
#> fitted_gam_age_pmax
#> fitted_gam_age_pmin
predict_gam_term <- function(model, data, terms){
c(
predict(model,
data,
type = "terms",
terms = terms)
)
}
head(predict_gam_term(sim_m, sim_data, formula_terms[1]))
#> [1] -444.4278 -444.4258 -444.3999 -444.3189 -444.1496 -443.8621
tail(predict_gam_term(sim_m, sim_data, formula_terms[1]))
#> [1] -443.8621 -444.1496 -444.3189 -444.3999 -444.4258 -444.4278
add_intercept <- function(data, model){
mutate(
.data = data,
fitted_intercept = model$coefficients[1]
)
}
tidy_predict_term <- function(data,
model,
term){
term_name <- extract_term_name(term)
dat_term <- tibble(x = predict_gam_term(model, data, term))
setNames(dat_term, term_name)
}
add_fitted_overall <- function(data){
data %>%
mutate(
fitted_overall = rowSums(across(
.cols = c(starts_with("fitted"))))
)
}
add_gam_predictions <- function(data, model, term) {
predictions <- map_dfc(
.x = formula_terms,
.f = tidy_predict_term,
data = data,
model = model
)
data %>%
add_intercept(model) %>%
bind_cols(predictions) %>%
add_fitted_overall()
}
data_gam_preds <- add_gam_predictions(
data = sim_data,
model = sim_m,
term = formula_terms
)
names(data_gam_preds)
#> [1] "age_from" "age_to"
#> [3] "term_eta" "participants"
#> [5] "response" "gam_age_offdiag"
#> [7] "gam_age_offdiag_2" "gam_age_diag_prod"
#> [9] "gam_age_diag_sum" "gam_age_pmax"
#> [11] "gam_age_pmin" "fitted_intercept"
#> [13] "fitted_gam_age_offdiag" "fitted_gam_age_offdiag_2"
#> [15] "fitted_gam_age_diag_prod" "fitted_gam_age_diag_sum"
#> [17] "fitted_gam_age_pmax" "fitted_gam_age_pmin"
#> [19] "fitted_overall"
# pivot_longer for plotting
data_pred_terms <- data_gam_preds %>%
# remove computed variables
select(-starts_with("gam_")) %>%
pivot_longer(
starts_with(match = c("term", "fitted")),
names_to = "term",
values_to = "value",
names_prefix = "term_"
)
# putting eta and fitted in with the terms is a bit inelegant, but I couldn't
# think of a nicer pattern that wasn't a lot of code.
# plot the truth and fitted model
unique(data_pred_terms$term)
#> [1] "eta" "fitted_intercept"
#> [3] "fitted_gam_age_offdiag" "fitted_gam_age_offdiag_2"
#> [5] "fitted_gam_age_diag_prod" "fitted_gam_age_diag_sum"
#> [7] "fitted_gam_age_pmax" "fitted_gam_age_pmin"
#> [9] "fitted_overall"
data_pred_terms %>%
filter(term %in% c("eta", "fitted_overall")) %>%
ggplot(
aes(
x = age_from,
y = age_to,
group = term,
fill = exp(value)
)
) +
facet_wrap(~term) +
geom_tile() +
scale_fill_viridis_c() +
theme_minimal() +
coord_fixed()
# plot the parts of the model
data_pred_terms %>%
filter(!term %in% c("eta", "fitted_overall")) %>%
ggplot(
aes(
x = age_from,
y = age_to,
group = term,
fill = value
)
) +
facet_wrap(~term) +
geom_tile() +
scale_fill_viridis_c() +
theme_minimal() +
coord_fixed()
# plot the 1D smooths
par(mfrow = c(2, 3))
plot(sim_m)
par(mfrow = c(1, 1))
Created on 2022-08-08 by the reprex package (v2.0.1)
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.2.0 (2022-04-22)
#> os macOS Monterey 12.3.1
#> system aarch64, darwin20
#> ui X11
#> language (EN)
#> collate en_AU.UTF-8
#> ctype en_AU.UTF-8
#> tz Australia/Perth
#> date 2022-08-08
#> pandoc 2.18 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.2.0)
#> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0)
#> broom 0.8.0 2022-04-13 [1] CRAN (R 4.2.0)
#> cellranger 1.1.0 2016-07-27 [1] CRAN (R 4.2.0)
#> cli 3.3.0.9000 2022-06-15 [1] Github (r-lib/cli@31a5db5)
#> colorspace 2.0-3 2022-02-21 [1] CRAN (R 4.2.0)
#> crayon 1.5.1 2022-03-26 [1] CRAN (R 4.2.0)
#> curl 4.3.2 2021-06-23 [1] CRAN (R 4.2.0)
#> DBI 1.1.2 2021-12-20 [1] CRAN (R 4.2.0)
#> dbplyr 2.1.1 2021-04-06 [1] CRAN (R 4.2.0)
#> digest 0.6.29 2021-12-01 [1] CRAN (R 4.2.0)
#> dplyr * 1.0.9 2022-04-28 [1] CRAN (R 4.2.0)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0)
#> evaluate 0.15 2022-02-18 [1] CRAN (R 4.2.0)
#> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0)
#> farver 2.1.0 2021-02-28 [1] CRAN (R 4.2.0)
#> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0)
#> forcats * 0.5.1 2021-01-27 [1] CRAN (R 4.2.0)
#> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0)
#> generics 0.1.2 2022-01-31 [1] CRAN (R 4.2.0)
#> ggplot2 * 3.3.6 2022-05-03 [1] CRAN (R 4.2.0)
#> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.2.0)
#> haven 2.5.0 2022-04-15 [1] CRAN (R 4.2.0)
#> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0)
#> hms 1.1.1 2021-09-26 [1] CRAN (R 4.2.0)
#> htmltools 0.5.2 2021-08-25 [1] CRAN (R 4.2.0)
#> httr 1.4.3 2022-05-04 [1] CRAN (R 4.2.0)
#> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.2.0)
#> knitr 1.39 2022-04-26 [1] CRAN (R 4.2.0)
#> labeling 0.4.2 2020-10-20 [1] CRAN (R 4.2.0)
#> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.0)
#> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.2.0)
#> lubridate 1.8.0 2021-10-07 [1] CRAN (R 4.2.0)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0)
#> Matrix 1.4-1 2022-03-23 [1] CRAN (R 4.2.0)
#> mgcv * 1.8-40 2022-03-29 [1] CRAN (R 4.2.0)
#> mime 0.12 2021-09-28 [1] CRAN (R 4.2.0)
#> modelr 0.1.8 2020-05-19 [1] CRAN (R 4.2.0)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.0)
#> nlme * 3.1-157 2022-03-25 [1] CRAN (R 4.2.0)
#> pillar 1.7.0 2022-02-01 [1] CRAN (R 4.2.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.2.0)
#> R.cache 0.15.0 2021-04-30 [1] CRAN (R 4.2.0)
#> R.methodsS3 1.8.1 2020-08-26 [1] CRAN (R 4.2.0)
#> R.oo 1.24.0 2020-08-26 [1] CRAN (R 4.2.0)
#> R.utils 2.11.0 2021-09-26 [1] CRAN (R 4.2.0)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0)
#> readr * 2.1.2 2022-01-30 [1] CRAN (R 4.2.0)
#> readxl 1.4.0 2022-03-28 [1] CRAN (R 4.2.0)
#> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.2.0)
#> rlang 1.0.4 2022-07-12 [1] CRAN (R 4.2.0)
#> rmarkdown 2.14 2022-04-25 [1] CRAN (R 4.2.0)
#> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.2.0)
#> rvest 1.0.2 2021-10-16 [1] CRAN (R 4.2.0)
#> scales 1.2.0 2022-04-13 [1] CRAN (R 4.2.0)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0)
#> stringi 1.7.6 2021-11-29 [1] CRAN (R 4.2.0)
#> stringr * 1.4.0 2019-02-10 [1] CRAN (R 4.2.0)
#> styler 1.7.0 2022-03-13 [1] CRAN (R 4.2.0)
#> tibble * 3.1.7 2022-05-03 [1] CRAN (R 4.2.0)
#> tidyr * 1.2.0 2022-02-01 [1] CRAN (R 4.2.0)
#> tidyselect 1.1.2 2022-02-21 [1] CRAN (R 4.2.0)
#> tidyverse * 1.3.1 2021-04-15 [1] CRAN (R 4.2.0)
#> tzdb 0.3.0 2022-03-28 [1] CRAN (R 4.2.0)
#> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0)
#> vctrs 0.4.1 2022-04-13 [1] CRAN (R 4.2.0)
#> viridisLite 0.4.0 2021-04-13 [1] CRAN (R 4.2.0)
#> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0)
#> xfun 0.31 2022-05-10 [1] CRAN (R 4.2.0)
#> xml2 1.3.3 2021-11-30 [1] CRAN (R 4.2.0)
#> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.0)
#>
#> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#>
#> ──────────────────────────────────────────────────────────────────────────────