Skip to content

Instantly share code, notes, and snippets.

@mattiasvillani
Created June 8, 2024 16:38
Gibbs sampling for a finite mixture of normals
using Distributions, LinearAlgebra, Statistics
"""
GibbsMixtureOfNormals(x, K, nIter, compPrior, α)
Simulate Gibbs sample of size `nIter` from the posterior of a mixture of `K` normals:
p(x) = ∑ₖ ωₖ ⋅ N(x | μₖ, σₖ²)
Prior:
(μₖ, σₖ²) ∼ NormalInverseChisq(μ₀, κ₀, ν₀, σ²₀)
ω ∼ Dirichlet(α) α = [α₁,...,αₖ] on the mixture weights.
"""
function GibbsMixtureOfNormals(x, K, nIter, compPrior, α = ones(K), store_alloc = false)
n = length(x)
μ₀, σ²₀, κ₀, ν₀ = compPrior
# Initial values
a = zeros(Int,n) # mixture allocation vector, a[i] ∈ {1,…,K}
ωpost = zeros(K)
if K == 1
μ = [mean(x)]
else
μ = collect(range(extrema(x)..., length = K)) # equally spaced between min and max
end
σ² = (var(x)/K)*ones(K)
ω = (1/K)*ones(K)
# Storage
μdraws = zeros(nIter, K)
σ²draws = zeros(nIter, K)
ωdraws = zeros(nIter, K)
if store_alloc
adraws = zeros(Int8, nIter, n)
else
adraws = nothing
end
for j = 1:nIter
# Update component allocations
for i = 1:n
ωpost = ω .* pdf.(Normal.(μ, sqrt.(σ²)), x[i])
a[i] = rand(Categorical(ωpost / sum(ωpost)))
end
if store_alloc
adraws[j,:] = a
end
nInComps = [sum(a .== k) for k in 1:K]
# Update component parameters
for k = 1:K
μₙ, σₙ², κₙ, νₙ = PostNormalModel(x[a .== k], μ₀, σ²₀, κ₀, ν₀)
μ[k], σ²[k] = rand(NormalInverseChisq(μₙ, σₙ², κₙ, νₙ))
end
μdraws[j,:] = μ
σ²draws[j,:] = σ²
# Update component probabilities
αpost = α + nInComps
ω = rand(Dirichlet(αpost))
ωdraws[j,:] = ω
end
return μdraws, σ²draws, ωdraws, adraws
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment