Last active
February 16, 2023 02:59
-
-
Save luiarthur/c8ac41aa02dbeeb919cf0d107d1a6b54 to your computer and use it in GitHub Desktop.
Gibbs sampling with ESS and analytic full conditional in Turing (julia)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Julia: v1.7 | |
""" | |
Demo of using `Gibbs` sampling with `ESS` update for one set of parameters and | |
`GibbsConditional` for another. Demo model is a standard multiple linear | |
regression model with Gaussian priors for the coefficients and Inverse Gamma | |
prior for the variance of the error terms. Weakly / non-informative priors are | |
used. Results are benchmarked against the inference from a model sampled | |
entirely using NUTS. | |
""" | |
using Turing # v0.20.4 | |
using Random | |
using LinearAlgebra | |
Base.@kwdef struct Prior{A<:Real, B<:Real} | |
a::A | |
b::B | |
end | |
# Linear regression model. | |
@model function linear_regression(X, y, prior) | |
num_predictors = size(X, 2) | |
# Priors. | |
# NOTE: beta has a MvNormal prior with diagonal covariance, and each | |
# element of the diagonal being 10^2. Diagonal matrices will be inverted | |
# element-wise, so there is minimal performance hit here. We can also take | |
# advantage of elliptical slice sampling (ESS), which requires a | |
# Normal/MvNormal prior, this way. | |
# i.e., the following is valid: | |
# beta ~ filldist(Normal(0, 10), num_predictors) | |
# but is not allowed by the ESS sampler. | |
sqrt_diag = fill(10, num_predictors) | |
beta ~ MvNormal(sqrt_diag) | |
sigmasq ~ InverseGamma(prior.a, prior.b) | |
# Transform parameters. | |
mu = X * beta | |
sigma = sqrt(sigmasq) | |
# Likelihood. | |
y .~ Normal.(mu, sigma) | |
end | |
# Generate data. | |
data = let | |
Random.seed!(0) | |
nobs, num_predictors = 500, 5 | |
X = randn(nobs, num_predictors) | |
beta = randn(num_predictors) | |
mu = X * beta | |
sigma = 0.1 | |
y = rand.(Normal.(mu, sigma)) | |
(y = y, X = X, beta = beta, sigma = sigma, sigmasq = sigma ^ 2, | |
nobs = nobs, num_predictors = num_predictors) | |
end | |
# Prior parameters for σ². | |
const prior = Prior(a = 0.01, b = 0.01) | |
# Instantiate turing model. | |
const model = linear_regression(data.X, data.y, prior) | |
@time chain = let | |
# Sample from posterior using NUTS. | |
sample(model, NUTS(), 2000, discard_initial=1000) | |
end | |
@time chain_gibbs = let | |
# Convenience function... | |
sqsum(x::AbstractVector{<:Real}) = sum(x .^ 2) | |
# Define full conditional distribution. | |
# state is the current state, as a NamedTuple. | |
function cond_sigmasq(prior::Prior, state) | |
a_new = prior.a + data.nobs / 2 | |
b_new = prior.b + sqsum(data.y - data.X * state.beta) / 2 | |
return InverseGamma(a_new, b_new) | |
end | |
cond_sigmasq(state) = cond_sigmasq(prior, state) | |
# The Gibbs sampler iterates between | |
# - sampling from beta's full conditional using ESS, and | |
# - sampling from sigmasq's full conditional, analytically | |
gibbs_sampler = Gibbs( | |
ESS(:beta), | |
GibbsConditional(:sigmasq, cond_sigmasq) | |
) | |
# Sample from posterior. | |
sample(model, gibbs_sampler, 2000, discard_initial=1000, thinning=5) | |
end | |
# Print summaries. | |
let | |
println("NUTS:") | |
show(stdout, "text/plain", mean(chain)) | |
println("Gibbs (ESS with Analytic update of σ²):") | |
show(stdout, "text/plain", mean(chain_gibbs)) | |
for param in (:beta, :sigmasq) | |
println("True $(param): $(data[param])") | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment