Skip to content

Instantly share code, notes, and snippets.

@njtierney
Last active August 8, 2022 07:05
Show Gist options
  • Save njtierney/f6e0dad240492be978bce9c4a514bfee to your computer and use it in GitHub Desktop.
Save njtierney/f6e0dad240492be978bce9c4a514bfee to your computer and use it in GitHub Desktop.
library(mgcv)
#> Loading required package: nlme
#> This is mgcv 1.8-40. For overview type 'help("mgcv-package")'.
library(tidyverse)
library(ggplot2)
set.seed(2022-08-08)
# the fake function we would like to mirror
f <- function(age) {
  
  points <- tribble(
    ~x, ~y,
    0, 0.1,
    5, 0.1,
    10, 0.1,
    15, 0.1,
    20, 0.1,
    23, 0.5,
    24, 0.7,
    25, 0.9,
    27, 0.9,
    30, 0.9,
    35, 0.9,
    40, 0.9,
    45, 0.9,
    48, 0.9,
    50, 0.9,
    51, 0.7,
    52, 0.5,
    55, 0.1,
    60, 0.1,
    65, 0.1,
    70, 0.1,
    75, 0.1,
    80, 0.1
  )
  interp <- splinefun(x = points$x,
                   y = points$y,
                   method = "periodic")
  interp(age)
}

g <- function(age_from, age_to) {
  sin(age_from / 10 - 1) * sin(age_to / 10 - 1)
}

h <- function(age_from, age_to) {
  -0.1 * exp(sin(abs(age_from - age_to) / 5))^2
}

eta <- function(age_from, age_to) {
  1 + f(age_from) + f(age_to) + g(age_from, age_to) + h(age_from, age_to)
}

sim_data <- expand.grid(
  age_from = 1:90,
  age_to = 1:90
) %>%
  as_tibble() %>% 
  mutate(
    # the fake log number of contacts, as if we had the model:
    # smooth(x, y) = f(x) + f(y)
    term_eta = eta(age_from, age_to),
    participants = 500,
    # fake data (ignore number of participants and population offsets for now)
    response = rpois(n(), exp(term_eta)) * participants,
    gam_age_offdiag = abs(age_from - age_to),
    gam_age_offdiag_2 = abs(age_from - age_to)^2,
    gam_age_diag_prod = abs(age_from * age_to),
    gam_age_diag_sum = abs(age_from + age_to),
    gam_age_pmax = pmax(age_from, age_to),
    gam_age_pmin = pmin(age_from, age_to)
  )

# fit with various symmetric smooth types
sim_m <- gam(
  response ~
    s(gam_age_offdiag) +
    s(gam_age_offdiag_2) +
    s(gam_age_diag_prod) +
    s(gam_age_diag_sum) +
    s(gam_age_pmax) +
    s(gam_age_pmin),
  family = poisson,
  offset = log(participants),
  data = sim_data
)

###

get_formulas_terms <- function(model){
  as.character(attr(terms(model$formula), "variables"))[-c(1,2)]
}

formula_terms <- get_formulas_terms(sim_m)

formula_terms
#> [1] "s(gam_age_offdiag)"   "s(gam_age_offdiag_2)" "s(gam_age_diag_prod)"
#> [4] "s(gam_age_diag_sum)"  "s(gam_age_pmax)"      "s(gam_age_pmin)"

extract_term_name <- function(x){
  term <- as.character(stringr::str_extract_all(x, "(?<=\\().+?(?=\\))"))
  glue::glue("fitted_{term}")
}

extract_term_name(formula_terms)
#> fitted_gam_age_offdiag
#> fitted_gam_age_offdiag_2
#> fitted_gam_age_diag_prod
#> fitted_gam_age_diag_sum
#> fitted_gam_age_pmax
#> fitted_gam_age_pmin

predict_gam_term <- function(model, data, terms){
  
  c(
  predict(model,
          data, 
          type = "terms", 
          terms = terms)
  )
  
}

head(predict_gam_term(sim_m, sim_data, formula_terms[1]))
#> [1] -444.4278 -444.4258 -444.3999 -444.3189 -444.1496 -443.8621
tail(predict_gam_term(sim_m, sim_data, formula_terms[1]))
#> [1] -443.8621 -444.1496 -444.3189 -444.3999 -444.4258 -444.4278

add_intercept <- function(data, model){
  mutate(
    .data = data,
    fitted_intercept = model$coefficients[1]
    )
}

tidy_predict_term <- function(data,
                              model,
                              term){
  
  term_name <- extract_term_name(term)
  
  dat_term <- tibble(x = predict_gam_term(model, data, term))
  
  setNames(dat_term, term_name)
  
}

add_fitted_overall <- function(data){
  data %>% 
    mutate(
      fitted_overall = rowSums(across(
        .cols = c(starts_with("fitted"))))
    )
}

add_gam_predictions <- function(data, model, term) {
  predictions <- map_dfc(
    .x = formula_terms,
    .f = tidy_predict_term,
    data = data,
    model = model
  )
  
  data %>% 
    add_intercept(model) %>% 
    bind_cols(predictions) %>% 
      add_fitted_overall()
  
}

data_gam_preds <- add_gam_predictions(
  data = sim_data,
  model = sim_m,
  term = formula_terms
)

names(data_gam_preds)
#>  [1] "age_from"                 "age_to"                  
#>  [3] "term_eta"                 "participants"            
#>  [5] "response"                 "gam_age_offdiag"         
#>  [7] "gam_age_offdiag_2"        "gam_age_diag_prod"       
#>  [9] "gam_age_diag_sum"         "gam_age_pmax"            
#> [11] "gam_age_pmin"             "fitted_intercept"        
#> [13] "fitted_gam_age_offdiag"   "fitted_gam_age_offdiag_2"
#> [15] "fitted_gam_age_diag_prod" "fitted_gam_age_diag_sum" 
#> [17] "fitted_gam_age_pmax"      "fitted_gam_age_pmin"     
#> [19] "fitted_overall"

# pivot_longer for plotting
data_pred_terms <- data_gam_preds %>%
  # remove computed variables
  select(-starts_with("gam_")) %>% 
  pivot_longer(
    starts_with(match = c("term", "fitted")),
    names_to = "term",
    values_to = "value",
    names_prefix = "term_"
  )

# putting eta and fitted in with the terms is a bit inelegant, but I couldn't
# think of a nicer pattern that wasn't a lot of code.

# plot the truth and fitted model
unique(data_pred_terms$term)
#> [1] "eta"                      "fitted_intercept"        
#> [3] "fitted_gam_age_offdiag"   "fitted_gam_age_offdiag_2"
#> [5] "fitted_gam_age_diag_prod" "fitted_gam_age_diag_sum" 
#> [7] "fitted_gam_age_pmax"      "fitted_gam_age_pmin"     
#> [9] "fitted_overall"
data_pred_terms %>%
  filter(term %in% c("eta", "fitted_overall")) %>%
  ggplot(
    aes(
      x = age_from,
      y = age_to,
      group = term,
      fill = exp(value)
    )
  ) +
  facet_wrap(~term) +
  geom_tile() +
  scale_fill_viridis_c() +
  theme_minimal() +
  coord_fixed()

# plot the parts of the model
data_pred_terms %>%
  filter(!term %in% c("eta", "fitted_overall")) %>%
  ggplot(
    aes(
      x = age_from,
      y = age_to,
      group = term,
      fill = value
    )
  ) +
  facet_wrap(~term) +
  geom_tile() +
  scale_fill_viridis_c() +
  theme_minimal() +
  coord_fixed()

# plot the 1D smooths
par(mfrow = c(2, 3))
plot(sim_m)

par(mfrow = c(1, 1))

Created on 2022-08-08 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22)
#>  os       macOS Monterey 12.3.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_AU.UTF-8
#>  ctype    en_AU.UTF-8
#>  tz       Australia/Perth
#>  date     2022-08-08
#>  pandoc   2.18 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  assertthat    0.2.1      2019-03-21 [1] CRAN (R 4.2.0)
#>  backports     1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  broom         0.8.0      2022-04-13 [1] CRAN (R 4.2.0)
#>  cellranger    1.1.0      2016-07-27 [1] CRAN (R 4.2.0)
#>  cli           3.3.0.9000 2022-06-15 [1] Github (r-lib/cli@31a5db5)
#>  colorspace    2.0-3      2022-02-21 [1] CRAN (R 4.2.0)
#>  crayon        1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>  curl          4.3.2      2021-06-23 [1] CRAN (R 4.2.0)
#>  DBI           1.1.2      2021-12-20 [1] CRAN (R 4.2.0)
#>  dbplyr        2.1.1      2021-04-06 [1] CRAN (R 4.2.0)
#>  digest        0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>  dplyr       * 1.0.9      2022-04-28 [1] CRAN (R 4.2.0)
#>  ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>  evaluate      0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>  fansi         1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>  farver        2.1.0      2021-02-28 [1] CRAN (R 4.2.0)
#>  fastmap       1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>  forcats     * 0.5.1      2021-01-27 [1] CRAN (R 4.2.0)
#>  fs            1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>  generics      0.1.2      2022-01-31 [1] CRAN (R 4.2.0)
#>  ggplot2     * 3.3.6      2022-05-03 [1] CRAN (R 4.2.0)
#>  glue          1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  gtable        0.3.0      2019-03-25 [1] CRAN (R 4.2.0)
#>  haven         2.5.0      2022-04-15 [1] CRAN (R 4.2.0)
#>  highr         0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>  hms           1.1.1      2021-09-26 [1] CRAN (R 4.2.0)
#>  htmltools     0.5.2      2021-08-25 [1] CRAN (R 4.2.0)
#>  httr          1.4.3      2022-05-04 [1] CRAN (R 4.2.0)
#>  jsonlite      1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>  knitr         1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>  labeling      0.4.2      2020-10-20 [1] CRAN (R 4.2.0)
#>  lattice       0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>  lifecycle     1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>  lubridate     1.8.0      2021-10-07 [1] CRAN (R 4.2.0)
#>  magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  Matrix        1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>  mgcv        * 1.8-40     2022-03-29 [1] CRAN (R 4.2.0)
#>  mime          0.12       2021-09-28 [1] CRAN (R 4.2.0)
#>  modelr        0.1.8      2020-05-19 [1] CRAN (R 4.2.0)
#>  munsell       0.5.0      2018-06-12 [1] CRAN (R 4.2.0)
#>  nlme        * 3.1-157    2022-03-25 [1] CRAN (R 4.2.0)
#>  pillar        1.7.0      2022-02-01 [1] CRAN (R 4.2.0)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  purrr       * 0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>  R.cache       0.15.0     2021-04-30 [1] CRAN (R 4.2.0)
#>  R.methodsS3   1.8.1      2020-08-26 [1] CRAN (R 4.2.0)
#>  R.oo          1.24.0     2020-08-26 [1] CRAN (R 4.2.0)
#>  R.utils       2.11.0     2021-09-26 [1] CRAN (R 4.2.0)
#>  R6            2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  readr       * 2.1.2      2022-01-30 [1] CRAN (R 4.2.0)
#>  readxl        1.4.0      2022-03-28 [1] CRAN (R 4.2.0)
#>  reprex        2.0.1      2021-08-05 [1] CRAN (R 4.2.0)
#>  rlang         1.0.4      2022-07-12 [1] CRAN (R 4.2.0)
#>  rmarkdown     2.14       2022-04-25 [1] CRAN (R 4.2.0)
#>  rstudioapi    0.13       2020-11-12 [1] CRAN (R 4.2.0)
#>  rvest         1.0.2      2021-10-16 [1] CRAN (R 4.2.0)
#>  scales        1.2.0      2022-04-13 [1] CRAN (R 4.2.0)
#>  sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi       1.7.6      2021-11-29 [1] CRAN (R 4.2.0)
#>  stringr     * 1.4.0      2019-02-10 [1] CRAN (R 4.2.0)
#>  styler        1.7.0      2022-03-13 [1] CRAN (R 4.2.0)
#>  tibble      * 3.1.7      2022-05-03 [1] CRAN (R 4.2.0)
#>  tidyr       * 1.2.0      2022-02-01 [1] CRAN (R 4.2.0)
#>  tidyselect    1.1.2      2022-02-21 [1] CRAN (R 4.2.0)
#>  tidyverse   * 1.3.1      2021-04-15 [1] CRAN (R 4.2.0)
#>  tzdb          0.3.0      2022-03-28 [1] CRAN (R 4.2.0)
#>  utf8          1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs         0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>  viridisLite   0.4.0      2021-04-13 [1] CRAN (R 4.2.0)
#>  withr         2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>  xfun          0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>  xml2          1.3.3      2021-11-30 [1] CRAN (R 4.2.0)
#>  yaml          2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment