Skip to content

Instantly share code, notes, and snippets.

@trappmartin
Last active August 28, 2020 17:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save trappmartin/6551df058eef18cbe2591e1bf52bd0a3 to your computer and use it in GitHub Desktop.
Save trappmartin/6551df058eef18cbe2591e1bf52bd0a3 to your computer and use it in GitHub Desktop.
infinite mixture model in Turing
using Turing
using Turing.RandomMeasures
# Implementation of infinite mixture model
@model function imm(y, α, ::Type{T}=Vector{Float64}) where {T}
N = length(y)
nk = tzeros(Int, N)
z = tzeros(Int, N)
for i in 1:N
z[i] ~ ChineseRestaurantProcess(DirichletProcess(α), nk)
nk[z[i]] += 1
end
K = findlast(!iszero, nk)
μ = T(undef, K)
s = T(undef, K)
for k = 1:K
μ[k] ~ Normal()
s[k] ~ InverseGamma(2,3)
end
for i in 1:N
x[i] ~ Normal(μ[z[i]], sqrt(s[z[i]]))
end
end
# Daten vom R. Neal paper
x = [-1.48, -1.40, -1.16, -1.08, -1.02, 0.14, 0.51, 0.53, 0.78];
model = imm(x, 10.0)
chn = sample(model, Gibbs(HMC(0.05, 5, :μ, :s), PG(10, :z)), 1000);
# some plotting.
using UnicodePlots
# number of active clusters
lp = mapslices(i -> length(unique(i)), Array(chn[:z]), dims=[2])
histogram(vec(lp))
#=
┌ ┐
[ 3.0, 3.5) ┤ 4
[ 3.5, 4.0) ┤ 0
[ 4.0, 4.5) ┤▇▇▇ 31
[ 4.5, 5.0) ┤ 0
[ 5.0, 5.5) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 128
[ 5.5, 6.0) ┤ 0
[ 6.0, 6.5) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 271
[ 6.5, 7.0) ┤ 0
[ 7.0, 7.5) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 319
[ 7.5, 8.0) ┤ 0
[ 8.0, 8.5) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 200
[ 8.5, 9.0) ┤ 0
[ 9.0, 9.5) ┤▇▇▇▇▇ 47
└ ┘
Frequency
=#
# Implementation that is better suited for particle sampling as we delay the sampling of mu and s
@model function imm(y, α, ::Type{T}=Vector{Float64}) where {T}
N = length(y)
nk = tzeros(Int, N)
z = tzeros(Int, N)
μ = T(undef, 0)
s = T(undef, 0)
for i in 1:N
z[i] ~ ChineseRestaurantProcess(DirichletProcess(α), nk)
if sum(nk) == 0
push!(μ, 0.0)
push!(s, 1.0)
μ[z[i]] ~ Normal()
s[z[i]] ~ InverseGamma(2,3)
elseif findlast(!iszero, nk) < z[i]
push!(μ, 0.0)
push!(s, 1.0)
μ[z[i]] ~ Normal()
s[z[i]] ~ InverseGamma(2,3)
end
nk[z[i]] += 1
x[i] ~ Normal(μ[z[i]], sqrt(s[z[i]]))
end
end
model = imm(x, 10.0)
chn = sample(model, Gibbs(HMC(0.05, 5, :μ, :s), PG(10, :z)), 1000);
# number of active clusters
lp = mapslices(i -> length(unique(i)), Array(chn[:z]), dims=[2])
histogram(vec(lp))
#=
┌ ┐
[ 2.0, 3.0) ┤ 1
[ 3.0, 4.0) ┤▇ 5
[ 4.0, 5.0) ┤▇▇▇▇ 36
[ 5.0, 6.0) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 124
[ 6.0, 7.0) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 275
[ 7.0, 8.0) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 303
[ 8.0, 9.0) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 209
[ 9.0, 10.0) ┤▇▇▇▇▇ 47
└ ┘
Frequency
=#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment