Skip to content

Instantly share code, notes, and snippets.

@jamesuanhoro
Last active April 15, 2024 19:49
Show Gist options
  • Save jamesuanhoro/a304312d11cdea5815d34595eef20686 to your computer and use it in GitHub Desktop.
Save jamesuanhoro/a304312d11cdea5815d34595eef20686 to your computer and use it in GitHub Desktop.
Base hierarchical ordinal regression for analysis of single case design
data {
int N;
int Ncase;
int Nmax;
array[N] int<lower = 1, upper = Ncase> case_id;
array[N] int<lower = 1, upper = Nmax> y_ord;
array[N] int treat;
int Nspl;
matrix[N, Nspl] SplMat;
int<lower = 1> Ncut_spl;
matrix[Nmax - 1, Ncut_spl] CutMat;
vector[Nmax] y_s;
int Ncount;
vector[Ncount] count;
}
parameters {
real intercept;
real<lower = 0> cut_spl_sd;
vector<lower = 0>[Ncut_spl] cut_spl;
vector<lower = 0>[2] spl_sds;
vector<lower = 0>[2] beta_sds;
real<multiplier = beta_sds[1]> beta;
vector<offset = beta, multiplier = beta_sds[2]>[Ncase] beta_s;
vector<multiplier = spl_sds[1]>[Nspl] spl_vec;
matrix<multiplier = spl_sds[2]>[Nspl, Ncase] spl_vec_p;
}
model {
intercept ~ normal(0, 2.5);
cut_spl_sd ~ student_t(3, 0, 1.0);
cut_spl ~ std_normal();
spl_sds ~ student_t(3, 0, 1.0);
spl_vec ~ normal(0, spl_sds[1]);
to_vector(spl_vec_p) ~ normal(0, spl_sds[2]);
beta_sds ~ student_t(3, 0, 1.0);
beta ~ normal(0, beta_sds[1]);
beta_s ~ normal(beta, beta_sds[2]);
{
vector[N] case_spl;
for (i in 1:N) case_spl[i] = treat[i] * beta_s[case_id[i]] +
SplMat[i, ] * (spl_vec + spl_vec_p[, case_id[i]]);
y_ord ~ ordered_logistic(case_spl, -intercept + CutMat * (cut_spl * cut_spl_sd));
}
}
generated quantities {
vector[Ncase * 2] mean_s;
vector[Ncase * 2] median_s;
vector[Ncase] mean_diff;
vector[Ncase] median_diff;
vector[Ncase] log_mean_ratio;
vector[Ncase] log_median_ratio;
vector[Ncase] nap;
vector[Ncase] tau;
vector[Ncase] pem;
vector[N] y_hat;
array[N] int ord_sim;
vector[N] y_sim;
{
vector[N] case_spl;
matrix[Nmax, Ncase * 2] pmf_mat = rep_matrix(0.0, Nmax, Ncase * 2);
vector[Nmax] pmf_vec;
vector[Nmax - 1] cuts = -intercept + CutMat * (cut_spl * cut_spl_sd);
vector[Nmax - 1] p_prob;
int mat_col_id;
int col_id;
vector[Nmax] cdf = rep_vector(0.0, Nmax);
vector[Nmax] d0_vec;
vector[Nmax] d1_vec;
vector[Nmax] ccd1_vec;
for (i in 1:N) {
mat_col_id = (case_id[i] - 1) * 2 + treat[i] + 1;
case_spl[i] = treat[i] * beta_s[case_id[i]] +
SplMat[i, ] * (spl_vec + spl_vec_p[, case_id[i]]);
p_prob = inv_logit(cuts - case_spl[i]);
for (j in 1:(Nmax - 1)) {
if (j == 1) {
pmf_vec[j] = p_prob[j];
} else {
pmf_vec[j] = p_prob[j] - p_prob[j - 1];
}
}
pmf_vec[Nmax] = 1 - p_prob[Nmax - 1];
y_hat[i] = sum(pmf_vec .* y_s);
ord_sim[i] = ordered_logistic_rng(case_spl[i], -intercept + CutMat * (cut_spl * cut_spl_sd));
y_sim[i] = y_s[ord_sim[i]];
pmf_mat[, mat_col_id] = pmf_mat[, mat_col_id] + pmf_vec;
}
for (i in 1:Ncase) {
real med_a = 0;
real med_b = 0;
for (j in 1:2) {
col_id = (i - 1) * 2 + j;
pmf_mat[, col_id] = pmf_mat[, col_id] / count[col_id];
mean_s[col_id] = sum(pmf_mat[, col_id] .* y_s);
cdf = cumulative_sum(pmf_mat[, col_id]);
if (cdf[1] >= .5) {
median_s[col_id] = y_s[1];
} else {
for (k in 2:Nmax) {
if (cdf[k] == .5) {
median_s[col_id] = y_s[k];
} else if (cdf[k - 1] < .5 && cdf[k] > .5) {
median_s[col_id] = y_s[k - 1] + (.5 - cdf[k - 1]) *
(y_s[k] - y_s[k - 1]) / (cdf[k] - cdf[k - 1]);
}
}
}
}
mean_diff[i] = mean_s[col_id] - mean_s[col_id - 1];
median_diff[i] = median_s[col_id] - median_s[col_id - 1];
log_mean_ratio[i] = log(mean_s[col_id]) - log(mean_s[col_id - 1]);
log_median_ratio[i] = log(median_s[col_id]) - log(median_s[col_id - 1]);
d0_vec = pmf_mat[, col_id - 1];
d1_vec = pmf_mat[, col_id];
ccd1_vec = 1.0 - cumulative_sum(d1_vec);
nap[i] = sum(d0_vec .* ccd1_vec + .5 * d0_vec .* d1_vec);
tau[i] = nap[i] * 2.0 - 1.0;
for (k in 1:Nmax) {
if (median_s[col_id - 1] == y_s[k]) {
med_a = .5 * pmf_mat[k, col_id];
}
if (median_s[col_id - 1] < y_s[k]) {
med_b += pmf_mat[k, col_id];
}
}
pem[i] = med_a + med_b;
}
}
}
cb_palette <- c(
"#999999", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2",
"#D55E00", "#CC79A7"
)
### Block 01 -- Install required packages
list_of_packages <- c(
"SingleCaseES", "data.table", "ggplot2", "splines2", "patchwork",
"ggdist", "scales", "posterior"
)
new_packages <- list_of_packages[
!(list_of_packages %in% installed.packages()[, "Package"])
]
if (length(new_packages)) install.packages(new_packages)
rm(list_of_packages, new_packages)
# install CmdStanR
if (!"cmdstanr" %in% installed.packages()[, "Package"]) {
# next install cmdstanr and CmdStan:
install.packages(
"cmdstanr",
repos = c("https://mc-stan.org/r-packages/", getOption("repos"))
)
}
# the next piece of code tries to install CmdStan, skip it if you have it
if (is.null(cmdstanr::cmdstan_version(error_on_NA = FALSE))) {
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(
cores = max(c(1, round(parallel::detectCores() / 3)))
)
}
### Block 02 -- Load packages
library(data.table)
library(ggplot2)
library(ggdist)
library(patchwork)
library(scales)
library(splines2)
### Prelims
extract_es <- function(fit, es, case_names, confidence = .95) {
stopifnot(length(es) == 1)
stopifnot(es %in% c(
"nap", "mean_diff", "median_diff", "log_mean_ratio",
"log_median_ratio", "nap", "tau", "pem"
))
lo <- (1 - confidence) / 2
ret <- data.table::as.data.table(fit$summary(
es,
~ quantile(.x, c(.5, lo, 1 - lo)), sd,
posterior::default_convergence_measures()
))
ret$person <- case_names
ret$ES <- toupper(es)
setnames(
ret, c("50%", "sd", paste0(100 * c(lo, 1 - lo), "%")),
c("Est", "SE", "CI_lower", "CI_upper")
)
ret <- ret[, .(person, ES, Est, SE, CI_lower, CI_upper, rhat, ess_bulk)]
return(ret)
}
gen_theme <- theme_bw() +
theme(
legend.position = "top", strip.background = element_blank(),
axis.title = element_blank(), panel.border = element_blank(),
panel.spacing.x = unit(.5, "cm"), panel.grid.major.y = element_blank(),
axis.ticks.y = element_blank(), axis.line.x = element_line(linewidth = .25)
)
dat_full <- as.data.table(SingleCaseES::Schmidt2012)
dat_full[, person := Case]
dat_full[, time := as.integer(Session_num)]
dat_full[, case := as.integer(factor(person))]
dat_full
head(dat_full[Behavior == "Conversation"])
head(dat_full[Behavior == "Initiations"])
head(dat_full[Behavior == "Responses"])
ggplot(dat_full, aes(time, Outcome)) +
geom_point(aes(col = factor(Trt))) +
geom_line() +
facet_grid(Behavior ~ person, scales = "free_y") +
theme_bw() +
theme(legend.position = "top")
ggsave("schmidt2012.png", width = 6.5, height = 5)
# ----
dat <- dat_full[Behavior == "Initiations"]
dat_list <- list(
N = nrow(dat), # number of rows of data
Ncase = max(dat$case), # number of cases
Ntime = max(dat$time), # maximum number of timepoints
case_id = dat$case, # case ID variable (integers)
time_id = dat$time, # time variable (can be continuous)
y_s = unique(sort(dat$Outcome)), # unique data points in outcome
y_ord = as.integer(ordered(dat$Outcome)), # outcome transformed to ranks
treat = as.integer(dat$Trt == 1) # treatment phase indicator
)
dat_list$Nmax <- max(dat_list$y_ord) # number of unique data points in outcome
# convert time to B-spline
dat_list$SplMat <- splines2::bSpline(
dat_list$time_id,
df = 15 + 3, degree = 3
)
dat_list$Nspl <- ncol(dat_list$SplMat) # number of basis functions
# I-spline for ordered thresholds
# minimum of 18 or number of thresholds - 2
dat_list$CutMat <- -.5 + splines2::iSpline(
1:(dat_list$Nmax - 1),
df = min(15 + 3, dat_list$Nmax - 2), degree = 2
)
dat_list$Ncut_spl <- ncol(dat_list$CutMat) # number of basis functions
dat_list
# number of data points per case per phase, ordered by phase within case
(gm_count <- dat[
, .N, list(person, case, x = as.integer(Trt == 1) + 1)
][order(case, x)])
dat_list$Ncount <- nrow(gm_count) # number of case-phases
dat_list$count <- gm_count$N # number of data points computed in gm_count
dat_list
# compile Stan model (only necessary first time)
hier_spl_mod <- cmdstanr::cmdstan_model("hier_spline_cuts_spl.stan")
# fit model with some default choices
fit <- hier_spl_mod$sample(
data = dat_list, seed = 12345, iter_warmup = 1e3, iter_sampling = 1e3,
chains = 4, parallel_chains = 4
)
# assess sampler problems
fit$cmdstan_diagnose()
# assess parameter convergence
print(fit$summary(c(
"intercept", "cut_spl_sd", "cut_spl", "spl_sds",
"spl_vec", "spl_vec_p",
"beta_sds", "beta", "beta_s"
)), n = 100)
bayesplot::mcmc_hist(fit$draws(c(
"intercept", "cut_spl_sd", "cut_spl", "spl_sds",
"beta_sds", "beta", "beta_s"
)))
bayesplot::mcmc_trace(fit$draws(c(
"intercept", "cut_spl_sd", "cut_spl", "spl_sds",
"beta_sds", "beta", "beta_s"
)))
# substantively interesting parameters (printed by case)
print(fit$summary(c(
"mean_s",
"mean_diff", "log_mean_ratio",
"median_s",
"median_diff", "log_median_ratio",
"nap", "pem", "tau"
)), n = 40)
# predictions for each case-timepoint
yhat_s <- fit$summary("y_hat")
dat_merge <- cbind(dat, yhat_s[, c("median", "q5", "q95")])
setnames(dat_merge, c("median", "q5", "q95"), c("md", "lo", "hi"))
dat_merge
# predictions with 90% intervals
ggplot(dat_merge, aes(time, Outcome)) +
geom_point(aes(col = factor(Trt))) +
geom_line(linewidth = .125) +
facet_wrap(~person, ncol = 2) +
theme_bw() +
geom_line(aes(y = md), linetype = 2) +
theme(legend.position = "top") +
geom_ribbon(aes(ymin = lo, ymax = hi), alpha = .2)
# posterior predictive distribution for each case-timepoint
ysim_s <- fit$summary("ord_sim")
dat_sims <- cbind(dat, ysim_s[, c("median", "q5", "q95")])
setnames(dat_sims, c("median", "q5", "q95"), c("md", "lo", "hi"))
dat_sims[]
# create ranked version of outcome
dat_sims[, ord_outcome := frank(Outcome, ties.method = "dense")]
# distribution with 90% intervals
ggplot(dat_sims, aes(time, ord_outcome)) +
geom_point(aes(col = factor(Trt))) +
geom_line(linewidth = .125) +
facet_wrap(~person, ncol = 2) +
theme_bw() +
geom_line(aes(y = md), linetype = 2) +
theme(legend.position = "top") +
geom_ribbon(aes(ymin = lo, ymax = hi), alpha = .2) +
labs(y = "Ranked outcome")
# case aggregated data
tmp_gm <- gm_count[, .(x = sum(x), N = sum(N)), list(person, case)]
# collect effect sizes in order in c()
(gmean_s <- fit$summary(
c(
"mean_s", "mean_diff", "log_mean_ratio",
"median_s", "median_diff", "log_median_ratio",
"nap", "pem", "tau"
),
~ quantile(.x, probs = c(.025, .1, .5, .9, .975)),
posterior::default_convergence_measures()
)[, c("2.5%", "10%", "50%", "90%", "97.5%", "rhat", "ess_bulk", "ess_tail")])
gmean_sum <- as.data.table(cbind(
# row-bind: case-time counts, case-counts and repeated in order above
# to column-join join with effect sizes table
rbindlist(
list(
gm_count, tmp_gm, tmp_gm,
gm_count, tmp_gm, tmp_gm,
tmp_gm, tmp_gm, tmp_gm
),
idcol = "metric"
),
gmean_s
))
gmean_sum[] # metric column contains id
gmean_sum[, metric_t := c(
"means", "mean diff", "log(mean ratio)",
"medians", "median diff", "log(median ratio)",
"NAP", "PEM", "TAU"
)[metric]]
# label phases
gmean_sum[, x_lab := factor(c("A", "B", "B > A")[x], c("A", "B", "B > A"))]
# set null conditions for effect sizes
gmean_sum[, null := c(NA, 0, 0, NA, 0, 0, .5, .5, 0)[metric]]
gmean_sum
ggplot(gmean_sum, aes(reorder(person, -case), y = `50%`, col = x_lab)) +
# Pointrange w/ 95% interval
geom_pointrange(
aes(ymin = `2.5%`, ymax = `97.5%`),
position = position_dodge(.5)
) +
# Bar w/ 80% interval
geom_linerange(aes(ymin = `10%`, ymax = `90%`),
alpha = .25, linewidth = 3,
position = position_dodge(.5)
) +
# Null conditions
geom_hline(aes(yintercept = null), linetype = 2, alpha = .125) +
geom_text(aes(label = number(`50%`, .01)),
position = position_dodge(.5), vjust = -.7,
size = 3, data = gmean_sum[metric %in% c(2, 5)]
) +
geom_text(aes(label = percent(`50%`, 1)),
position = position_dodge(.5), vjust = -.7,
size = 3, data = gmean_sum[metric %in% c(7:8)]
) +
geom_text(aes(label = paste0(ifelse(`50%` > 0, "+", ""), percent(`50%`, 1))),
position = position_dodge(.5), vjust = -.7,
size = 3, data = gmean_sum[metric %in% c(3, 6, 9)]
) +
coord_flip() +
gen_theme +
scale_color_manual(values = cb_palette[c(2, 3, 1)], name = "") +
facet_wrap(~ reorder(metric_t, metric), scales = "free_x", nrow = 3)
ggsave("schmidt2012_init.png", width = 6.5, height = 5.25)
frq_dt <- rbindlist(list(
dat[, SingleCaseES::NAP(Outcome[Trt == 0], Outcome[Trt == 1]), person],
dat[, SingleCaseES::PEM(Outcome[Trt == 0], Outcome[Trt == 1]), person],
dat[, SingleCaseES::LRRi(Outcome[Trt == 0], Outcome[Trt == 1]), person],
dat[, SingleCaseES::LRM(Outcome[Trt == 0], Outcome[Trt == 1]), person]
), fill = TRUE)
frq_dt[]
ord_es_dt <- rbindlist(
lapply(c("nap", "pem", "log_mean_ratio", "log_median_ratio"), function(es) {
extract_es(fit, es, dat[order(case), .N, list(Case)]$Case, .95)
})
)
ord_es_dt[]
joined_es_dt <- rbindlist(
list(frq_dt, ord_es_dt),
fill = TRUE, idcol = "approach"
)
joined_es_dt[]
joined_es_dt[, approach_t := factor(
approach, 2:1,
c("Bayesian Ordinal (hierarchical)", "Frequentist (non-hierarchical)")
)]
joined_es_dt[]
joined_es_dt[grepl("MEAN_RA", ES), ES := "LRRi"]
joined_es_dt[grepl("MEDIAN_RA", ES), ES := "LRM"]
joined_es_dt[]
joined_es_dt[, .N, ES]
joined_es_dt[, null := fcase(
ES %in% c("NAP", "PEM"), .5,
grepl("LR", ES), 0
)]
pt_dodge <- position_dodge(width = .8)
ggplot(joined_es_dt, aes(
reorder(person, -as.integer(factor(person))),
Est, ymin = CI_lower, ymax = CI_upper, group = approach_t,
shape = approach_t, colour = approach_t, label = percent(Est, 1)
)) +
geom_text(position = pt_dodge, vjust = -.5, size = 2.5) +
geom_point(position = pt_dodge) +
geom_linerange(position = pt_dodge) +
geom_hline(aes(yintercept = null), linetype = 2) +
scale_color_manual(values = cb_palette[2:1]) +
scale_shape_manual(values = c(4, 1)) +
scale_y_continuous(labels = percent_format()) +
facet_wrap(~ factor(ES, c("NAP", "PEM", "LRRi", "LRM")), scales = "free_x") +
coord_flip() +
guides(
col = guide_legend(reverse = TRUE), shape = guide_legend(reverse = TRUE)
) +
gen_theme +
labs(col = "", shape = "")
ggsave("schmidt2012_init_comp.png", width = 6.5, height = 4)
# ----
dat <- dat_full[Behavior == "Conversation"]
dat_list <- list(
N = nrow(dat), Ncase = max(dat$case), Ntime = max(dat$time),
case_id = dat$case, time_id = dat$time, y_s = unique(sort(dat$Outcome)),
y_ord = as.integer(ordered(dat$Outcome)),
treat = as.integer(dat$Trt == 1)
)
dat_list$Nmax <- max(dat_list$y_ord)
dat_list$SplMat <- splines2::bSpline(dat_list$time_id, df = 15 + 3, degree = 3)
dat_list$Nspl <- ncol(dat_list$SplMat)
dat_list$CutMat <- -.5 + splines2::iSpline(
1:(dat_list$Nmax - 1),
df = min(15 + 3, dat_list$Nmax - 2), degree = 2
)
dat_list$Ncut_spl <- ncol(dat_list$CutMat)
dat_list
(gm_count <- dat[
, .N, list(person, case, x = as.integer(Trt == 1) + 1)
][order(case, x)])
dat_list$Ncount <- nrow(gm_count)
dat_list$count <- gm_count$N
dat_list
hier_spl_mod <- cmdstanr::cmdstan_model("hier_spline_cuts_spl.stan")
fit <- hier_spl_mod$sample(
data = dat_list, seed = 12345, iter_warmup = 1e3, iter_sampling = 1e3,
chains = 4, parallel_chains = 4, adapt_delta = .9
)
fit$cmdstan_diagnose()
print(fit, c(
"intercept", "cut_spl_sd", "cut_spl", "spl_sds", "beta_sds",
"beta", "beta_s"
))
print(fit$summary(c(
# "mean_s",
"mean_diff", "log_mean_ratio",
# "median_s",
"median_diff", "log_median_ratio",
"nap", "pem", "tau"
)), n = 40)
yhat_s <- fit$summary("y_hat")
dat_merge <- cbind(dat, yhat_s[, c("median", "q5", "q95")])
setnames(dat_merge, c("median", "q5", "q95"), c("md", "lo", "hi"))
dat_merge
ggplot(dat_merge, aes(time, Outcome)) +
geom_point(aes(col = factor(Trt))) +
geom_line(linewidth = .125) +
facet_wrap(~person, ncol = 2) +
theme_bw() +
geom_line(aes(y = md), linetype = 2) +
theme(legend.position = "top") +
geom_ribbon(aes(ymin = lo, ymax = hi), alpha = .2)
tmp_gm <- gm_count[, .(x = sum(x), N = sum(N)), list(person, case)]
(gmean_s <- fit$summary(
c(
"mean_s", "mean_diff", "log_mean_ratio",
"median_s", "median_diff", "log_median_ratio",
"nap", "pem", "tau"
),
~ quantile(.x, probs = c(.025, .1, .5, .9, .975)),
posterior::default_convergence_measures()
)[, c("2.5%", "10%", "50%", "90%", "97.5%", "rhat", "ess_bulk", "ess_tail")])
gmean_sum <- as.data.table(cbind(
rbindlist(
list(
gm_count, tmp_gm, tmp_gm, gm_count, tmp_gm, tmp_gm,
tmp_gm, tmp_gm, tmp_gm
),
idcol = "metric"
),
gmean_s
))
gmean_sum[, metric_t := c(
"means", "mean diff", "log(mean ratio)",
"medians", "median diff", "log(median ratio)",
"NAP", "PEM", "TAU"
)[metric]]
gmean_sum[, x_lab := factor(c("A", "B", "B > A")[x], c("A", "B", "B > A"))]
gmean_sum[, null := c(NA, 0, 0, NA, 0, 0, .5, .5, 0)[metric]]
gmean_sum
ggplot(gmean_sum, aes(reorder(person, -case), y = `50%`, col = x_lab)) +
# Pointrange w/ 95% interval
geom_pointrange(
aes(ymin = `2.5%`, ymax = `97.5%`),
position = position_dodge(.5)
) +
# Bar w/ 80% interval
geom_linerange(aes(ymin = `10%`, ymax = `90%`),
alpha = .25, linewidth = 3,
position = position_dodge(.5)
) +
# Null conditions
geom_hline(aes(yintercept = null), linetype = 2, alpha = .125) +
geom_text(aes(label = number(`50%`, .01)),
position = position_dodge(.5), vjust = -.7,
size = 3, data = gmean_sum[metric %in% c(2, 5)]
) +
geom_text(aes(label = percent(`50%`, 1)),
position = position_dodge(.5), vjust = -.7,
size = 3, data = gmean_sum[metric %in% c(7:8)]
) +
geom_text(aes(label = paste0(ifelse(`50%` > 0, "+", ""), percent(`50%`, 1))),
position = position_dodge(.5), vjust = -.7,
size = 3, data = gmean_sum[metric %in% c(3, 6, 9)]
) +
coord_flip() +
gen_theme +
scale_color_manual(values = cb_palette[c(2, 3, 1)], name = "") +
facet_wrap(~ reorder(metric_t, metric), scales = "free_x", nrow = 3)
ggsave("schmidt2012_conv.png", width = 6.5, height = 5.25)
dat[, SingleCaseES::NAP(Outcome[Trt == 0], Outcome[Trt == 1]), person]
dat[, SingleCaseES::PEM(Outcome[Trt == 0], Outcome[Trt == 1]), person]
dat[, SingleCaseES::Tau(Outcome[Trt == 0], Outcome[Trt == 1]), person]
dat[, SingleCaseES::LRRi(Outcome[Trt == 0], Outcome[Trt == 1]), person]
dat[, SingleCaseES::LRM(Outcome[Trt == 0], Outcome[Trt == 1]), person]
# ----
dat <- dat_full[Behavior == "Responses"]
dat_list <- list(
N = nrow(dat), Ncase = max(dat$case), Ntime = max(dat$time),
case_id = dat$case, time_id = dat$time, y_s = unique(sort(dat$Outcome)),
y_ord = as.integer(ordered(dat$Outcome)),
treat = as.integer(dat$Trt == 1)
)
dat_list$Nmax <- max(dat_list$y_ord)
dat_list$SplMat <- splines2::bSpline(dat_list$time_id, df = 15 + 3, degree = 3)
dat_list$Nspl <- ncol(dat_list$SplMat)
dat_list$CutMat <- -.5 + splines2::iSpline(
1:(dat_list$Nmax - 1),
df = min(15 + 3, dat_list$Nmax - 2), degree = 2
)
dat_list$Ncut_spl <- ncol(dat_list$CutMat)
dat_list
(gm_count <- dat[
, .N, list(person, case, x = as.integer(Trt == 1) + 1)
][order(case, x)])
dat_list$Ncount <- nrow(gm_count)
dat_list$count <- gm_count$N
dat_list
hier_spl_mod <- cmdstanr::cmdstan_model("hier_spline_cuts_spl.stan")
fit <- hier_spl_mod$sample(
data = dat_list, seed = 12345, iter_warmup = 1e3, iter_sampling = 1e3,
chains = 4, parallel_chains = 4
)
fit$cmdstan_diagnose()
print(fit, c(
"intercept", "cut_spl_sd", "cut_spl", "spl_sds", "beta_sds",
"beta", "beta_s"
))
print(fit$summary(c(
# "mean_s",
"mean_diff", "log_mean_ratio",
# "median_s",
"median_diff", "log_median_ratio",
"nap", "pem", "tau"
)), n = 40)
yhat_s <- fit$summary("y_hat")
dat_merge <- cbind(dat, yhat_s[, c("median", "q5", "q95")])
setnames(dat_merge, c("median", "q5", "q95"), c("md", "lo", "hi"))
dat_merge
ggplot(dat_merge, aes(time, Outcome)) +
geom_point(aes(col = factor(Trt))) +
geom_line(linewidth = .125) +
facet_wrap(~person, ncol = 2) +
theme_bw() +
geom_line(aes(y = md), linetype = 2) +
theme(legend.position = "top") +
geom_ribbon(aes(ymin = lo, ymax = hi), alpha = .2)
tmp_gm <- gm_count[, .(x = sum(x), N = sum(N)), list(person, case)]
(gmean_s <- fit$summary(
c(
"mean_s", "mean_diff", "log_mean_ratio",
"median_s", "median_diff", "log_median_ratio",
"nap", "pem", "tau"
),
~ quantile(.x, probs = c(.025, .1, .5, .9, .975)),
posterior::default_convergence_measures()
)[, c("2.5%", "10%", "50%", "90%", "97.5%", "rhat", "ess_bulk", "ess_tail")])
gmean_sum <- as.data.table(cbind(
rbindlist(
list(
gm_count, tmp_gm, tmp_gm, gm_count, tmp_gm, tmp_gm,
tmp_gm, tmp_gm, tmp_gm
),
idcol = "metric"
),
gmean_s
))
gmean_sum[, metric_t := c(
"means", "mean diff", "log(mean ratio)",
"medians", "median diff", "log(median ratio)",
"NAP", "PEM", "TAU"
)[metric]]
gmean_sum[, x_lab := factor(c("A", "B", "B > A")[x], c("A", "B", "B > A"))]
gmean_sum[, null := c(NA, 0, 0, NA, 0, 0, .5, .5, 0)[metric]]
gmean_sum
ggplot(gmean_sum, aes(reorder(person, -case), y = `50%`, col = x_lab)) +
# Pointrange w/ 95% interval
geom_pointrange(
aes(ymin = `2.5%`, ymax = `97.5%`),
position = position_dodge(.5)
) +
# Bar w/ 80% interval
geom_linerange(aes(ymin = `10%`, ymax = `90%`),
alpha = .25, linewidth = 3,
position = position_dodge(.5)
) +
# Null conditions
geom_hline(aes(yintercept = null), linetype = 2, alpha = .125) +
geom_text(aes(label = number(`50%`, .01)),
position = position_dodge(.5), vjust = -.7,
size = 3, data = gmean_sum[metric %in% c(2, 5)]
) +
geom_text(aes(label = percent(`50%`, 1)),
position = position_dodge(.5), vjust = -.7,
size = 3, data = gmean_sum[metric %in% c(7:8)]
) +
geom_text(aes(label = paste0(ifelse(`50%` > 0, "+", ""), percent(`50%`, 1))),
position = position_dodge(.5), vjust = -.7,
size = 3, data = gmean_sum[metric %in% c(3, 6, 9)]
) +
coord_flip() +
gen_theme +
scale_color_manual(values = cb_palette[c(2, 3, 1)], name = "") +
facet_wrap(~ reorder(metric_t, metric), scales = "free_x", nrow = 3)
ggsave("schmidt2012_resp.png", width = 6.5, height = 5.25)
dat[, SingleCaseES::NAP(Outcome[Trt == 0], Outcome[Trt == 1]), person]
dat[, SingleCaseES::PEM(Outcome[Trt == 0], Outcome[Trt == 1]), person]
dat[, SingleCaseES::Tau(Outcome[Trt == 0], Outcome[Trt == 1]), person]
dat[, SingleCaseES::LRRi(Outcome[Trt == 0], Outcome[Trt == 1]), person]
dat[, SingleCaseES::LRM(Outcome[Trt == 0], Outcome[Trt == 1]), person]
@jamesuanhoro
Copy link
Author

Stan script followed by example use

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment