Skip to content

Instantly share code, notes, and snippets.

@rmcelreath
Last active April 21, 2020 18:06
Embed
What would you like to do?
Nested varying effects in ulam example
# nested varying effects in ulam
# simulate an example
# people in nations, allowing varying among people to vary by nation
N_nations <- 10
N_id_per_nation <- sample( 5:100 , size=N_nations , replace=TRUE )
N_id <- sum(N_id_per_nation)
# variation among individuals in each nation
scale_nations <- 0.3
sigma_id <- rlnorm( N_nations , log(1) , scale_nations )
# mean in each nation
a_bar <- rnorm( N_nations )
# now generate individuals
nation <- rep( 1:N_nations , times=N_id_per_nation )
a <- rnorm( N_id , a_bar[nation] , sigma_id[nation] )
# now generate observations
N_per_id <- sample( 1:30 , size=N_id , replace=TRUE )
id <- rep( 1:N_id , times=N_per_id )
N <- sum(N_per_id)
y <- c()
for ( i in 1:N_id ) {
yid <- rnorm( N_per_id[i] , a_bar[nation[i]] + a[i] , sigma_id[nation[i]] )
y <- c( y , yid )
}
# model
library(rethinking)
nation_id <- rep( nation , times=N_per_id )
dat <- list(y=y,nation=nation_id,id=id)
# version with log-normal prior for standard deviation among nations
m <- ulam(
alist(
y ~ normal( mu , tau ),
mu <- a0 + a_nat[nation] + z_id[id]*sigma_id[nation],
a0 ~ normal( 0 , 1 ),
tau ~ exponential( 1 ),
# person effects
z_id[id] ~ normal( 0 , 1 ),
# nation effects
# means
a_nat[nation] ~ normal( 0 , sigma_nations ),
sigma_nations ~ exponential( 1 ),
# standard deviations
sigma_id[nation] ~ dlnorm( mu_sigma_id , scale_sigma_id ),
mu_sigma_id ~ normal( 0 , 1 ),
scale_sigma_id ~ exponential( 1 )
) , data=dat , chains=4 , cores=4 )
precis( m , 2 , omit="z" )
# version with gaussian at level
# this is more traditional approach
# note exp() in linear model to force standard deviation positive
m2 <- ulam(
alist(
y ~ normal( mu , tau ),
mu <- a0 + a_nat[nation] + z_id[id]*exp(log_sigma_id[nation]),
a0 ~ normal( 0 , 1 ),
tau ~ exponential( 1 ),
# person effects
z_id[id] ~ normal( 0 , 1 ),
# nation effects
c(a_nat,log_sigma_id)[nation] ~ multi_normal( c(0,mu_sigma_id) , Rho_nations , Sigma_nations ),
mu_sigma_id ~ normal( 0 , 1 ),
Rho_nations ~ lkj_corr(2),
Sigma_nations ~ exponential(1)
) , data=dat , chains=4 , cores=4 )
precis( m2 , 2 , omit="z" )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment