Skip to content

Instantly share code, notes, and snippets.

@stonegold546
Last active May 5, 2021 16:56
Show Gist options
  • Save stonegold546/c17092b8bae4a090bab79609f70bc538 to your computer and use it in GitHub Desktop.
Save stonegold546/c17092b8bae4a090bab79609f70bc538 to your computer and use it in GitHub Desktop.
Polychoric correlations with Stan (accounts for missingness)
library(lavaan)
library(rstan)
library(cmdstanr)
set_cmdstan_path("~/cmdstan/")
source("create_poly_stan_missing.R")
HS9 <- HolzingerSwineford1939[, paste0("x", 1:9)]
# Pearson correlations
cor(HS9)
# ordinal version, with three categories
HS9ord <- as.data.frame(lapply(HS9, cut, 5, labels = FALSE))
HS9ord
# Data must have minimum values of 1
# Create data format for Stan
dat.list <- create_dat_list_miss(HS9ord)
# Create Stan script
create_poly_stan_miss(dat.list$Ni, dat.list$n_ord, "stan_scripts/hs9.stan")
# Compile script
hs9_example <- cmdstan_model("stan_scripts/hs9.stan")
hs9.fit <- hs9_example$sample(
data = dat.list, seed = 12345, iter_warmup = 5e2,
iter_sampling = 5e2, chains = 3, parallel_chains = 3)
hs9.fit$cmdstan_diagnose()
hs9.fit <- read_stan_csv(hs9.fit$output_files())
# View thresholds
print(hs9.fit, paste0("cutpoints_", 1:dat.list$Ni), probs = c(.025, .5, .975), digits_summary = 3)
lavCor(HS9ord, ordered = names(HS9ord), output = "th") # Results from lavaan
# View correlation matrix
print(hs9.fit, c(
apply(which(lower.tri(matrix(NA, dat.list$Ni, dat.list$Ni)), arr.ind = TRUE), 1,
function (x) { paste0("R[", paste0(x, collapse = ","), "]") })),
probs = c(.025, .5, .975), digits_summary = 3)
lavCor(HS9ord, ordered = names(HS9ord), output = "est")[10:45, ] # Results from lavaan
# Show correlation matrix
filter_interval(as.data.frame(hs9.fit, "R"), interval = .001)
# Show correlation matrix setting to 0 correlations whose 95% intervals exclude 0
filter_interval(as.data.frame(hs9.fit, "R"), interval = .95)
# View partial correlation matrix
print(hs9.fit, c(
apply(which(lower.tri(matrix(NA, dat.list$Ni, dat.list$Ni)), arr.ind = TRUE), 1,
function (x) { paste0("P[", paste0(x, collapse = ","), "]") })),
probs = c(.025, .5, .975), digits_summary = 3)
# Partial correlation matrix
filter_interval(as.data.frame(hs9.fit, "P"), interval = .001)
# Partial correlation matrix setting to 0 correlations whose 95% intervals include 0
filter_interval(as.data.frame(hs9.fit, "P"), interval = .95)
# Sample Gaussian graphical model
qgraph::qgraph(filter_interval(as.data.frame(hs9.fit, "P"), interval = .95))
# -------
# HS9 with missing
HS9ord_miss <- apply(HS9ord, 2, function (x) { x[sample(1:length(x), 20)] <- NA; x })
dim(na.omit(HS9ord_miss))
HS9ord_miss[5, ] <- NA # sample row completely missing, gets dropped
HS9ord_miss
dat.list.miss <- create_dat_list_miss(HS9ord_miss)
hs9_example <- cmdstan_model("stan_scripts/hs9")
hs9.miss.fit <- hs9_example$sample(
data = dat.list.miss, seed = 12345, iter_warmup = 5e2,
iter_sampling = 5e2, chains = 3, parallel_chains = 3)
hs9.miss.fit$cmdstan_diagnose()
hs9.miss.fit <- read_stan_csv(hs9.miss.fit$output_files())
# View thresholds
print(hs9.fit, paste0("cutpoints_", 1:dat.list$Ni), probs = c(.025, .5, .975), digits_summary = 3)
print(hs9.miss.fit, paste0("cutpoints_", 1:dat.list.miss$Ni), probs = c(.025, .5, .975), digits_summary = 3)
lavCor(HS9ord, ordered = names(HS9ord), output = "th") # Results from lavaan
# View correlation matrix
print(hs9.miss.fit, c(
apply(which(lower.tri(matrix(NA, dat.list.miss$Ni, dat.list.miss$Ni)), arr.ind = TRUE), 1,
function (x) { paste0("R[", paste0(x, collapse = ","), "]") })),
probs = c(.025, .5, .975), digits_summary = 3)
lavCor(HS9ord, ordered = names(HS9ord), output = "est")[10:45, ] # Results from lavaan
# Show correlation matrix
filter_interval(as.data.frame(hs9.miss.fit, "R"), interval = .001)
# Show correlation matrix setting to 0 correlations whose 95% intervals exclude 0
filter_interval(as.data.frame(hs9.miss.fit, "R"), interval = .95)
# View partial correlation matrix
print(hs9.miss.fit, c(
apply(which(lower.tri(matrix(NA, dat.list.miss$Ni, dat.list.miss$Ni)), arr.ind = TRUE), 1,
function (x) { paste0("P[", paste0(x, collapse = ","), "]") })),
probs = c(.025, .5, .975), digits_summary = 3)
# Partial correlation matrix
filter_interval(as.data.frame(hs9.miss.fit, "P"), interval = .001)
# Partial correlation matrix setting to 0 correlations whose 95% intervals exclude 0
filter_interval(as.data.frame(hs9.miss.fit, "P"), interval = .95)
# Sample Gaussian graphical model
qgraph::qgraph(filter_interval(as.data.frame(hs9.miss.fit, "P"), interval = .95))
create_poly_stan_miss <- function (Ni, n_ord, file = "stan_scripts/polychoric_varying_cut.stan") {
code <- "// Script produced by other script based on number of levels per item and number of items
functions {
matrix cov2cor (matrix C_mat) {
int p = dims(C_mat)[1];
vector[p] S_i = 1 ./ sqrt(diagonal(C_mat));
matrix[p, p] R_mat = -quad_form_diag(C_mat, S_i);
for (i in 1:p) R_mat[i, i] = 1;
return R_mat;
}
}
data {
real<lower = 0> shape_phi_c; // lkj prior shape for phi
int<lower = 0> Np; // N_persons
int<lower = 0> Ni; // N_items
int n_ord[Ni]; // number of levels per item
int<lower = -1> input_data[Np, Ni]; // input matrix, -1 for missing, real data begin at 1
int count_ord[Ni, max(n_ord) + 1];
int<lower = 0, upper = 1> ggm;
}
transformed data {
int Ni_ggm = 0;
int N_miss = sum(count_ord[, max(n_ord) + 1]);
if (ggm) Ni_ggm = Ni;
}
parameters {
"
cutpoints <- paste(sapply(1:Ni, function (i) {
paste0(" ordered[n_ord[", i, "] - 1] cutpoints_", i, ";")
}), collapse = "\n")
code <- paste0(code, " // cutpoints\n", cutpoints)
rm(cutpoints)
# writeLines(code)
z.vars <- "\n // latent variables"
for (i in 1:Ni) {
for (j in 1:n_ord[i]) {
if (j == 1) {
z.vars <- paste0(
z.vars, "\n",
paste0(" vector<upper = cutpoints_", i, "[", j, "]>[count_ord[", i, ",", j, "]] z", i, "_", j, ";"))
} else if (j == n_ord[i]) {
z.vars <- paste0(
z.vars, "\n",
paste0(" vector<lower = cutpoints_", i, "[", j - 1, "]>[count_ord[", i, ",", j, "]] z", i, "_", j, ";"))
} else {
z.vars <- paste0(
z.vars, "\n",
paste0(" vector<lower = cutpoints_", i, "[", j - 1, "], upper = cutpoints_", i, "[", j, "]>[count_ord[", i, ",", j, "]] z", i, "_", j, ";"))
}
}
}
rm(i, j)
# writeLines(z.vars)
code <- paste0(code, z.vars)
rm(z.vars)
# writeLines(code)
code <- paste0(code, "\n vector[N_miss] z_miss; // latent missing data")
code <- paste0(code, "\n cholesky_factor_corr[Ni] R_chol; // Cor_mat cholesky")
# writeLines(code)
code <- paste0(code, "\n}
model {
int pos[sum(n_ord)] = rep_array(0, sum(n_ord));
int pos_miss = 0;
int loc1;
int loc2;
vector[Ni] z[Np];
")
z.consolidate <- ""
for (i in 1:Ni) {
part.1 <- paste0(" vector[sum(count_ord[", i, ", 1:n_ord[", i, "]])] z", i, " = ")
for (j in n_ord[i]:2) {
if (j == n_ord[i]) {
mid <- paste0("append_row(z", i, "_", j - 1, ",z", i, "_", j, ")")
} else {
mid <- paste0("append_row(z", i, "_", j - 1, ",", mid, ")")
}
}
z.consolidate <- paste0(z.consolidate, paste0(part.1, mid, ";\n"))
}
rm(part.1, mid, i, j)
# writeLines(z.consolidate)
code <- paste0(code, z.consolidate)
rm(z.consolidate)
# writeLines(code)
code <- paste0(code, "\n R_chol ~ lkj_corr_cholesky(shape_phi_c);\n")
# writeLines(code)
cutpoints.prior <- paste(sapply(1:Ni, function (i) {
paste0(" cutpoints_", i, " ~ std_normal();")
}), collapse = "\n")
code <- paste0(code, cutpoints.prior)
rm(cutpoints.prior)
# writeLines(code)
fill_lat_var <- "\n\n for (i in 1:Np) {"
for (j in 1:Ni) {
part.1 <- paste0("loc1 = ", ifelse(j == 1, "0;", paste0("sum(n_ord[1:", j - 1, "]);")))
part.2 <- paste0("pos[loc1 + input_data[i, ", j, "]] += 1;")
part.3 <- paste0("loc2 = sum(count_ord[", j, ", 1:input_data[i, ", j, "]]) - count_ord[", j, ", input_data[i, ", j, "]] + pos[loc1 + input_data[i, ", j, "]];")
part.4 <- paste0("z[i, ", j, "] = z", j, "[loc2];")
fill_lat_var <- paste0(
fill_lat_var,
"\n if (input_data[i, ", j, "] == -1) {",
"\n pos_miss += 1;",
"\n z[i, ", j, "] = z_miss[pos_miss];",
"\n } else {",
paste0("\n ", part.1, "\n ", part.2, "\n ", part.3, "\n ", part.4),
"\n }\n")
}
rm(part.1, part.2, part.3, part.4, j)
fill_lat_var <- paste0(fill_lat_var, " }")
# writeLines(fill_lat_var)
code <- paste0(code, fill_lat_var)
rm(fill_lat_var)
# writeLines(code)
code <- paste0(code,
"\n\n z ~ multi_normal_cholesky(rep_vector(0, Ni), R_chol);
}
generated quantities {
matrix[Ni, Ni] R = multiply_lower_tri_self_transpose(R_chol);
matrix[Ni_ggm, Ni_ggm] P;
if (ggm) P = cov2cor(inverse_spd(R));
}")
# writeLines(code)
if (file.exists(file)) {
old_code <- paste0(readLines(con = file), collapse = "\n")
if (old_code != code) {
writeLines(code, con = file)
}
} else {
writeLines(code, con = file)
}
}
create_dat_list_miss <- function (data, ggm = 1) {
X <- apply(data, 2, function (x)
as.integer(as.ordered(x)))
X <- X[rowSums(X, na.rm = TRUE) > 0, ]
res <- list(
shape_phi_c = sqrt(ncol(X)), Np = nrow(X), Ni = ncol(X),
n_ord = unname(apply(X, 2, function (x) length(na.omit(unique(x))))),
input_data = X, ggm = ggm)
res$count_ord <- t(apply(X, 2, function (x) {
sapply(1:max(X, na.rm = TRUE), function (m) sum(x == m, na.rm = TRUE))
}))
res$count_ord <- cbind(res$count_ord, apply(X, 2, function (x) sum(is.na(x))))
res$input_data[is.na(res$input_data)] <- -1
res
}
filter_interval <- function (post, interval = .95, mdn = 0) {
n_var <- sqrt(ncol(post))
ll <- (1 - interval) / 2
ul <- 1 - ll
q.lo <- apply(post, 2, quantile, probs = ll)
if (mdn) {
q.mdn <- apply(post, 2, median)
} else {
q.mdn <- colMeans(post)
}
q.hi <- apply(post, 2, quantile, probs = ul)
q.mdn[(q.lo * q.hi) < 0] <- 0
matrix(q.mdn, n_var)
}
@stonegold546
Copy link
Author

stonegold546 commented Jan 12, 2021

create_poly_stan.R contains needed functions. Other file is example use with command Stan.

  • Works for binary and ordinal variables
  • A single dataset can contain both binary and ordinal data
  • Also returns the partial correlation matrix
  • May break sometimes :).
  • Should account for missing data, assumes missing data are latent variables without bounds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment