Skip to content

Instantly share code, notes, and snippets.

@yguoyguo
Last active August 29, 2015 14:23
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save yguoyguo/5ac1d64251b22180ac9a to your computer and use it in GitHub Desktop.
Save yguoyguo/5ac1d64251b22180ac9a to your computer and use it in GitHub Desktop.
Repeated Re-randomization Simulations
require(data.table)
require(ggplot2)
require(reshape2)
require(MASS)
require(gridExtra)
sim_data <- function(n = 50000, Sigma = diag(3), lognormal = F, discrete = F){
m <- data.table(mvrnorm(n, mu = 1:3, Sigma = Sigma))
setnames(m, c('V1', 'V2', 'V3'), c('metricX', 'metricY', 'metricZ'))
m[, trt := factor(sample(0:1, size = n, replace = T))]
if (lognormal)
m[, metricY := exp(metricY)]
if (discrete) {
cuts <- m[, quantile(metricZ, (0:4)/4)]
cuts[4] <- cuts[4] + 1 # so max value is not excluded in cut
cuts[0] <- cuts[0] - 1
m[, metricZ := cut(metricZ, c(-1, cuts))][, metricZ := as.numeric(as.factor(metricZ))]
}
m
}
sum_stats <- function(m){
stats <- melt(m, id.vars = 'trt', variable.name = 'metric')[, list(mu = mean(value), se = sd(value)/sqrt(.N)), by = .(trt, metric)]
stats <- dcast.data.table(melt(stats, id.vars = c('trt', 'metric')), metric ~ variable + trt)
stats[, delta := mu_1 - mu_0]
stats[, se := sqrt(se_0^2+se_1^2)]
stats[, lower := delta-1.96*se]
stats[, upper := delta+1.96*se]
stats[, pv := 2 * pnorm(abs(delta/se), lower = F)]
stats
}
sim_K <- function(rerandomize = F, K = 100, threshold = .8, Sigma = diag(3), lognormal = F, discrete = F, n = 50000){
statsm <- NULL
for(i in 1:K) {
if (rerandomize) {
has_diff <- TRUE
while(has_diff == TRUE){
stats <- sum_stats(sim_data(n, Sigma, lognormal, discrete))
has_diff <- stats[, any(pv<threshold, na.rm=T)]
}
} else {
stats <- sum_stats(sim_data(n, Sigma, lognormal, discrete))
}
statsm <- rbind(statsm, cbind(i, stats))
}
rand_method <- if(rerandomize) 'Repeated Re-randomization' else 'Single Randomization'
cbind(type=rand_method,
melt(statsm[, list(i, delta, lower, upper, metric, pv)],
id.vars = c('i', 'pv','metric')))
}
plot_result <- function(dt, metrici, ranges, plot.title=F, show.ylab=F, typei=typei) {
ggplot(dt[metric == metrici & type == typei], aes(value, i)) +
geom_point(aes(color = pv<.05)) +
geom_line(aes(group = i, width = (pv<.05)+1, color = pv<.05, alpha = (pv<.05))) +
geom_vline(xintercept = 0, color = 'gray') +
scale_alpha_discrete(range = c(.6, 1)) +
guides(color = F, alpha = F) +
theme_minimal() +
scale_y_continuous(breaks = NULL) +
xlim(as.numeric(ranges[metric == metrici, list(min, max)])) +
xlab(metrici) +
ylab(if (show.ylab) 'Number of simulations' else '') +
ggtitle(if(plot.title) typei else '')
}
comp_plot <- function(K = 100, threshold = .8, Sigma = diag(3), lognormal = F, discrete = F, n = 50000) {
without_rr <- sim_K(rerandomize = F, K, threshold, Sigma, lognormal, discrete, n)
with_rr <- sim_K(rerandomize = T, K, threshold, Sigma, lognormal, discrete, n)
dt <- rbind(without_rr, with_rr)
ranges <- dt[, list(min=min(value), max=max(value)), by = metric]
rerands <- c('Single Randomization', 'Repeated Re-randomization')
plots <- list()
for(i in 1:length(rerands)) {
for(metrici in as.character(sort(unique(dt$metric)))) {
plots[[paste0('plot_', metrici,'_',i)]] <-
plot_result(dt, metrici, ranges,
plot.title = metrici == 'metricY',
show.ylab = metrici == 'metricX',
typei = rerands[i])
}
}
plots[['nrow']] = 2
do.call(grid.arrange, plots)
do.call(arrangeGrob, plots)
}
main <- function(K = 100, seed=12345) {
set.seed(seed)
dir.create('plots')
Sigma <- diag(3)
ggsave(filename = 'Independent-Normal.png', path = 'plots',
comp_plot(K, threshold = .8, Sigma = Sigma, lognormal = F, discrete = F))
ggsave(filename = 'Lognormal-Y.png', path = 'plots',
comp_plot(K, threshold = .8, Sigma = Sigma, lognormal = T, discrete = F))
ggsave(filename = 'Discrete-Z.png', path = 'plots',
comp_plot(K, threshold = .8, Sigma = Sigma, lognormal = F, discrete = T))
Sigma <- matrix(c(1, .5, .2, .5, 1, -.1, .2, -.1, 1), nrow = 3)
ggsave(filename = 'Covariance-Structure.png', path = 'plots',
comp_plot(K, threshold = .8, Sigma = Sigma, lognormal = F, discrete = F))
}
main(20)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment