Skip to content

Instantly share code, notes, and snippets.

@luiarthur
Last active February 16, 2023 02:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save luiarthur/c8ac41aa02dbeeb919cf0d107d1a6b54 to your computer and use it in GitHub Desktop.
Save luiarthur/c8ac41aa02dbeeb919cf0d107d1a6b54 to your computer and use it in GitHub Desktop.
Gibbs sampling with ESS and analytic full conditional in Turing (julia)
# Julia: v1.7
"""
Demo of using `Gibbs` sampling with `ESS` update for one set of parameters and
`GibbsConditional` for another. Demo model is a standard multiple linear
regression model with Gaussian priors for the coefficients and Inverse Gamma
prior for the variance of the error terms. Weakly / non-informative priors are
used. Results are benchmarked against the inference from a model sampled
entirely using NUTS.
"""
using Turing # v0.20.4
using Random
using LinearAlgebra
Base.@kwdef struct Prior{A<:Real, B<:Real}
a::A
b::B
end
# Linear regression model.
@model function linear_regression(X, y, prior)
num_predictors = size(X, 2)
# Priors.
# NOTE: beta has a MvNormal prior with diagonal covariance, and each
# element of the diagonal being 10^2. Diagonal matrices will be inverted
# element-wise, so there is minimal performance hit here. We can also take
# advantage of elliptical slice sampling (ESS), which requires a
# Normal/MvNormal prior, this way.
# i.e., the following is valid:
# beta ~ filldist(Normal(0, 10), num_predictors)
# but is not allowed by the ESS sampler.
sqrt_diag = fill(10, num_predictors)
beta ~ MvNormal(sqrt_diag)
sigmasq ~ InverseGamma(prior.a, prior.b)
# Transform parameters.
mu = X * beta
sigma = sqrt(sigmasq)
# Likelihood.
y .~ Normal.(mu, sigma)
end
# Generate data.
data = let
Random.seed!(0)
nobs, num_predictors = 500, 5
X = randn(nobs, num_predictors)
beta = randn(num_predictors)
mu = X * beta
sigma = 0.1
y = rand.(Normal.(mu, sigma))
(y = y, X = X, beta = beta, sigma = sigma, sigmasq = sigma ^ 2,
nobs = nobs, num_predictors = num_predictors)
end
# Prior parameters for σ².
const prior = Prior(a = 0.01, b = 0.01)
# Instantiate turing model.
const model = linear_regression(data.X, data.y, prior)
@time chain = let
# Sample from posterior using NUTS.
sample(model, NUTS(), 2000, discard_initial=1000)
end
@time chain_gibbs = let
# Convenience function...
sqsum(x::AbstractVector{<:Real}) = sum(x .^ 2)
# Define full conditional distribution.
# state is the current state, as a NamedTuple.
function cond_sigmasq(prior::Prior, state)
a_new = prior.a + data.nobs / 2
b_new = prior.b + sqsum(data.y - data.X * state.beta) / 2
return InverseGamma(a_new, b_new)
end
cond_sigmasq(state) = cond_sigmasq(prior, state)
# The Gibbs sampler iterates between
# - sampling from beta's full conditional using ESS, and
# - sampling from sigmasq's full conditional, analytically
gibbs_sampler = Gibbs(
ESS(:beta),
GibbsConditional(:sigmasq, cond_sigmasq)
)
# Sample from posterior.
sample(model, gibbs_sampler, 2000, discard_initial=1000, thinning=5)
end
# Print summaries.
let
println("NUTS:")
show(stdout, "text/plain", mean(chain))
println("Gibbs (ESS with Analytic update of σ²):")
show(stdout, "text/plain", mean(chain_gibbs))
for param in (:beta, :sigmasq)
println("True $(param): $(data[param])")
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment