Skip to content

Instantly share code, notes, and snippets.

@MatsuuraKentaro
Last active November 30, 2023 09:23
Show Gist options
  • Save MatsuuraKentaro/feca27d68c2c739483632459b1af4c1f to your computer and use it in GitHub Desktop.
Save MatsuuraKentaro/feca27d68c2c739483632459b1af4c1f to your computer and use it in GitHub Desktop.
Test of Bayesian IPW implemented by A. Jordan Nafa
#include <ostream>
static int iteration_index = 1;
inline void add_iter(std::ostream* pstream__) {
iteration_index += 1;
}
inline int get_iter(std::ostream* pstream__) {
return iteration_index;
}
data {
int<lower=1> N;
array[N] int<lower=0, upper=1> A; // intervention
vector<lower=0, upper=1>[N] L; // covariate
}
transformed data {
vector[N] A_vec = to_vector(A);
}
parameters {
vector[2] b;
}
model {
A[1:N] ~ bernoulli_logit(b[1] + b[2]*L[1:N]);
}
generated quantities {
vector[N] e = inv_logit(b[1] + b[2]*L[1:N]);
vector[N] ipw = A_vec ./ e + (1 - A_vec) ./ (1 - e);
}
functions {
void add_iter();
int get_iter();
}
data {
int<lower=1> N;
vector[N] Y; // outcome
array[N] int<lower=0, upper=1> A; // intervention
int<lower=1> N_ms;
matrix[N, N_ms] IPW;
}
parameters {
vector[2] b;
real<lower=0> sigma;
}
model {
int m = get_iter();
for (n in 1:N) {
target += IPW[n,m] * normal_lpdf(Y[n] | b[1] + b[2]*A[n], sigma);
}
}
generated quantities {
add_iter();
}
library(dplyr)
library(cmdstanr)
set.seed(123)
N <- 400
L <- sample.int(n=2, size=N, replace=TRUE) - 1
A <- rbinom(n=N, size=1, prob=plogis(-3 + 3*L))
Y <- rnorm(n=N, mean=-0.4 + A*1.0 + L*1.5, sd=0.2)
d <- data.frame(PersonID=1:N, A=A, Y=Y, L=L)
data_int <- list(N=N, A=A, L=L)
model_int <- cmdstan_model(stan_file='model/model-intervention.stan')
fit_int <- model_int$sample(data=data_int, seed=123, parallel_chains=4,
iter_sampling=2000, iter_warmup=1000)
ipw_ms <- fit_int$draws('ipw', format='matrix') %>% t()
N_ms <- ncol(ipw_ms)
data_MSM <- list(N=N, Y=Y, A=A, N_ms=N_ms, IPW=ipw_ms)
model_MSM <- cmdstan_model(stan_file='model/model-MSM.stan',
user_header='model/iterfuns.hpp')
fit_MSM <- model_MSM$sample(data=data_MSM, seed=123, parallel_chains=4,
iter_sampling=2000, iter_warmup=1000)
print(fit_MSM)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment