# comparison of Prem vs conmat for germany:

world_data <- socialmixr::wpp_age() %>%
    new_lower_age = if_else(lower.age.limit >= 75, 75L, lower.age.limit)
  ) %>%
  as_tibble() %>% 
    population = sum(population),
    .by = c(new_lower_age, country, year)

germany_2015 <- age_population(
  data = world_data,
  location_col = country,
  location = "Germany",
  age_col = new_lower_age,
  year_col = year,
  year = 2015

#> # A tibble: 16 × 6 (conmat_population)
#>  - age: lower.age.limit
#>  - population: population
#>    new_lower_age country  year population lower.age.limit upper.age.limit
#>            <int> <chr>   <int>      <dbl>           <dbl>           <dbl>
#>  1             0 Germany  2015    3517800               0               4
#>  2             5 Germany  2015    3507779               5               9
#>  3            10 Germany  2015    3693474              10              14
#>  4            15 Germany  2015    4101901              15              19
#>  5            20 Germany  2015    4571184              20              24
#>  6            25 Germany  2015    5213301              25              29
#>  7            30 Germany  2015    5058402              30              34
#>  8            35 Germany  2015    4782027              35              39
#>  9            40 Germany  2015    5190881              40              44
#> 10            45 Germany  2015    6805444              45              49
#> 11            50 Germany  2015    6920702              50              54
#> 12            55 Germany  2015    5998415              55              59
#> 13            60 Germany  2015    5091795              60              64
#> 14            65 Germany  2015    4216896              65              69
#> 15            70 Germany  2015    4225748              70              74
#> 16            75 Germany  2015    8812050              75              79

age_breaks_socialmixr <- c(seq(0, 75, by = 5), Inf)

germany_contacts <- extrapolate_polymod(
  population = germany_2015,
  age_breaks = age_breaks_socialmixr

n_finite_states <- length(age_breaks_socialmixr) - 1
socialmixr_matrix <- matrix(0.1761765, 
                            nrow = n_finite_states, 
                            ncol = n_finite_states)

transmission_matrix <- transmission_probability_matrix(
  home = socialmixr_matrix,
  work = socialmixr_matrix,
  school = socialmixr_matrix,
  other = socialmixr_matrix,
  age_breaks = age_breaks_socialmixr

parameters <- list(
  "transmission_matrix" = transmission_matrix,
  "homogeneous_contact" = germany_contacts,
  "gamma" = 1,
  "s_indexes" = 1:n_finite_states,
  "i_indexes" = (n_finite_states + 1):(2 * n_finite_states),
  "r_indexes" = (2 * n_finite_states + 1):(3 * n_finite_states)

S0 <- germany_2015$population
I0 <- rep(1, times = n_finite_states)
R0 <- rep(0, times = n_finite_states)
initial_condition <- c(S0, I0, R0)
names(initial_condition) <- paste(
  rep(c("S0", "I0", "R0"), each = n_finite_states),
  sep = "_"

age_structured_sir <- function(time, state, parameters) {
  # Calculate the force of infection for each setting:
  # unstructured SIR beta is age_group_n / pop_n
  N_by_age <- map_dbl(
    .x = parameters$s_indexes,
    .f = function(i) {
      current_indexes_to_sum <- c(
  # normalise by the age population
  N_infected_by_age <- state[parameters$i_indexes] / N_by_age
  # functional method for takign the product of two matrices
  product <- function(transmission, contact) {
    map2(transmission, contact, `*`)
  age_normalise <- function(beta) {
    # matrix multiply by infected and normalise by age population
    map(beta, function(beta) {
      beta %*% N_infected_by_age
  lambdas <- tibble(
    setting = names(parameters$transmission_matrix),
    transmission_matrix = parameters$transmission_matrix,
    homogeneous_contact = parameters$homogeneous_contact[1:4]
  ) %>%
      beta = product(transmission_matrix, homogeneous_contact),
      lambda = age_normalise(beta)
  # Combine them all into one term for ease of computation
  lambda_total <- Reduce("+", lambdas$lambda)
  # Don't forget to normalise your infection rate by the population!
  dSdt <- -lambda_total * state[parameters$s_indexes]
  dIdt <- lambda_total * state[parameters$s_indexes] -
    parameters$gamma * state[parameters$i_indexes]
  dRdt <- parameters$gamma * state[parameters$i_indexes]

times <- seq(0, 100, by = 0.1)
germany_soln <- ode(
  y = initial_condition,
  times = times,
  func = age_structured_sir,
  parms = parameters

# Have to convert ode output to a data frame to do any plotting
germany_soln <- as_tibble(

#> # A tibble: 6 × 49
#>    time     S0_0    S0_5  S0_10  S0_15  S0_20  S0_25  S0_30  S0_35  S0_40  S0_45
#>   <dbl>    <dbl>   <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
#> 1   0   3517800   3.51e6 3.69e6 4.10e6 4.57e6 5.21e6 5.06e6 4.78e6 5.19e6 6.81e6
#> 2   0.1 3517800.  3.51e6 3.69e6 4.10e6 4.57e6 5.21e6 5.06e6 4.78e6 5.19e6 6.81e6
#> 3   0.2 3517800.  3.51e6 3.69e6 4.10e6 4.57e6 5.21e6 5.06e6 4.78e6 5.19e6 6.81e6
#> 4   0.3 3517800.  3.51e6 3.69e6 4.10e6 4.57e6 5.21e6 5.06e6 4.78e6 5.19e6 6.81e6
#> 5   0.4 3517799.  3.51e6 3.69e6 4.10e6 4.57e6 5.21e6 5.06e6 4.78e6 5.19e6 6.81e6
#> 6   0.5 3517799.  3.51e6 3.69e6 4.10e6 4.57e6 5.21e6 5.06e6 4.78e6 5.19e6 6.81e6
#> # … with 38 more variables: S0_50 <dbl>, S0_55 <dbl>, S0_60 <dbl>, S0_65 <dbl>,
#> #   S0_70 <dbl>, S0_75 <dbl>, I0_0 <dbl>, I0_5 <dbl>, I0_10 <dbl>, I0_15 <dbl>,
#> #   I0_20 <dbl>, I0_25 <dbl>, I0_30 <dbl>, I0_35 <dbl>, I0_40 <dbl>,
#> #   I0_45 <dbl>, I0_50 <dbl>, I0_55 <dbl>, I0_60 <dbl>, I0_65 <dbl>,
#> #   I0_70 <dbl>, I0_75 <dbl>, R0_0 <dbl>, R0_5 <dbl>, R0_10 <dbl>, R0_15 <dbl>,
#> #   R0_20 <dbl>, R0_25 <dbl>, R0_30 <dbl>, R0_35 <dbl>, R0_40 <dbl>,
#> #   R0_45 <dbl>, R0_50 <dbl>, R0_55 <dbl>, R0_60 <dbl>, R0_65 <dbl>, …
#> # A tibble: 6 × 49
#>    time     S0_0    S0_5  S0_10  S0_15  S0_20  S0_25  S0_30  S0_35  S0_40  S0_45
#>   <dbl>    <dbl>   <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
#> 1  99.5 1103717. 463182. 2.46e5 2.79e5 5.09e5 4.81e5 4.57e5 4.64e5 4.75e5 4.60e5
#> 2  99.6 1103717. 463182. 2.46e5 2.79e5 5.09e5 4.81e5 4.57e5 4.64e5 4.75e5 4.60e5
#> 3  99.7 1103717. 463182. 2.46e5 2.79e5 5.09e5 4.81e5 4.57e5 4.64e5 4.75e5 4.60e5
#> 4  99.8 1103717. 463182. 2.46e5 2.79e5 5.09e5 4.81e5 4.57e5 4.64e5 4.75e5 4.60e5
#> 5  99.9 1103717. 463182. 2.46e5 2.79e5 5.09e5 4.81e5 4.57e5 4.64e5 4.75e5 4.60e5
#> 6 100   1103717. 463182. 2.46e5 2.79e5 5.09e5 4.81e5 4.57e5 4.64e5 4.75e5 4.60e5
#> # … with 38 more variables: S0_50 <dbl>, S0_55 <dbl>, S0_60 <dbl>, S0_65 <dbl>,
#> #   S0_70 <dbl>, S0_75 <dbl>, I0_0 <dbl>, I0_5 <dbl>, I0_10 <dbl>, I0_15 <dbl>,
#> #   I0_20 <dbl>, I0_25 <dbl>, I0_30 <dbl>, I0_35 <dbl>, I0_40 <dbl>,
#> #   I0_45 <dbl>, I0_50 <dbl>, I0_55 <dbl>, I0_60 <dbl>, I0_65 <dbl>,
#> #   I0_70 <dbl>, I0_75 <dbl>, R0_0 <dbl>, R0_5 <dbl>, R0_10 <dbl>, R0_15 <dbl>,
#> #   R0_20 <dbl>, R0_25 <dbl>, R0_30 <dbl>, R0_35 <dbl>, R0_40 <dbl>,
#> #   R0_45 <dbl>, R0_50 <dbl>, R0_55 <dbl>, R0_60 <dbl>, R0_65 <dbl>, …

# we are going to tidy up ODE output a few times, so wrap it into a function:
tidy_ode <- function(ode_soln) {
  ode_soln %>%
    pivot_longer(cols = -time) %>%
    mutate(parent_state = substr(name, 1, 1)) %>%
    # group_by(time, parent_state) %>%
    summarise(value = sum(value),
              .by = c(time, parent_state)) %>%
    ungroup() %>%
    rename(name = parent_state)

germany_soln_long <- germany_soln %>%
  tidy_ode() %>%
  mutate(type = "age_structured")

#> # A tibble: 3,003 × 4
#>     time name        value type          
#>    <dbl> <chr>       <dbl> <chr>         
#>  1   0   S     81707799    age_structured
#>  2   0   I           16    age_structured
#>  3   0   R            0    age_structured
#>  4   0.1 S     81707795.   age_structured
#>  5   0.1 I           18.5  age_structured
#>  6   0.1 R            1.72 age_structured
#>  7   0.2 S     81707790.   age_structured
#>  8   0.2 I           21.5  age_structured
#>  9   0.2 R            3.72 age_structured
#> 10   0.3 S     81707784.   age_structured
#> # … with 2,993 more rows

gg_germany_sir <- ggplot(
  aes(x = time, y = value / sum(initial_condition), colour = name)
) +
  geom_line() +
  labs(x = "Time", 
       y = "Proportion",
       colour = "Compartment") +
  scale_colour_discrete(limits = c("S", "I", "R"))


# So we go through a similar process, setting up parameters, and solving the ODE for Prem

parameters_prem <- list(
  "transmission_matrix" = transmission_matrix,
  "homogeneous_contact" = prem_germany_contact_matrices,
  "gamma" = 1,
  "s_indexes" = 1:n_finite_states,
  "i_indexes" = (n_finite_states + 1):(2 * n_finite_states),
  "r_indexes" = (2 * n_finite_states + 1):(3 * n_finite_states)

prem_soln <- ode(
  y = initial_condition,
  times = times,
  func = age_structured_sir,
  parms = parameters_prem

# Have to convert ode output to a data frame to do any plotting
prem_soln <- as_tibble(

#> # A tibble: 6 × 49
#>    time     S0_0    S0_5  S0_10  S0_15  S0_20  S0_25  S0_30  S0_35  S0_40  S0_45
#>   <dbl>    <dbl>   <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
#> 1  99.5 2757450.  2.53e6 1.86e6 1.01e6 1.79e6 2.64e6 2.48e6 1.84e6 1.87e6 3.01e6
#> 2  99.6 2757450.  2.53e6 1.86e6 1.01e6 1.79e6 2.64e6 2.48e6 1.84e6 1.87e6 3.01e6
#> 3  99.7 2757450.  2.53e6 1.86e6 1.01e6 1.79e6 2.64e6 2.48e6 1.84e6 1.87e6 3.01e6
#> 4  99.8 2757450.  2.53e6 1.86e6 1.01e6 1.79e6 2.64e6 2.48e6 1.84e6 1.87e6 3.01e6
#> 5  99.9 2757450.  2.53e6 1.86e6 1.01e6 1.79e6 2.64e6 2.48e6 1.84e6 1.87e6 3.01e6
#> 6 100   2757450.  2.53e6 1.86e6 1.01e6 1.79e6 2.64e6 2.48e6 1.84e6 1.87e6 3.01e6
#> # … with 38 more variables: S0_50 <dbl>, S0_55 <dbl>, S0_60 <dbl>, S0_65 <dbl>,
#> #   S0_70 <dbl>, S0_75 <dbl>, I0_0 <dbl>, I0_5 <dbl>, I0_10 <dbl>, I0_15 <dbl>,
#> #   I0_20 <dbl>, I0_25 <dbl>, I0_30 <dbl>, I0_35 <dbl>, I0_40 <dbl>,
#> #   I0_45 <dbl>, I0_50 <dbl>, I0_55 <dbl>, I0_60 <dbl>, I0_65 <dbl>,
#> #   I0_70 <dbl>, I0_75 <dbl>, R0_0 <dbl>, R0_5 <dbl>, R0_10 <dbl>, R0_15 <dbl>,
#> #   R0_20 <dbl>, R0_25 <dbl>, R0_30 <dbl>, R0_35 <dbl>, R0_40 <dbl>,
#> #   R0_45 <dbl>, R0_50 <dbl>, R0_55 <dbl>, R0_60 <dbl>, R0_65 <dbl>, …

germany_aggregated <- tidy_ode(germany_soln)

# For the stratified model, we have to add up all the age categories together for a fair comparison.
prem_aggregated <- tidy_ode(prem_soln)

conmat_prem_soln <- bind_rows(
  conmat = germany_aggregated,
  prem = prem_aggregated,
  .id = "type"
) %>%
  mutate(name = factor(name, levels = c("S", "I", "R")))

#> # A tibble: 6 × 4
#>   type    time name        value
#>   <chr>  <dbl> <fct>       <dbl>
#> 1 conmat   0   S     81707799   
#> 2 conmat   0   I           16   
#> 3 conmat   0   R            0   
#> 4 conmat   0.1 S     81707795.  
#> 5 conmat   0.1 I           18.5 
#> 6 conmat   0.1 R            1.72
#> # A tibble: 6 × 4
#>   type   time name    value
#>   <chr> <dbl> <fct>   <dbl>
#> 1 prem   99.9 S     5.02e+7
#> 2 prem   99.9 I     2.82e-4
#> 3 prem   99.9 R     3.15e+7
#> 4 prem  100   S     5.02e+7
#> 5 prem  100   I     2.73e-4
#> 6 prem  100   R     3.15e+7

#> # A tibble: 6,006 × 4
#>    type    time name        value
#>    <chr>  <dbl> <fct>       <dbl>
#>  1 conmat   0   S     81707799   
#>  2 conmat   0   I           16   
#>  3 conmat   0   R            0   
#>  4 conmat   0.1 S     81707795.  
#>  5 conmat   0.1 I           18.5 
#>  6 conmat   0.1 R            1.72
#>  7 conmat   0.2 S     81707790.  
#>  8 conmat   0.2 I           21.5 
#>  9 conmat   0.2 R            3.72
#> 10 conmat   0.3 S     81707784.  
#> # … with 5,996 more rows
conmat_prem_soln_40 <- conmat_prem_soln %>%  filter(time <= 40)

       aes(x = time, y = value, colour = type)) +
  geom_line() +
  labs(x = "Time", y = "Population", colour = "Model") +
  facet_wrap(~name, ncol = 1) +
    # labels = scales::label_number(scale_cut = scales::cut_si("")),
    labels = scales::label_number(big.mark = ","),
    n.breaks = 3
  ) + 
  theme(legend.position = "right") + 
  scale_colour_brewer(palette = "Dark2")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment