Last active
April 16, 2017 20:36
-
-
Save bbbales2/34602cacc6cddeae68c68228670882f3 to your computer and use it in GitHub Desktop.
Constrain sigma to be [0.0, inf)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pystan | |
import numpy | |
import matplotlib.pyplot as plt | |
# We gotta only have a few datapoints. Too much data and the mistake isn't visible | |
y = numpy.random.lognormal(mean = 0.7, sigma = 1.0, size = 5) | |
# This model does not have the adjustment and so is incorrect | |
model_code = """ | |
data { | |
int<lower=1> N; | |
vector<lower=0.0>[N] y; | |
} | |
parameters { | |
real mu; | |
real logsigma; | |
} | |
transformed parameters { | |
real sigma; | |
sigma = exp(logsigma); | |
} | |
model { | |
y ~ lognormal(mu, sigma); | |
} | |
""" | |
sm3 = pystan.StanModel(model_code = model_code) | |
# This model has the adjustment and is correct | |
model_code = """ | |
data { | |
int<lower=1> N; | |
vector<lower=0.0>[N] y; | |
} | |
parameters { | |
real mu; | |
real logsigma; | |
} | |
transformed parameters { | |
real sigma; | |
sigma = exp(logsigma); | |
} | |
model { | |
y ~ lognormal(mu, sigma); | |
target += log(fabs(sigma)); // or more simply, target += logsigma // (I had the sign wrong first time) | |
} | |
""" | |
sm4 = pystan.StanModel(model_code = model_code) | |
# Stan handles the adjustment in this model | |
model_code = """ | |
data { | |
int<lower=1> N; | |
vector<lower=0.0>[N] y; | |
} | |
parameters { | |
real mu; | |
real<lower=0.0> sigma; | |
} | |
model { | |
y ~ lognormal(mu, sigma); | |
} | |
""" | |
sm5 = pystan.StanModel(model_code = model_code) | |
# The results for sigma won't be the same | |
fit3 = sm3.sampling(data = { 'N' : len(y), 'y' : y }, iter = 100000, warmup = 1000) | |
fit4 = sm4.sampling(data = { 'N' : len(y), 'y' : y }, iter = 100000, warmup = 1000) | |
fit5 = sm5.sampling(data = { 'N' : len(y), 'y' : y }, iter = 100000, warmup = 1000) | |
print fit3 | |
print fit4 | |
print fit5 | |
_, bins, _ = plt.hist(fit3.extract()['sigma'], alpha = 0.333, bins = 50) | |
plt.hist(fit4.extract()['sigma'], alpha = 0.333, bins = bins) | |
plt.hist(fit5.extract()['sigma'], alpha = 0.333, bins = bins) | |
plt.legend(['No correction', 'Correction', 'Stan']) | |
plt.show(block = True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment