Last active
May 5, 2021 16:56
-
-
Save stonegold546/c17092b8bae4a090bab79609f70bc538 to your computer and use it in GitHub Desktop.
Polychoric correlations with Stan (accounts for missingness)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(lavaan) | |
library(rstan) | |
library(cmdstanr) | |
set_cmdstan_path("~/cmdstan/") | |
source("create_poly_stan_missing.R") | |
HS9 <- HolzingerSwineford1939[, paste0("x", 1:9)] | |
# Pearson correlations | |
cor(HS9) | |
# ordinal version, with three categories | |
HS9ord <- as.data.frame(lapply(HS9, cut, 5, labels = FALSE)) | |
HS9ord | |
# Data must have minimum values of 1 | |
# Create data format for Stan | |
dat.list <- create_dat_list_miss(HS9ord) | |
# Create Stan script | |
create_poly_stan_miss(dat.list$Ni, dat.list$n_ord, "stan_scripts/hs9.stan") | |
# Compile script | |
hs9_example <- cmdstan_model("stan_scripts/hs9.stan") | |
hs9.fit <- hs9_example$sample( | |
data = dat.list, seed = 12345, iter_warmup = 5e2, | |
iter_sampling = 5e2, chains = 3, parallel_chains = 3) | |
hs9.fit$cmdstan_diagnose() | |
hs9.fit <- read_stan_csv(hs9.fit$output_files()) | |
# View thresholds | |
print(hs9.fit, paste0("cutpoints_", 1:dat.list$Ni), probs = c(.025, .5, .975), digits_summary = 3) | |
lavCor(HS9ord, ordered = names(HS9ord), output = "th") # Results from lavaan | |
# View correlation matrix | |
print(hs9.fit, c( | |
apply(which(lower.tri(matrix(NA, dat.list$Ni, dat.list$Ni)), arr.ind = TRUE), 1, | |
function (x) { paste0("R[", paste0(x, collapse = ","), "]") })), | |
probs = c(.025, .5, .975), digits_summary = 3) | |
lavCor(HS9ord, ordered = names(HS9ord), output = "est")[10:45, ] # Results from lavaan | |
# Show correlation matrix | |
filter_interval(as.data.frame(hs9.fit, "R"), interval = .001) | |
# Show correlation matrix setting to 0 correlations whose 95% intervals exclude 0 | |
filter_interval(as.data.frame(hs9.fit, "R"), interval = .95) | |
# View partial correlation matrix | |
print(hs9.fit, c( | |
apply(which(lower.tri(matrix(NA, dat.list$Ni, dat.list$Ni)), arr.ind = TRUE), 1, | |
function (x) { paste0("P[", paste0(x, collapse = ","), "]") })), | |
probs = c(.025, .5, .975), digits_summary = 3) | |
# Partial correlation matrix | |
filter_interval(as.data.frame(hs9.fit, "P"), interval = .001) | |
# Partial correlation matrix setting to 0 correlations whose 95% intervals include 0 | |
filter_interval(as.data.frame(hs9.fit, "P"), interval = .95) | |
# Sample Gaussian graphical model | |
qgraph::qgraph(filter_interval(as.data.frame(hs9.fit, "P"), interval = .95)) | |
# ------- | |
# HS9 with missing | |
HS9ord_miss <- apply(HS9ord, 2, function (x) { x[sample(1:length(x), 20)] <- NA; x }) | |
dim(na.omit(HS9ord_miss)) | |
HS9ord_miss[5, ] <- NA # sample row completely missing, gets dropped | |
HS9ord_miss | |
dat.list.miss <- create_dat_list_miss(HS9ord_miss) | |
hs9_example <- cmdstan_model("stan_scripts/hs9") | |
hs9.miss.fit <- hs9_example$sample( | |
data = dat.list.miss, seed = 12345, iter_warmup = 5e2, | |
iter_sampling = 5e2, chains = 3, parallel_chains = 3) | |
hs9.miss.fit$cmdstan_diagnose() | |
hs9.miss.fit <- read_stan_csv(hs9.miss.fit$output_files()) | |
# View thresholds | |
print(hs9.fit, paste0("cutpoints_", 1:dat.list$Ni), probs = c(.025, .5, .975), digits_summary = 3) | |
print(hs9.miss.fit, paste0("cutpoints_", 1:dat.list.miss$Ni), probs = c(.025, .5, .975), digits_summary = 3) | |
lavCor(HS9ord, ordered = names(HS9ord), output = "th") # Results from lavaan | |
# View correlation matrix | |
print(hs9.miss.fit, c( | |
apply(which(lower.tri(matrix(NA, dat.list.miss$Ni, dat.list.miss$Ni)), arr.ind = TRUE), 1, | |
function (x) { paste0("R[", paste0(x, collapse = ","), "]") })), | |
probs = c(.025, .5, .975), digits_summary = 3) | |
lavCor(HS9ord, ordered = names(HS9ord), output = "est")[10:45, ] # Results from lavaan | |
# Show correlation matrix | |
filter_interval(as.data.frame(hs9.miss.fit, "R"), interval = .001) | |
# Show correlation matrix setting to 0 correlations whose 95% intervals exclude 0 | |
filter_interval(as.data.frame(hs9.miss.fit, "R"), interval = .95) | |
# View partial correlation matrix | |
print(hs9.miss.fit, c( | |
apply(which(lower.tri(matrix(NA, dat.list.miss$Ni, dat.list.miss$Ni)), arr.ind = TRUE), 1, | |
function (x) { paste0("P[", paste0(x, collapse = ","), "]") })), | |
probs = c(.025, .5, .975), digits_summary = 3) | |
# Partial correlation matrix | |
filter_interval(as.data.frame(hs9.miss.fit, "P"), interval = .001) | |
# Partial correlation matrix setting to 0 correlations whose 95% intervals exclude 0 | |
filter_interval(as.data.frame(hs9.miss.fit, "P"), interval = .95) | |
# Sample Gaussian graphical model | |
qgraph::qgraph(filter_interval(as.data.frame(hs9.miss.fit, "P"), interval = .95)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
create_poly_stan_miss <- function (Ni, n_ord, file = "stan_scripts/polychoric_varying_cut.stan") { | |
code <- "// Script produced by other script based on number of levels per item and number of items | |
functions { | |
matrix cov2cor (matrix C_mat) { | |
int p = dims(C_mat)[1]; | |
vector[p] S_i = 1 ./ sqrt(diagonal(C_mat)); | |
matrix[p, p] R_mat = -quad_form_diag(C_mat, S_i); | |
for (i in 1:p) R_mat[i, i] = 1; | |
return R_mat; | |
} | |
} | |
data { | |
real<lower = 0> shape_phi_c; // lkj prior shape for phi | |
int<lower = 0> Np; // N_persons | |
int<lower = 0> Ni; // N_items | |
int n_ord[Ni]; // number of levels per item | |
int<lower = -1> input_data[Np, Ni]; // input matrix, -1 for missing, real data begin at 1 | |
int count_ord[Ni, max(n_ord) + 1]; | |
int<lower = 0, upper = 1> ggm; | |
} | |
transformed data { | |
int Ni_ggm = 0; | |
int N_miss = sum(count_ord[, max(n_ord) + 1]); | |
if (ggm) Ni_ggm = Ni; | |
} | |
parameters { | |
" | |
cutpoints <- paste(sapply(1:Ni, function (i) { | |
paste0(" ordered[n_ord[", i, "] - 1] cutpoints_", i, ";") | |
}), collapse = "\n") | |
code <- paste0(code, " // cutpoints\n", cutpoints) | |
rm(cutpoints) | |
# writeLines(code) | |
z.vars <- "\n // latent variables" | |
for (i in 1:Ni) { | |
for (j in 1:n_ord[i]) { | |
if (j == 1) { | |
z.vars <- paste0( | |
z.vars, "\n", | |
paste0(" vector<upper = cutpoints_", i, "[", j, "]>[count_ord[", i, ",", j, "]] z", i, "_", j, ";")) | |
} else if (j == n_ord[i]) { | |
z.vars <- paste0( | |
z.vars, "\n", | |
paste0(" vector<lower = cutpoints_", i, "[", j - 1, "]>[count_ord[", i, ",", j, "]] z", i, "_", j, ";")) | |
} else { | |
z.vars <- paste0( | |
z.vars, "\n", | |
paste0(" vector<lower = cutpoints_", i, "[", j - 1, "], upper = cutpoints_", i, "[", j, "]>[count_ord[", i, ",", j, "]] z", i, "_", j, ";")) | |
} | |
} | |
} | |
rm(i, j) | |
# writeLines(z.vars) | |
code <- paste0(code, z.vars) | |
rm(z.vars) | |
# writeLines(code) | |
code <- paste0(code, "\n vector[N_miss] z_miss; // latent missing data") | |
code <- paste0(code, "\n cholesky_factor_corr[Ni] R_chol; // Cor_mat cholesky") | |
# writeLines(code) | |
code <- paste0(code, "\n} | |
model { | |
int pos[sum(n_ord)] = rep_array(0, sum(n_ord)); | |
int pos_miss = 0; | |
int loc1; | |
int loc2; | |
vector[Ni] z[Np]; | |
") | |
z.consolidate <- "" | |
for (i in 1:Ni) { | |
part.1 <- paste0(" vector[sum(count_ord[", i, ", 1:n_ord[", i, "]])] z", i, " = ") | |
for (j in n_ord[i]:2) { | |
if (j == n_ord[i]) { | |
mid <- paste0("append_row(z", i, "_", j - 1, ",z", i, "_", j, ")") | |
} else { | |
mid <- paste0("append_row(z", i, "_", j - 1, ",", mid, ")") | |
} | |
} | |
z.consolidate <- paste0(z.consolidate, paste0(part.1, mid, ";\n")) | |
} | |
rm(part.1, mid, i, j) | |
# writeLines(z.consolidate) | |
code <- paste0(code, z.consolidate) | |
rm(z.consolidate) | |
# writeLines(code) | |
code <- paste0(code, "\n R_chol ~ lkj_corr_cholesky(shape_phi_c);\n") | |
# writeLines(code) | |
cutpoints.prior <- paste(sapply(1:Ni, function (i) { | |
paste0(" cutpoints_", i, " ~ std_normal();") | |
}), collapse = "\n") | |
code <- paste0(code, cutpoints.prior) | |
rm(cutpoints.prior) | |
# writeLines(code) | |
fill_lat_var <- "\n\n for (i in 1:Np) {" | |
for (j in 1:Ni) { | |
part.1 <- paste0("loc1 = ", ifelse(j == 1, "0;", paste0("sum(n_ord[1:", j - 1, "]);"))) | |
part.2 <- paste0("pos[loc1 + input_data[i, ", j, "]] += 1;") | |
part.3 <- paste0("loc2 = sum(count_ord[", j, ", 1:input_data[i, ", j, "]]) - count_ord[", j, ", input_data[i, ", j, "]] + pos[loc1 + input_data[i, ", j, "]];") | |
part.4 <- paste0("z[i, ", j, "] = z", j, "[loc2];") | |
fill_lat_var <- paste0( | |
fill_lat_var, | |
"\n if (input_data[i, ", j, "] == -1) {", | |
"\n pos_miss += 1;", | |
"\n z[i, ", j, "] = z_miss[pos_miss];", | |
"\n } else {", | |
paste0("\n ", part.1, "\n ", part.2, "\n ", part.3, "\n ", part.4), | |
"\n }\n") | |
} | |
rm(part.1, part.2, part.3, part.4, j) | |
fill_lat_var <- paste0(fill_lat_var, " }") | |
# writeLines(fill_lat_var) | |
code <- paste0(code, fill_lat_var) | |
rm(fill_lat_var) | |
# writeLines(code) | |
code <- paste0(code, | |
"\n\n z ~ multi_normal_cholesky(rep_vector(0, Ni), R_chol); | |
} | |
generated quantities { | |
matrix[Ni, Ni] R = multiply_lower_tri_self_transpose(R_chol); | |
matrix[Ni_ggm, Ni_ggm] P; | |
if (ggm) P = cov2cor(inverse_spd(R)); | |
}") | |
# writeLines(code) | |
if (file.exists(file)) { | |
old_code <- paste0(readLines(con = file), collapse = "\n") | |
if (old_code != code) { | |
writeLines(code, con = file) | |
} | |
} else { | |
writeLines(code, con = file) | |
} | |
} | |
create_dat_list_miss <- function (data, ggm = 1) { | |
X <- apply(data, 2, function (x) | |
as.integer(as.ordered(x))) | |
X <- X[rowSums(X, na.rm = TRUE) > 0, ] | |
res <- list( | |
shape_phi_c = sqrt(ncol(X)), Np = nrow(X), Ni = ncol(X), | |
n_ord = unname(apply(X, 2, function (x) length(na.omit(unique(x))))), | |
input_data = X, ggm = ggm) | |
res$count_ord <- t(apply(X, 2, function (x) { | |
sapply(1:max(X, na.rm = TRUE), function (m) sum(x == m, na.rm = TRUE)) | |
})) | |
res$count_ord <- cbind(res$count_ord, apply(X, 2, function (x) sum(is.na(x)))) | |
res$input_data[is.na(res$input_data)] <- -1 | |
res | |
} | |
filter_interval <- function (post, interval = .95, mdn = 0) { | |
n_var <- sqrt(ncol(post)) | |
ll <- (1 - interval) / 2 | |
ul <- 1 - ll | |
q.lo <- apply(post, 2, quantile, probs = ll) | |
if (mdn) { | |
q.mdn <- apply(post, 2, median) | |
} else { | |
q.mdn <- colMeans(post) | |
} | |
q.hi <- apply(post, 2, quantile, probs = ul) | |
q.mdn[(q.lo * q.hi) < 0] <- 0 | |
matrix(q.mdn, n_var) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
create_poly_stan.R
contains needed functions. Other file is example use with command Stan.