#> Loading required package: nlme
#> This is mgcv 1.8-40. For overview type 'help("mgcv-package")'.
# the fake function we would like to mirror
f <- function(age) {
  points <- tribble(
    ~x, ~y,
    0, 0.1,
    5, 0.1,
    10, 0.1,
    15, 0.1,
    20, 0.1,
    23, 0.5,
    24, 0.7,
    25, 0.9,
    27, 0.9,
    30, 0.9,
    35, 0.9,
    40, 0.9,
    45, 0.9,
    48, 0.9,
    50, 0.9,
    51, 0.7,
    52, 0.5,
    55, 0.1,
    60, 0.1,
    65, 0.1,
    70, 0.1,
    75, 0.1,
    80, 0.1
  interp <- splinefun(x = points$x,
                   y = points$y,
                   method = "periodic")

g <- function(age_from, age_to) {
  sin(age_from / 10 - 1) * sin(age_to / 10 - 1)

h <- function(age_from, age_to) {
  -0.1 * exp(sin(abs(age_from - age_to) / 5))^2

eta <- function(age_from, age_to) {
  1 + f(age_from) + f(age_to) + g(age_from, age_to) + h(age_from, age_to)

sim_data <- expand.grid(
  age_from = 1:90,
  age_to = 1:90
) %>%
  as_tibble() %>% 
    # the fake log number of contacts, as if we had the model:
    # smooth(x, y) = f(x) + f(y)
    term_eta = eta(age_from, age_to),
    participants = 500,
    # fake data (ignore number of participants and population offsets for now)
    response = rpois(n(), exp(term_eta)) * participants,
    gam_age_offdiag = abs(age_from - age_to),
    gam_age_offdiag_2 = abs(age_from - age_to)^2,
    gam_age_diag_prod = abs(age_from * age_to),
    gam_age_diag_sum = abs(age_from + age_to),
    gam_age_pmax = pmax(age_from, age_to),
    gam_age_pmin = pmin(age_from, age_to)

# fit with various symmetric smooth types
sim_m <- gam(
  response ~
    s(gam_age_offdiag) +
    s(gam_age_offdiag_2) +
    s(gam_age_diag_prod) +
    s(gam_age_diag_sum) +
    s(gam_age_pmax) +
  family = poisson,
  offset = log(participants),
  data = sim_data


get_formulas_terms <- function(model){
  as.character(attr(terms(model$formula), "variables"))[-c(1,2)]

formula_terms <- get_formulas_terms(sim_m)

#> [1] "s(gam_age_offdiag)"   "s(gam_age_offdiag_2)" "s(gam_age_diag_prod)"
#> [4] "s(gam_age_diag_sum)"  "s(gam_age_pmax)"      "s(gam_age_pmin)"

extract_term_name <- function(x){
  term <- as.character(stringr::str_extract_all(x, "(?<=\\().+?(?=\\))"))

#> fitted_gam_age_offdiag
#> fitted_gam_age_offdiag_2
#> fitted_gam_age_diag_prod
#> fitted_gam_age_diag_sum
#> fitted_gam_age_pmax
#> fitted_gam_age_pmin

predict_gam_term <- function(model, data, terms){
          type = "terms", 
          terms = terms)

head(predict_gam_term(sim_m, sim_data, formula_terms[1]))
#> [1] -444.4278 -444.4258 -444.3999 -444.3189 -444.1496 -443.8621
tail(predict_gam_term(sim_m, sim_data, formula_terms[1]))
#> [1] -443.8621 -444.1496 -444.3189 -444.3999 -444.4258 -444.4278

add_intercept <- function(data, model){
    .data = data,
    fitted_intercept = model$coefficients[1]

tidy_predict_term <- function(data,
  term_name <- extract_term_name(term)
  dat_term <- tibble(x = predict_gam_term(model, data, term))
  setNames(dat_term, term_name)

add_fitted_overall <- function(data){
  data %>% 
      fitted_overall = rowSums(across(
        .cols = c(starts_with("fitted"))))

add_gam_predictions <- function(data, model, term) {
  predictions <- map_dfc(
    .x = formula_terms,
    .f = tidy_predict_term,
    data = data,
    model = model
  data %>% 
    add_intercept(model) %>% 
    bind_cols(predictions) %>% 

data_gam_preds <- add_gam_predictions(
  data = sim_data,
  model = sim_m,
  term = formula_terms

#>  [1] "age_from"                 "age_to"                  
#>  [3] "term_eta"                 "participants"            
#>  [5] "response"                 "gam_age_offdiag"         
#>  [7] "gam_age_offdiag_2"        "gam_age_diag_prod"       
#>  [9] "gam_age_diag_sum"         "gam_age_pmax"            
#> [11] "gam_age_pmin"             "fitted_intercept"        
#> [13] "fitted_gam_age_offdiag"   "fitted_gam_age_offdiag_2"
#> [15] "fitted_gam_age_diag_prod" "fitted_gam_age_diag_sum" 
#> [17] "fitted_gam_age_pmax"      "fitted_gam_age_pmin"     
#> [19] "fitted_overall"

# pivot_longer for plotting
data_pred_terms <- data_gam_preds %>%
  # remove computed variables
  select(-starts_with("gam_")) %>% 
    starts_with(match = c("term", "fitted")),
    names_to = "term",
    values_to = "value",
    names_prefix = "term_"

# putting eta and fitted in with the terms is a bit inelegant, but I couldn't
# think of a nicer pattern that wasn't a lot of code.

# plot the truth and fitted model
#> [1] "eta"                      "fitted_intercept"        
#> [3] "fitted_gam_age_offdiag"   "fitted_gam_age_offdiag_2"
#> [5] "fitted_gam_age_diag_prod" "fitted_gam_age_diag_sum" 
#> [7] "fitted_gam_age_pmax"      "fitted_gam_age_pmin"     
#> [9] "fitted_overall"
data_pred_terms %>%
  filter(term %in% c("eta", "fitted_overall")) %>%
      x = age_from,
      y = age_to,
      group = term,
      fill = exp(value)
  ) +
  facet_wrap(~term) +
  geom_tile() +
  scale_fill_viridis_c() +
  theme_minimal() +

# plot the parts of the model
data_pred_terms %>%
  filter(!term %in% c("eta", "fitted_overall")) %>%
      x = age_from,
      y = age_to,
      group = term,
      fill = value
  ) +
  facet_wrap(~term) +
  geom_tile() +
  scale_fill_viridis_c() +
  theme_minimal() +

# plot the 1D smooths
par(mfrow = c(2, 3))

par(mfrow = c(1, 1))

