Skip to content

Instantly share code, notes, and snippets.

@bblais
Created January 27, 2018 16:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bblais/24e863d561affc3c3977ed6e1e345e94 to your computer and use it in GitHub Desktop.
Save bblais/24e863d561affc3c3977ed6e1e345e94 to your computer and use it in GitHub Desktop.
# taken from https://github.com/fonnesbeck/stan_workshop_2016
# removed plotting and data loading to make self-contained
from numpy import array,int64
import numpy as np
pooled_data = """
data {
int<lower=0> N;
vector[N] x;
vector[N] y;
}
"""
# Next we initialize our parameters, which in this case are the linear model coefficients and the normal scale parameter. Notice that `sigma` is constrained to be positive.
# In[9]:
pooled_parameters = """
parameters {
vector[2] beta;
real<lower=0> sigma;
}
"""
# Finally, we model the log-radon measurements as a normal sample with a mean that is a function of the floor measurement.
# In[ ]:
pooled_model = """
model {
y ~ normal(beta[1] + beta[2] * x, sigma);
}
"""
# We then pass the code, data, and parameters to the `stan` function. The sampling requires specifying how many iterations we want, and how many parallel chains to sample. Here, we will sample 2 chains of length 1000.
original_data = {'N': 919,
'x': array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1,
0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,
0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0,
1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1,
0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0,
0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=int64),
'y': array([ 0.83, 0.83, 1.1 , 0.1 , 1.16, 0.96, 0.47, 0.1 , -0.22,
0.26, 0.26, 0.34, 0.41, -0.69, 0.18, 1.53, 0.34, 0.79,
1.79, 1.22, 0.64, 1.7 , 1.86, 0.69, 1.9 , 1.16, 1.93,
1.96, 2.05, 1.67, 1.53, 1.5 , 1.06, 2.1 , 0.53, 1.46,
1.7 , 1.41, 0.88, 1.1 , 0.41, 1.22, 1.1 , 0.64, -1.2 ,
0.92, 0.18, 0.83, -0.36, 0.59, 1.1 , 0.83, 0.59, 0.41,
0.69, 0.64, 0.26, 1.48, 1.53, 1.86, 1.55, 1.76, 0.83,
-0.69, 1.55, 1.5 , 1.9 , 1.03, 1.1 , 1.1 , 1.99, 1.63,
0.99, 1.63, 2.57, 1.99, 1.93, 2.56, 1.77, 2.26, 1.81,
1.36, 2.67, 0.64, 1.95, 1.57, 2.26, 0.96, 1.92, 1.41,
2.32, 0.83, 0.64, 1.25, 1.74, 1.48, 1.39, 0.34, 1.46,
-0.11, 0.74, 0.53, 2.56, 2.69, 1.57, 2.27, -2.3 , 1.34,
2.01, 0.69, 1.69, 1.41, 2.05, 0.41, 2.31, 2.25, -0.11,
1.5 , 1.63, 0.79, 0.59, 2.1 , 0. , 2.56, 0.99, 1.28,
3.28, 0.47, 2.57, 2.19, 2.98, 0.96, 2.21, 2.58, 1.31,
1.95, 1.59, 1.25, 0. , 1.25, 1.03, 0.41, 1.93, 2.42,
-2.3 , 0.96, 0.64, 0.53, 0.1 , 0. , 1.1 , 1.5 , 0.47,
1.44, 0.96, 1.92, 1.48, 1.72, 1.31, 1.06, 2.69, 1.92,
2.09, 0.99, 1.06, 1.5 , 0.59, 0.74, 0.74, 0.47, 2.27,
2.1 , 1.28, -0.11, 1.65, 1.19, 2.39, 2.12, 1.86, 1.59,
1.81, 0.18, 2.17, 2.19, 1.93, 0.88, 0.53, 1.06, 1.89,
0.59, 1.55, 1.22, 1.5 , 3.06, 2.22, 0. , 1.61, 1.63,
0.18, 2.04, 1.7 , 1.31, 1.61, 1.57, 0.41, 1.25, 1.46,
0.96, 0.41, 0.41, 0.69, 1.59, 0.41, 1.36, 2.19, 1.48,
1.5 , 1.53, 0.83, -0.51, 1.77, 1.7 , 1.99, 1.76, 2.01,
1.59, 1.93, 1.87, 1.34, 1.72, 2.07, 1.5 , 1.03, 1.25,
1.46, 0.88, 0.34, 1.67, -1.61, 0.96, 1.19, 1.19, 2.27,
1.46, 2.21, 1.86, 3.49, 2.59, 0.83, 1.74, 2.67, 1.95,
2.04, 2.29, 0.99, 3.78, 1.61, 1.61, 1.28, 1.59, 1.74,
1.28, 1.39, 1.92, 2.08, 1.22, 0.79, 0.53, 1.41, 0.64,
0.96, 2.42, 0.99, 1.39, 2.01, 0.34, 0. , -0.69, 0.96,
1.81, 0.74, 1.7 , 1.13, 1.1 , 1.72, 1.44, 1.39, 2.71,
1.99, 0.88, 1.06, 1.5 , 0.47, 2.16, 1.74, 2.16, 1.36,
0.64, 0.69, 1.72, 0.96, -0.11, 0.79, 1.06, 1.39, 1.48,
1.57, 1.06, 1.44, 0.53, 1.48, -0.22, 1.72, 1.22, 1.72,
0.96, 1.03, 2.14, 1.22, 1.19, 2.16, 0.59, 1.76, 2.57,
1.03, 1.57, 1.74, 2.63, 2.04, 1.76, 1.55, 2.04, 0.99,
1.53, 1.79, 0.83, 0.92, 1.41, 1.55, 1.55, 2.4 , 2.04,
1.13, 0.47, 0.53, 2.81, 1.16, 1.65, 1.61, 1.81, 0. ,
0.64, 1.39, 1.74, -0.69, 0.99, 1.31, 1.84, 3.17, 1.39,
1.1 , 1.13, 1.57, 1.13, 1.46, 1.36, 1.13, 1.48, 1.1 ,
1.25, 2.15, 2.21, 1.59, 1.31, 0.83, 1.06, -0.11, 0.47,
1.55, 1.34, 1.31, 1.13, 0.83, 0.69, 0.99, 0.64, 0.92,
1.48, 0.99, 0.18, 1.22, 0.96, 2.25, 0.34, 2.14, 1.63,
1.1 , 2.58, 2.73, 0.64, 1.36, 2.08, 0.99, 2.43, 1.44,
2.52, 1.92, 1.95, 1.53, 0. , 0.59, 0.41, 0.74, 0.1 ,
0.1 , 1.06, 0.34, 2.43, 2.78, 0.34, 0.34, 0.53, 0. ,
1.06, -0.51, 0.47, 1.97, -0.51, 2.32, 1.48, 1.22, 1.1 ,
2.53, 1.46, 1.53, 1.39, 1.22, 2.87, 2.37, 2.08, 1.28,
1.89, 1.95, 1.65, 2.49, 1.65, 2.2 , 1.77, 1.55, 1.39,
0.47, 3.17, 0. , 0.41, 0.18, 1.06, 3.88, 0. , 2.13,
1.44, -0.51, 1.92, 2.03, 2.23, -0.51, 0.47, 2.34, 1.39,
0.64, 2.3 , 0.88, 1.5 , 1.06, 0.18, 0.26, 0.53, 3.24,
-2.3 , 2.37, 0.88, 1.39, 1.99, 0.79, 1.19, -0.51, 1.76,
0.41, 0.79, 1.5 , 0.92, 1.61, 1.13, 1.13, 1.06, 1.39,
2.4 , 1.87, 0.74, 1.13, 1.53, 0.79, 2.09, 0.34, 2.23,
0.18, 2.37, 3.18, 2.22, 2.5 , 2.1 , 2.39, 1.46, 2.76,
1.7 , 1.84, 2.28, 2.1 , 0.53, 0.53, 1.87, 1.5 , 2.42,
2.31, 1.53, 2.09, 0.88, 1.19, 1.63, 1.44, 0.18, 0.74,
0.18, 1.1 , 0.79, 2.07, 1.36, 0.96, 1.1 , 0.59, 0.96,
2.25, -0.36, 1.03, 0.18, 0.79, 2.49, 2.54, 1.19, 1.46,
1.36, 1.34, 1.77, -0.92, 1.44, 1.06, 0.69, 0.26, 0.26,
0.47, 2.25, 0.59, 2.5 , 1.48, 1.95, 0.41, 0.96, 2.27,
1.36, 1.25, 1.93, 1.31, 0.83, 0.99, 0.79, 1.96, 0.26,
1.36, 1.28, 1.46, 0.53, 1.06, 2.16, 1.84, 1.67, 1.03,
0.26, 1.28, 1.72, 2.32, 1.72, 0.26, 1.61, 1.41, 1.28,
0.96, 0.26, 1.03, 0.59, 1.16, -0.22, 0.1 , 0.69, 1.36,
2.2 , 2.01, 3.03, 1.81, 0.79, 1.77, 2.28, 1.87, 1.55,
1.74, 2.95, 0.92, 1.13, 1.65, 2.05, 2.1 , 1.57, 2.14,
0.53, 1.81, 0.18, 2.44, 1.48, 1.31, 2.34, 1.25, 1.16,
1.31, 1.03, 1.41, 0.26, 0.59, 1.46, 2.97, 2.22, 0.74,
2.44, 2.33, 0.79, 0.26, 1.19, 0.74, 1.48, 0.83, 1.7 ,
3.23, 1.65, 0.88, 1.19, 0.96, 1.06, 1.16, 0.53, 1.57,
1.41, 1.63, 0.47, 1.59, -0.11, -0.51, 0.92, 0.88, 1.55,
2.41, 2.71, 2.16, 1.53, 0.47, 1.39, 0.64, 0.53, -0.51,
-0.69, -0.51, 2.17, 0.53, 0.41, 2.17, 2.42, 0.47, 0.18,
0. , -0.22, 1.46, 1.25, 0.79, 1.1 , 0.64, 0.64, 0.92,
0.59, -0.11, 2.47, 0.64, 1.06, 1.28, 1.31, 1.28, 1.13,
1.19, 1.16, 1.22, 0.59, 1.74, 1.25, 0.47, 3.48, 0.18,
0.79, -0.11, 0.47, 0.34, 1.16, 1.99, 0.41, 0.34, 0.47,
1.63, 0.88, 0.92, 0.26, 1.7 , 0.18, 0.41, 1.99, 0.18,
1.22, 1.19, 0.47, 1.31, -0.11, 0.53, 0.41, 1.03, 1.22,
0. , -0.36, 0.74, 0.69, 0. , 1.7 , 0.47, 1.16, 0.64,
0. , 1.22, 0.59, 1.16, -0.22, 1.48, 0.41, 0.64, 0.47,
0.83, 0.92, 1.03, 0.59, 0.18, 0.64, -1.2 , 0.83, 1.55,
0.79, 0.74, -0.22, 1.87, 1.13, 0.74, 0. , 1.22, 0.64,
0.64, 0.83, 1.48, 2.03, 1.87, 2.13, 0.79, 1.22, 0.34,
1.63, 0.1 , 1.96, 1.76, 2.32, 1.9 , 0.99, 1.22, 0.47,
1.63, 2.01, 2.68, 0.64, 2.01, 0.99, 1.34, 0.69, 0.83,
1.63, 2. , 1.34, 1.1 , 1.5 , 2.14, 1.65, 1.31, 0.47,
2.16, 2.37, 2.09, 1.53, 1.13, 0.92, 0.47, 1.59, 1.93,
0.79, 1.81, 1.1 , 1.92, 2.97, 1.41, 1.79, 2.21, 2.14,
0.18, 1.16, 2.45, 2.27, 1.1 , -0.22, 1.19, 1.57, 1.59,
-0.69, 2.24, 0.59, 0. , 2.33, 2.05, 0.83, 1.89, 2.51,
1.55, 1.84, 1.89, 1.06, 0.69, 0.26, 0.92, 0.1 , 0.26,
0.53, -0.11, 0.59, 1.57, 0.59, 1.22, -0.11, 2.29, 1.69,
2.15, 0.69, 1.9 , 1.36, 1.79, 1.61, 0.96, 2.38, 0.92,
0.79, 1.57, 1.34, 2.6 , 1.1 , 1.48, 1.36, 0.64, 0.47,
0.64, 0.34, 1.9 , 3.02, 1.81, 2.63, 2.33, 1.76, 2.24,
1.25, 1.44, 2.46, 1.99, 1.57, 0.64, -0.22, 1.57, 2.33,
2.43, 2.04, 2.48, -0.51, 1.92, 1.69, 1.16, 0.79, 2. ,
1.65, 0.83, 0.88, 2.77, 2.26, 1.87, 1.53, 1.63, 1.34,
1.1 ])}
import pystan
sm = pystan.StanModel(model_code=pooled_data + pooled_parameters + pooled_model)
for N in [10,20,40,80,90,100,120,140,180,200,250,500,900]:
pooled_data_dict={}
pooled_data_dict['N']=N
pooled_data_dict['x']=original_data['x'][:N]
pooled_data_dict['y']=original_data['y'][:N]
print("N=",pooled_data_dict['N'])
pooled_fit=sm.sampling(data=pooled_data_dict, iter=1000, chains=2)
# The sample can be extracted for plotting and summarization.
# In[ ]:
pooled_sample = pooled_fit.extract(permuted=True)
# In[ ]:
b0, m0 = pooled_sample['beta'].T.mean(1)
print(b0,m0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment