Created
January 27, 2018 16:57
-
-
Save bblais/24e863d561affc3c3977ed6e1e345e94 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# taken from https://github.com/fonnesbeck/stan_workshop_2016 | |
# removed plotting and data loading to make self-contained | |
from numpy import array,int64 | |
import numpy as np | |
pooled_data = """ | |
data { | |
int<lower=0> N; | |
vector[N] x; | |
vector[N] y; | |
} | |
""" | |
# Next we initialize our parameters, which in this case are the linear model coefficients and the normal scale parameter. Notice that `sigma` is constrained to be positive. | |
# In[9]: | |
pooled_parameters = """ | |
parameters { | |
vector[2] beta; | |
real<lower=0> sigma; | |
} | |
""" | |
# Finally, we model the log-radon measurements as a normal sample with a mean that is a function of the floor measurement. | |
# In[ ]: | |
pooled_model = """ | |
model { | |
y ~ normal(beta[1] + beta[2] * x, sigma); | |
} | |
""" | |
# We then pass the code, data, and parameters to the `stan` function. The sampling requires specifying how many iterations we want, and how many parallel chains to sample. Here, we will sample 2 chains of length 1000. | |
original_data = {'N': 919, | |
'x': array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, | |
1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, | |
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, | |
0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, | |
0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, | |
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, | |
0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, | |
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, | |
1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, | |
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, | |
1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, | |
1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, | |
1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, | |
0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, | |
0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, | |
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, | |
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, | |
0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, | |
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, | |
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, | |
0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, | |
0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=int64), | |
'y': array([ 0.83, 0.83, 1.1 , 0.1 , 1.16, 0.96, 0.47, 0.1 , -0.22, | |
0.26, 0.26, 0.34, 0.41, -0.69, 0.18, 1.53, 0.34, 0.79, | |
1.79, 1.22, 0.64, 1.7 , 1.86, 0.69, 1.9 , 1.16, 1.93, | |
1.96, 2.05, 1.67, 1.53, 1.5 , 1.06, 2.1 , 0.53, 1.46, | |
1.7 , 1.41, 0.88, 1.1 , 0.41, 1.22, 1.1 , 0.64, -1.2 , | |
0.92, 0.18, 0.83, -0.36, 0.59, 1.1 , 0.83, 0.59, 0.41, | |
0.69, 0.64, 0.26, 1.48, 1.53, 1.86, 1.55, 1.76, 0.83, | |
-0.69, 1.55, 1.5 , 1.9 , 1.03, 1.1 , 1.1 , 1.99, 1.63, | |
0.99, 1.63, 2.57, 1.99, 1.93, 2.56, 1.77, 2.26, 1.81, | |
1.36, 2.67, 0.64, 1.95, 1.57, 2.26, 0.96, 1.92, 1.41, | |
2.32, 0.83, 0.64, 1.25, 1.74, 1.48, 1.39, 0.34, 1.46, | |
-0.11, 0.74, 0.53, 2.56, 2.69, 1.57, 2.27, -2.3 , 1.34, | |
2.01, 0.69, 1.69, 1.41, 2.05, 0.41, 2.31, 2.25, -0.11, | |
1.5 , 1.63, 0.79, 0.59, 2.1 , 0. , 2.56, 0.99, 1.28, | |
3.28, 0.47, 2.57, 2.19, 2.98, 0.96, 2.21, 2.58, 1.31, | |
1.95, 1.59, 1.25, 0. , 1.25, 1.03, 0.41, 1.93, 2.42, | |
-2.3 , 0.96, 0.64, 0.53, 0.1 , 0. , 1.1 , 1.5 , 0.47, | |
1.44, 0.96, 1.92, 1.48, 1.72, 1.31, 1.06, 2.69, 1.92, | |
2.09, 0.99, 1.06, 1.5 , 0.59, 0.74, 0.74, 0.47, 2.27, | |
2.1 , 1.28, -0.11, 1.65, 1.19, 2.39, 2.12, 1.86, 1.59, | |
1.81, 0.18, 2.17, 2.19, 1.93, 0.88, 0.53, 1.06, 1.89, | |
0.59, 1.55, 1.22, 1.5 , 3.06, 2.22, 0. , 1.61, 1.63, | |
0.18, 2.04, 1.7 , 1.31, 1.61, 1.57, 0.41, 1.25, 1.46, | |
0.96, 0.41, 0.41, 0.69, 1.59, 0.41, 1.36, 2.19, 1.48, | |
1.5 , 1.53, 0.83, -0.51, 1.77, 1.7 , 1.99, 1.76, 2.01, | |
1.59, 1.93, 1.87, 1.34, 1.72, 2.07, 1.5 , 1.03, 1.25, | |
1.46, 0.88, 0.34, 1.67, -1.61, 0.96, 1.19, 1.19, 2.27, | |
1.46, 2.21, 1.86, 3.49, 2.59, 0.83, 1.74, 2.67, 1.95, | |
2.04, 2.29, 0.99, 3.78, 1.61, 1.61, 1.28, 1.59, 1.74, | |
1.28, 1.39, 1.92, 2.08, 1.22, 0.79, 0.53, 1.41, 0.64, | |
0.96, 2.42, 0.99, 1.39, 2.01, 0.34, 0. , -0.69, 0.96, | |
1.81, 0.74, 1.7 , 1.13, 1.1 , 1.72, 1.44, 1.39, 2.71, | |
1.99, 0.88, 1.06, 1.5 , 0.47, 2.16, 1.74, 2.16, 1.36, | |
0.64, 0.69, 1.72, 0.96, -0.11, 0.79, 1.06, 1.39, 1.48, | |
1.57, 1.06, 1.44, 0.53, 1.48, -0.22, 1.72, 1.22, 1.72, | |
0.96, 1.03, 2.14, 1.22, 1.19, 2.16, 0.59, 1.76, 2.57, | |
1.03, 1.57, 1.74, 2.63, 2.04, 1.76, 1.55, 2.04, 0.99, | |
1.53, 1.79, 0.83, 0.92, 1.41, 1.55, 1.55, 2.4 , 2.04, | |
1.13, 0.47, 0.53, 2.81, 1.16, 1.65, 1.61, 1.81, 0. , | |
0.64, 1.39, 1.74, -0.69, 0.99, 1.31, 1.84, 3.17, 1.39, | |
1.1 , 1.13, 1.57, 1.13, 1.46, 1.36, 1.13, 1.48, 1.1 , | |
1.25, 2.15, 2.21, 1.59, 1.31, 0.83, 1.06, -0.11, 0.47, | |
1.55, 1.34, 1.31, 1.13, 0.83, 0.69, 0.99, 0.64, 0.92, | |
1.48, 0.99, 0.18, 1.22, 0.96, 2.25, 0.34, 2.14, 1.63, | |
1.1 , 2.58, 2.73, 0.64, 1.36, 2.08, 0.99, 2.43, 1.44, | |
2.52, 1.92, 1.95, 1.53, 0. , 0.59, 0.41, 0.74, 0.1 , | |
0.1 , 1.06, 0.34, 2.43, 2.78, 0.34, 0.34, 0.53, 0. , | |
1.06, -0.51, 0.47, 1.97, -0.51, 2.32, 1.48, 1.22, 1.1 , | |
2.53, 1.46, 1.53, 1.39, 1.22, 2.87, 2.37, 2.08, 1.28, | |
1.89, 1.95, 1.65, 2.49, 1.65, 2.2 , 1.77, 1.55, 1.39, | |
0.47, 3.17, 0. , 0.41, 0.18, 1.06, 3.88, 0. , 2.13, | |
1.44, -0.51, 1.92, 2.03, 2.23, -0.51, 0.47, 2.34, 1.39, | |
0.64, 2.3 , 0.88, 1.5 , 1.06, 0.18, 0.26, 0.53, 3.24, | |
-2.3 , 2.37, 0.88, 1.39, 1.99, 0.79, 1.19, -0.51, 1.76, | |
0.41, 0.79, 1.5 , 0.92, 1.61, 1.13, 1.13, 1.06, 1.39, | |
2.4 , 1.87, 0.74, 1.13, 1.53, 0.79, 2.09, 0.34, 2.23, | |
0.18, 2.37, 3.18, 2.22, 2.5 , 2.1 , 2.39, 1.46, 2.76, | |
1.7 , 1.84, 2.28, 2.1 , 0.53, 0.53, 1.87, 1.5 , 2.42, | |
2.31, 1.53, 2.09, 0.88, 1.19, 1.63, 1.44, 0.18, 0.74, | |
0.18, 1.1 , 0.79, 2.07, 1.36, 0.96, 1.1 , 0.59, 0.96, | |
2.25, -0.36, 1.03, 0.18, 0.79, 2.49, 2.54, 1.19, 1.46, | |
1.36, 1.34, 1.77, -0.92, 1.44, 1.06, 0.69, 0.26, 0.26, | |
0.47, 2.25, 0.59, 2.5 , 1.48, 1.95, 0.41, 0.96, 2.27, | |
1.36, 1.25, 1.93, 1.31, 0.83, 0.99, 0.79, 1.96, 0.26, | |
1.36, 1.28, 1.46, 0.53, 1.06, 2.16, 1.84, 1.67, 1.03, | |
0.26, 1.28, 1.72, 2.32, 1.72, 0.26, 1.61, 1.41, 1.28, | |
0.96, 0.26, 1.03, 0.59, 1.16, -0.22, 0.1 , 0.69, 1.36, | |
2.2 , 2.01, 3.03, 1.81, 0.79, 1.77, 2.28, 1.87, 1.55, | |
1.74, 2.95, 0.92, 1.13, 1.65, 2.05, 2.1 , 1.57, 2.14, | |
0.53, 1.81, 0.18, 2.44, 1.48, 1.31, 2.34, 1.25, 1.16, | |
1.31, 1.03, 1.41, 0.26, 0.59, 1.46, 2.97, 2.22, 0.74, | |
2.44, 2.33, 0.79, 0.26, 1.19, 0.74, 1.48, 0.83, 1.7 , | |
3.23, 1.65, 0.88, 1.19, 0.96, 1.06, 1.16, 0.53, 1.57, | |
1.41, 1.63, 0.47, 1.59, -0.11, -0.51, 0.92, 0.88, 1.55, | |
2.41, 2.71, 2.16, 1.53, 0.47, 1.39, 0.64, 0.53, -0.51, | |
-0.69, -0.51, 2.17, 0.53, 0.41, 2.17, 2.42, 0.47, 0.18, | |
0. , -0.22, 1.46, 1.25, 0.79, 1.1 , 0.64, 0.64, 0.92, | |
0.59, -0.11, 2.47, 0.64, 1.06, 1.28, 1.31, 1.28, 1.13, | |
1.19, 1.16, 1.22, 0.59, 1.74, 1.25, 0.47, 3.48, 0.18, | |
0.79, -0.11, 0.47, 0.34, 1.16, 1.99, 0.41, 0.34, 0.47, | |
1.63, 0.88, 0.92, 0.26, 1.7 , 0.18, 0.41, 1.99, 0.18, | |
1.22, 1.19, 0.47, 1.31, -0.11, 0.53, 0.41, 1.03, 1.22, | |
0. , -0.36, 0.74, 0.69, 0. , 1.7 , 0.47, 1.16, 0.64, | |
0. , 1.22, 0.59, 1.16, -0.22, 1.48, 0.41, 0.64, 0.47, | |
0.83, 0.92, 1.03, 0.59, 0.18, 0.64, -1.2 , 0.83, 1.55, | |
0.79, 0.74, -0.22, 1.87, 1.13, 0.74, 0. , 1.22, 0.64, | |
0.64, 0.83, 1.48, 2.03, 1.87, 2.13, 0.79, 1.22, 0.34, | |
1.63, 0.1 , 1.96, 1.76, 2.32, 1.9 , 0.99, 1.22, 0.47, | |
1.63, 2.01, 2.68, 0.64, 2.01, 0.99, 1.34, 0.69, 0.83, | |
1.63, 2. , 1.34, 1.1 , 1.5 , 2.14, 1.65, 1.31, 0.47, | |
2.16, 2.37, 2.09, 1.53, 1.13, 0.92, 0.47, 1.59, 1.93, | |
0.79, 1.81, 1.1 , 1.92, 2.97, 1.41, 1.79, 2.21, 2.14, | |
0.18, 1.16, 2.45, 2.27, 1.1 , -0.22, 1.19, 1.57, 1.59, | |
-0.69, 2.24, 0.59, 0. , 2.33, 2.05, 0.83, 1.89, 2.51, | |
1.55, 1.84, 1.89, 1.06, 0.69, 0.26, 0.92, 0.1 , 0.26, | |
0.53, -0.11, 0.59, 1.57, 0.59, 1.22, -0.11, 2.29, 1.69, | |
2.15, 0.69, 1.9 , 1.36, 1.79, 1.61, 0.96, 2.38, 0.92, | |
0.79, 1.57, 1.34, 2.6 , 1.1 , 1.48, 1.36, 0.64, 0.47, | |
0.64, 0.34, 1.9 , 3.02, 1.81, 2.63, 2.33, 1.76, 2.24, | |
1.25, 1.44, 2.46, 1.99, 1.57, 0.64, -0.22, 1.57, 2.33, | |
2.43, 2.04, 2.48, -0.51, 1.92, 1.69, 1.16, 0.79, 2. , | |
1.65, 0.83, 0.88, 2.77, 2.26, 1.87, 1.53, 1.63, 1.34, | |
1.1 ])} | |
import pystan | |
sm = pystan.StanModel(model_code=pooled_data + pooled_parameters + pooled_model) | |
for N in [10,20,40,80,90,100,120,140,180,200,250,500,900]: | |
pooled_data_dict={} | |
pooled_data_dict['N']=N | |
pooled_data_dict['x']=original_data['x'][:N] | |
pooled_data_dict['y']=original_data['y'][:N] | |
print("N=",pooled_data_dict['N']) | |
pooled_fit=sm.sampling(data=pooled_data_dict, iter=1000, chains=2) | |
# The sample can be extracted for plotting and summarization. | |
# In[ ]: | |
pooled_sample = pooled_fit.extract(permuted=True) | |
# In[ ]: | |
b0, m0 = pooled_sample['beta'].T.mean(1) | |
print(b0,m0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment