Skip to content

Instantly share code, notes, and snippets.

@tmellan
Created May 15, 2020 16:19
Show Gist options
  • Save tmellan/3cb7449c6c151349a7cd916a77ba7aa3 to your computer and use it in GitHub Desktop.
Save tmellan/3cb7449c6c151349a7cd916a77ba7aa3 to your computer and use it in GitHub Desktop.
code
library(rstan)
library(matrixStats)
library(data.table)
library(lubridate)
library(gdata)
library(dplyr)
library(tidyr)
library(EnvStats)
library(scales)
library(tidyverse)
library(dplyr)
library(abind)
library(xtable)
library(ggplot2)
library(gridExtra)
library(ggpubr)
library(bayesplot)
library(cowplot)
#library(svglite)
library(openxlsx)
source("geom-stepribbon.r")
source("gammaAlt.r")
source("Brazil/xlsx_preprocessing_subnation_brazil.R")
weight_fatality<-read.csv(paste0("Brazil/IFRS_all.csv"))[c("X","State","IFR_Peru_poorer")]
# weight_fatality<-read.csv(paste0("Brazil/IFRS_all.csv"))[c("State","IFR_UK_poorer")]
# weight_fatality<-read.csv(paste0("Brazil/IFRS_all.csv"))[c("State","IFR_Peru")]
# weight_fatality<-read.csv(paste0("Brazil/IFRS_all.csv"))[c("State","IFR_UK")]
cfr.by.country<-weight_fatality
colnames(cfr.by.country)<-c(" ","region","weighted_fatality")
cfr.by.country
for (scene_i in 1:1){
####################################################################
#### Parameters to input:
scenario = scene_i
DEBUG=TRUE
#DEBUG=FALSE
START_TIME=as.Date("2020-02-19")
END_TIME=as.Date("2020-05-30")
RANGE_TIME=seq(START_TIME,END_TIME,by = '1 day')
countries <- c("RJ","SP","PE","CE","AM","BA","ES","MA","MG","PR","PA","RN","RS","SC","AL","PB")
#countries <- c("RJ","SP")
#FULL set countries
# countries <- c('MS','MT','TO','RR','SE','DF','RO','AC','PI','GO','SC','AP','RN','RS','PR','MG',
# 'PB','AL','ES','BA','MA','PA','AM','PE','CE','RJ','SP')
#countries <- c("RJ","SP")
models = c("base-general-half")
ONSET_to_DEATH=18.8
# countries <- c("RJ","SP","PE","CE","AM")
# countries <- c("RJ","SP","PE","CE","AM","BA")
#countries <- c("RJ","SP")
# conterfactual_rate = c(0.75,1,1.25)
conterfactual_rate = c(1)
# models = c("base-general-half","base-general-full",'base-general-half-underreport-25',
# 'base-general-half-underreport-50',"base-general-half-underreport-100","base-general-half-underreport-200")
# models = c("base-general-half","base-general-full",'base-general-half-underreport-25pc','base-general-half-underreport-50pc',
# "base-general-half-underreport-100pc","base-general-half-underreport-200pc")
#models = c("base-general-half","base-general-full")
# models = c("base-general-full",'base-general-half-underreport-25pc','base-general-half-underreport-50pc',
# "base-general-half-underreport-100pc","base-general-half-underreport-200pc")
#models = c("base-general-half")
#models = c("base-general-half","base-general-full",'base-general-half-underreport-25pc','base-general-half-underreport-50pc',
# "base-general-half-underreport-100pc","base-general-half-underreport-200pc")
# ONSET_to_DEATH=16.9
#ONSET_to_DEATH=20.7
####################################################################A
pars = expand.grid(conterfactual_rate,models)
pars
conterfactual_rate = pars$Var1[scenario]
StanModel = as.character(pars$Var2[scenario])
#Function to fill missings in mobility
f1 <- function(dat) {
N <- length(dat)
na.pos <- which(is.na(dat))
if (length(na.pos) %in% c(0, N)) {
return(dat)
}
non.na.pos <- which(!is.na(dat))
intervals <- findInterval(na.pos, non.na.pos,
all.inside = TRUE)
left.pos <- non.na.pos[pmax(1, intervals)]
right.pos <- non.na.pos[pmin(N, intervals+1)]
left.dist <- na.pos - left.pos
right.dist <- right.pos - na.pos
dat[na.pos] <- ifelse(left.dist <= right.dist,
dat[left.pos], dat[right.pos])
return(dat)
}
#Run from base directory
d<-df
#Set model manually to run in notebook
print(sprintf("Running %s",StanModel))
serial.interval = read.csv("data/serial_interval.csv")
# using covariates as dates we want - currently not used
interventions <- read.csv2(paste0(path,"/brazil_interventions.csv"), sep=";")
interventions[,2] <- dmy(as.character(interventions[,2]))
interventions[,3] <- dmy(as.character(interventions[,3]))
interventions[,4] <- dmy(as.character(interventions[,4]))
interventions[,5] <- dmy(as.character(interventions[,5]))
colnames(interventions) = c("region","Emergency","Retail and Service","Transport","School Closing")
if(DEBUG == FALSE) {
N2 = length(RANGE_TIME)
} else {
N2 = length(RANGE_TIME)
}
dates = list()
reported_cases = list()
# stan_data = list(M=length(countries),N=NULL,covariate1=NULL,covariate2=NULL,covariate3=NULL,covariate4=NULL,covariate5=NULL,covariate6=NULL,covariate7=NULL,deaths=NULL,f=NULL,
# N0=6,cases=NULL,SI=serial.interval$fit[1:N2],
# EpidemicStart = NULL, pop = NULL) # N0 = 6 to make it consistent with Rayleigh
stan_data = list(M=length(countries),N=NULL,covariate1=NULL,covariate2=NULL,covariate3=NULL,covariate4=NULL,deaths=NULL,f=NULL,
N0=6,cases=NULL,SI=serial.interval$fit[1:N2],
EpidemicStart = NULL, pop = NULL) # N0 = 6 to make it consistent with Rayleigh
deaths_by_country = list()
# various distributions required for modeling
mean1 = 5.1; cv1 = 0.86; # infection to onset
mean2 = ONSET_to_DEATH; cv2 = 0.45 # onset to death
x1 = rgammaAlt(1e6,mean1,cv1) # infection-to-onset distribution
x2 = rgammaAlt(1e6,mean2,cv2) # onset-to-death distribution
ecdf.saved = ecdf(x1+x2)
aux.epidemicStart = NULL
for(Country in countries) {
IFR=cfr.by.country$weighted_fatality[which(cfr.by.country$region == Country)]
#IFR<-IFR[!is.na(IFR)]
d1=d[d$region==Country,]
d1_pop = df_pop[df_pop$region==Country,]
aux.dates = seq.Date(from=as.Date("2020-01-01"),d1$DateRep[1],by="1 d")
mat.aux = cbind.data.frame(rep(Country,length(aux.dates)),aux.dates, rep(d1_pop$population[1],length(aux.dates)),
rep(0,length(aux.dates)),rep(0,length(aux.dates)),rep(0,length(aux.dates)),rep(0,length(aux.dates)))
colnames(mat.aux) = colnames(d1)
d1 = rbind.data.frame(d1,mat.aux)
d1 = d1[order(d1$DateRep),]
mobility1=mobility[mobility$region==Country,]
mobility1 = mobility1[order(as.Date(mobility1$date)),] # ensure date ordering
mobility1$date = as.Date(mobility1$date)
# merge d1 and mobility - repeating the ones without data
aux = left_join(d1,mobility1,by=c("DateRep" = "date"))
# input missing fisrt column
aux$region.y = f1(as.character(aux$region.y))
# input missing mobility
idx = which(colnames(aux) %in% c("grocery_pharmacy","parks","residential","retail_recreation","transitstations","workplace"))
aux[,idx] = apply(aux[,idx], 2, function(x) f1(x))
mobility1 = aux[,c("region.x","DateRep","grocery_pharmacy","parks","residential","retail_recreation","transitstations","workplace")]
colnames(mobility1)[1:2] = c("county","date")
mobility1 = mobility1[order(as.Date(mobility1$date)),]
## adding interventions to d1
aux.int = interventions[interventions$region==Country,]
d1$Emergency = rep(0,nrow(d1))
d1$Retail = rep(0,nrow(d1))
d1$Transport = rep(0,nrow(d1))
d1$Schools = rep(0,nrow(d1))
## check if the intervention happened or not
ifelse(!is.na(aux.int$Emergency),d1$Emergency[which(as.Date(d1$DateRep)==as.Date(aux.int$Emergency)):nrow(d1)] <- 1,
d1$Emergency<-0)
ifelse(!is.na(aux.int$`Retail and Service`),d1$Retail[which(as.Date(d1$DateRep)==as.Date(aux.int$`Retail and Service`)):nrow(d1)] <- 1,
d1$Retail<-0)
ifelse(!is.na(aux.int$Transport),d1$Transport[which(as.Date(d1$DateRep)==as.Date(aux.int$Transport)):nrow(d1)] <- 1,
d1$Transport<-0)
ifelse(!is.na(aux.int$`School Closing`),d1$Schools[which(as.Date(d1$DateRep)==as.Date(aux.int$`School Closing`)):nrow(d1)] <- 1,
d1$Schools <- 0)
index = which(d1$Cases>0)[1]
index1 = which(cumsum(d1$Deaths)>=10)[1] # also 5
index2 = index1-30
print(sprintf("First non-zero cases is on day %d, and 30 days before 10 deaths is day %d",index,index2))
d1=d1[index2:nrow(d1),]
aux.epidemicStart = c(aux.epidemicStart,d1$DateRep[index1+1-index2])
stan_data$EpidemicStart = c(stan_data$EpidemicStart,index1+1-index2)
stan_data$pop = c(stan_data$pop, d1_pop$population)
mobility1 = mobility1[index2:nrow(mobility1),]
dates[[Country]] = d1$DateRep
# hazard estimation
N = length(d1$Cases)
N0=N
print(sprintf("%s has %d days of data",Country,N))
forecast = N2 - N
if(forecast < 0) {
print(sprintf("%s: %d", Country, N))
print("ERROR!!!! increasing N2")
N2 = N
forecast = N2 - N
}
# IFR is the overall probability of dying given infection
convolution = function(u) (IFR * ecdf.saved(u))
f = rep(0,N2) # f is the probability of dying on day i given infection
f[1] = (convolution(1.5) - convolution(0))
for(i in 2:N2) {
f[i] = (convolution(i+.5) - convolution(i-.5))
}
reported_cases[[Country]] = as.vector(as.numeric(d1$Cases))
deaths=c(as.vector(as.numeric(d1$Deaths)),rep(-1,forecast))
cases=c(as.vector(as.numeric(d1$Cases)),rep(-1,forecast))
deaths_by_country[[Country]] = as.vector(as.numeric(d1$Deaths))
library(forecast)
#covariate for mobility now being passed
covariates2 <- as.data.frame(mobility1[, c("grocery_pharmacy","parks","residential","retail_recreation","transitstations","workplace")])
models = apply(covariates2, 2, function(x) auto.arima(x, seasonal = T))
mat.forecast = lapply(models, function(x) forecast(x,length((N+1):(N+forecast)))$mean)
covariates2[(N+1):(N+forecast),] <- conterfactual_rate*cbind(mat.forecast$grocery_pharmacy,mat.forecast$parks,mat.forecast$residential,mat.forecast$retail_recreation,
mat.forecast$transitstations,mat.forecast$workplace)
#covariates2[(N+1):(N+forecast),] <- conterfactual_rate*covariates2[N,]
average <- (covariates2[,1] + covariates2[,4]+ covariates2[,6])/3
stan_data$covariate1 = cbind(stan_data$covariate1,covariates2[,3]) #Residential
stan_data$covariate2 = cbind(stan_data$covariate2,covariates2[,5]) #Transitstations
stan_data$covariate3 = cbind(stan_data$covariate3,average) #Mean of Grocery, Retail, workplace
stan_data$covariate4 = cbind(stan_data$covariate4,covariates2[,2]) #Parks
# covariates for interventions and week effects
#covariates3 <- as.data.frame(d1[,c("Emergency","Retail","Transport","Schools","Week","Weekend")])
# covariates3 <- as.data.frame(d1[,c("Emergency","Retail","Transport","Schools")])
# covariates3[N:(N+forecast),] <- covariates3[N,]
# stan_data$covariate4 = cbind(stan_data$covariate4,covariates3[,1])
# stan_data$covariate5 = cbind(stan_data$covariate5,covariates3[,2])
# stan_data$covariate6 = cbind(stan_data$covariate6,covariates3[,3])
# stan_data$covariate7 = cbind(stan_data$covariate7,covariates3[,4])
# stan_data$covariate8 = cbind(stan_data$covariate8,covariates3[,5])
# stan_data$covariate9 = cbind(stan_data$covariate9,covariates3[,6])
stan_data$N = c(stan_data$N,N)
stan_data$f = cbind(stan_data$f,f)
stan_data$deaths = cbind(stan_data$deaths,deaths)
stan_data$cases = cbind(stan_data$cases,cases)
stan_data$N2=N2
stan_data$x=1:N2
if(length(stan_data$N) == 1) {
stan_data$N = as.array(stan_data$N)
}
}
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)
m = stan_model(paste0('Brazil/stan-models/',StanModel,'.stan'))
## Adding everything to the X matrix as General is doing
# stan_data$X = list(stan_data$covariate1,stan_data$covariate2,stan_data$covariate3,stan_data$covariate4,stan_data$covariate5,
# stan_data$covariate6,stan_data$covariate7)
stan_data$X = list(stan_data$covariate1,stan_data$covariate2,stan_data$covariate3,stan_data$covariate4)
stan_data$P = length(stan_data$X)
if(DEBUG) {
fit = sampling(m,data=stan_data,iter=20,warmup=10,chains=3)
} else {
fit = sampling(m,data=stan_data,iter=1500,warmup=500,chains=8,thin=1, control = list(adapt_delta = 0.95, max_treedepth = 15))
}
out = rstan::extract(fit)
prediction = out$prediction
estimated.deaths = out$E_deaths
estimated.deaths.cf = out$E_deaths0
JOBID = Sys.getenv("PBS_JOBID")
if(JOBID == "")
JOBID = as.character(abs(round(rnorm(1) * 1000000)))
print(sprintf("Jobid = %s",JOBID))
filename <- paste0(StanModel,'-',JOBID)
save.image(paste0('Brazil/results/',StanModel,'-',filename,"counterfactual_",conterfactual_rate,".Rdata"))
table_paper = NULL
for(i in 1:length(countries)){
print(i)
N <- length(dates[[i]])
country <- countries[[i]]
predicted_cases <- colMeans(prediction[,1:N,i])
predicted_cases_li <- colQuantiles(prediction[,1:N,i], probs=.025)
predicted_cases_ui <- colQuantiles(prediction[,1:N,i], probs=.975)
predicted_cases_li2 <- colQuantiles(prediction[,1:N,i], probs=.25)
predicted_cases_ui2 <- colQuantiles(prediction[,1:N,i], probs=.75)
estimated_deaths <- colMeans(estimated.deaths[,1:N,i])
estimated_deaths_li <- colQuantiles(estimated.deaths[,1:N,i], probs=.025)
estimated_deaths_ui <- colQuantiles(estimated.deaths[,1:N,i], probs=.975)
estimated_deaths_li2 <- colQuantiles(estimated.deaths[,1:N,i], probs=.25)
estimated_deaths_ui2 <- colQuantiles(estimated.deaths[,1:N,i], probs=.75)
rt <- colMeans(out$Rt_adj[,1:N,i])
rt_li <- colQuantiles(out$Rt_adj[,1:N,i],probs=.025)
rt_ui <- colQuantiles(out$Rt_adj[,1:N,i],probs=.975)
rt_li2 <- colQuantiles(out$Rt_adj[,1:N,i],probs=.25)
rt_ui2 <- colQuantiles(out$Rt_adj[,1:N,i],probs=.75)
data_country <- data.frame("time" = as_date(as.character(dates[[i]])),
"country" = rep(country, length(dates[[i]])),
"reported_cases" = reported_cases[[i]],
"reported_cases_c" = cumsum(reported_cases[[i]]),
"predicted_cases_c" = cumsum(predicted_cases),
"predicted_min_c" = cumsum(predicted_cases_li),
"predicted_max_c" = cumsum(predicted_cases_ui),
"predicted_cases" = predicted_cases,
"predicted_min" = predicted_cases_li,
"predicted_max" = predicted_cases_ui,
"predicted_min2" = predicted_cases_li2,
"predicted_max2" = predicted_cases_ui2,
"deaths" = deaths_by_country[[i]],
"deaths_c" = cumsum(deaths_by_country[[i]]),
"estimated_deaths_c" = cumsum(estimated_deaths),
"death_min_c" = cumsum(estimated_deaths_li),
"death_max_c"= cumsum(estimated_deaths_ui),
"estimated_deaths" = estimated_deaths,
"death_min" = estimated_deaths_li,
"death_max"= estimated_deaths_ui,
"death_min2" = estimated_deaths_li2,
"death_max2"= estimated_deaths_ui2,
"rt" = rt,
"rt_min" = rt_li,
"rt_max" = rt_ui,
"rt_min2" = rt_li2,
"rt_max2" = rt_ui2)
aux = data_country[,c("reported_cases","predicted_cases","predicted_min2","predicted_max2","deaths","death_min2","death_max2")]
aux2 = data.frame(as.character(country),
tail(apply(aux, 2, cumsum),1),
(df_pop[which(df_pop$region==country),2])
)
table_paper = rbind.data.frame(table_paper,aux2)
# aux_a = data_country[,c("reported_cases","predicted_cases","predicted_min2","predicted_max2")]
# aux_a2 = data.frame(as.character(country),
# tail(apply(aux_a, 2, cumsum),1),
# df_pop[which(df_pop$region==country),2]
# )
# table_paper_cases = rbind.data.frame(table_paper_cases,aux_a2)
data_cases_95 <- data.frame(data_country$time, data_country$predicted_min,
data_country$predicted_max)
names(data_cases_95) <- c("time", "cases_min", "cases_max")
data_cases_95$key <- rep("nintyfive", length(data_cases_95$time))
data_cases_50 <- data.frame(data_country$time, data_country$predicted_min2,
data_country$predicted_max2)
names(data_cases_50) <- c("time", "cases_min", "cases_max")
data_cases_50$key <- rep("fifty", length(data_cases_50$time))
data_cases <- rbind(data_cases_95, data_cases_50)
levels(data_cases$key) <- c("ninetyfive", "fifty")
p1 <- ggplot(data_country) +
geom_bar(data = data_country, aes(x = time, y = reported_cases),
fill = "coral4", stat='identity', alpha=0.5) +
geom_ribbon(data = data_cases,
aes(x = time, ymin = cases_min, ymax = cases_max, fill = key)) +
xlab("") +
ylab("Daily number of infections\n") +
scale_x_date(date_breaks = "2 weeks", labels = date_format("%e %b")) +
scale_y_continuous(expand = c(0, 0), labels = comma) +
scale_fill_manual(name = "", labels = c("50%", "95%"),
values = c(alpha("deepskyblue4", 0.55),
alpha("deepskyblue4", 0.45))) +
theme_pubr() +
theme(axis.text.x = element_text(angle = 45, hjust = 1),
legend.position = "None") + ggtitle(df_region_codes[which(df_region_codes[,1]==country),2]) +
guides(fill=guide_legend(ncol=1))
data_deaths_95 <- data.frame(data_country$time, data_country$death_min,
data_country$death_max)
names(data_deaths_95) <- c("time", "death_min", "death_max")
data_deaths_95$key <- rep("nintyfive", length(data_deaths_95$time))
data_deaths_50 <- data.frame(data_country$time, data_country$death_min2,
data_country$death_max2)
names(data_deaths_50) <- c("time", "death_min", "death_max")
data_deaths_50$key <- rep("fifty", length(data_deaths_50$time))
data_deaths <- rbind(data_deaths_95, data_deaths_50)
levels(data_deaths$key) <- c("ninetyfive", "fifty")+ coord_fixed(ratio = 10)
p2 <- ggplot(data_country, aes(x = time)) +
geom_bar(data = data_country, aes(y = deaths, fill = "reported"),
fill = "coral4", stat='identity', alpha=0.5) +
geom_ribbon(
data = data_deaths,
aes(ymin = death_min, ymax = death_max, fill = key)) +
scale_x_date(date_breaks = "2 weeks", labels = date_format("%e %b")) +
scale_y_continuous(expand = c(0, 0), labels = comma) +
scale_fill_manual(name = "", labels = c("50%", "95%"),
values = c(alpha("deepskyblue4", 0.55),
alpha("deepskyblue4", 0.45))) +
ylab("Daily number of deaths\n") +
xlab("") +
theme_pubr() +
theme(axis.text.x = element_text(angle = 45, hjust = 1),
legend.position = "None") +
guides(fill=guide_legend(ncol=1))
# Plotting interventions
data_rt_95 <- data.frame(data_country$time,
data_country$rt_min, data_country$rt_max)
names(data_rt_95) <- c("time", "rt_min", "rt_max")
data_rt_95$key <- rep("nintyfive", length(data_rt_95$time))
data_rt_50 <- data.frame(data_country$time, data_country$rt_min2,
data_country$rt_max2)
names(data_rt_50) <- c("time", "rt_min", "rt_max")
data_rt_50$key <- rep("fifty", length(data_rt_50$time))
data_rt <- rbind(data_rt_95, data_rt_50)
levels(data_rt$key) <- c("ninetyfive", "fifth")
# interventions
# # delete these 2 lines
covariates_country <- interventions[which(interventions$region == country),-1]
covariates_country_long <- gather(covariates_country, key = "key",
value = "value")
covariates_country_long$x <- rep(NULL, length(covariates_country_long$key))
un_dates <- unique(covariates_country_long$value)
for (k in 1:length(un_dates)){
idxs <- which(covariates_country_long$value == un_dates[k])
max_val <- round(max(rt_ui)) + 0.3
for (j in idxs){
covariates_country_long$x[j] <- max_val
max_val <- max_val - 0.3
}
}
covariates_country_long$value <- as_date(covariates_country_long$value)
covariates_country_long$country <- rep(country,
length(covariates_country_long$value))
# plot_labels <- c("Emergency","Retail and Service","Transport","School Closing")
plot_labels <- c("Emergency","Retail and Service","School Closing","Transport")
p3 <- ggplot(data_country) +
geom_ribbon(data = data_rt, aes(x = time, ymin = rt_min, ymax = rt_max,
group = key,
fill = key)) +
geom_hline(yintercept = 1, color = 'black', size = 0.1) +
geom_segment(data = covariates_country_long,
aes(x = value, y = 0, xend = value, yend = max(x)),
linetype = "dashed", colour = "grey", alpha = 0.75) +
geom_point(data = covariates_country_long, aes(x = value,
y = x,
group = key,
shape = key,
col = key), size = 2) +
xlab("") +
ylab(expression(R[t])) +
scale_fill_manual(name = "", labels = c("50%", "95%"),
values = c(alpha("seagreen", 0.75), alpha("seagreen", 0.5))) +
scale_shape_manual(name = "Interventions", labels = plot_labels,
values = c(21, 22, 23, 24, 25, 12)) +
scale_colour_discrete(name = "Interventions", labels = plot_labels) +
scale_x_date(date_breaks = "weeks", labels = date_format("%e %b"),
limits = c(data_country$time[1],
data_country$time[length(data_country$time)])) +
scale_y_continuous(expand = expansion(mult=c(0,0.1))) +
theme_pubr() +
theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
theme(legend.position="right")
ptmp <- plot_grid(p1, p2, p3, ncol = 3, rel_widths = c(0.75, 0.75, 1))
print(ptmp)
#save_plot(filename = paste0("Brazil/figures/", country, "_three_pannel_", filename2, ".png"), p, base_width = 14)
#ggsave(ptmp, file=paste0("Brazil/figures/", country, "_three_pannel_", JOBID,'-',filename2, ".png"), width = 14)
ggsave(ptmp,
file=paste0("Brazil/figures/",country, "counterfactual_",conterfactual_rate,",_three_pannel_", JOBID,'-',StanModel, ".pdf"), width = 14, height = 5)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment