Last active
April 30, 2020 14:50
-
-
Save trappmartin/073194b368218c610a74844064e6b865 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Turing, Turing.RandomMeasures | |
@model dpmixture(x) = begin | |
rpm = DirichletProcess(1.0) | |
n = zeros(Int, length(x)) | |
z = zeros(Int, length(x)) | |
for i in eachindex(x) | |
z[i] ~ ChineseRestaurantProcess(rpm, n) | |
n[z[i]] += 1 | |
end | |
K = findlast(!iszero, n) | |
m ~ MvNormal(fill(0.0, K), 1.0) | |
x ~ MvNormal(m[z], 1.0) | |
end | |
x = vcat(randn(100).-2, randn(100).+2); | |
model = dpmixture(x); | |
alg = Gibbs(PG(20, :z), HMC(0.05, 5, :m)); | |
chn = sample(model, alg, 1_000); |
Author
trappmartin
commented
Apr 15, 2020
•
julia> lineplot(map(it -> length(unique(r[it,:])), 1:1_000), xlabel="iteration", ylabel="K")
┌────────────────────────────────────────┐
9 │⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⣀⠀⣇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⣿⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⣿⣀⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⣿⣿⣿⠀⢠⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
│⣿⣿⣿⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
K │⣿⣿⣿⣤⣼⡄⠀⢠⡄⡄⡄⡄⠀⠀⠀⠀⠀⠀⠀⠀⢠⠀⠀⠀⠀⣤⣤⠀⠀⢠⠀⡄⠀⣤⣤⣤⣤⣤⣤⣤│
│⣿⣿⣿⣿⣿⡇⠀⢸⡇⡇⡇⡇⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⣿⣿⠀⠀⢸⠀⡇⠀⣿⣿⣿⣿⣿⣿⣿│
│⣿⣿⣿⣿⣿⣷⣶⣾⣷⣷⣷⣷⣶⣶⣶⣶⣶⣶⣶⣶⣾⣶⣶⣶⣶⣿⣿⣶⣶⣾⣶⣷⣶⣿⣿⣿⣿⣿⣿⣿│
│⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿│
│⡟⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⢻⣿⣿⣿⣿⣿⣿⢻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⢻⣿⣿⣿⡟⣿│
│⡇⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⢸⣿⣿⣿⣿⣿⣿⢸⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⢸⣿⣿⣿⡇⣿│
│⠁⢹⠉⡏⢹⠉⣿⣿⠉⡏⢹⢹⢸⢹⣿⡏⣿⡏⠉⢸⢹⡏⢹⡏⠉⣿⢹⡏⢹⡏⠉⡏⢹⠉⠈⢹⡏⢹⠁⣿│
1 │⠀⢸⠀⡇⢸⠀⢸⣿⠀⡇⢸⢸⢸⢸⣿⡇⣿⡇⠀⢸⢸⡇⢸⡇⠀⣿⢸⡇⢸⡇⠀⡇⢸⠀⠀⢸⡇⢸⠀⣿│
└────────────────────────────────────────┘
0 1000
iteration
julia> chn[:m]
Object of type Chains, with data of type 1000×9×1 Array{Union{Missing, Float64},3}
Iterations = 1:1000
Thinning interval = 1
Chains = 1
Samples per chain = 1000
parameters = m[1], m[2], m[3], m[4], m[5], m[6], m[7], m[8], m[9]
2-element Array{ChainDataFrame,1}
Summary Statistics
parameters mean std naive_se mcse ess r_hat
────────── ─────── ────── ──────── ────── ─────── ──────
m[1] -0.0984 0.2812 0.0089 0.0077 20.0729 1.0013
m[2] -0.0954 0.5023 0.0159 0.0524 33.5777 1.0045
m[3] -0.1827 0.8570 0.0271 0.1166 20.5320 1.0169
m[4] 0.0657 1.1112 0.0351 0.2279 13.2401 1.0701
m[5] 4.2691 2.2366 0.0707 0.6882 5.6262 1.0119
m[6] 6.7992 2.7605 0.0873 0.8468 6.7059 1.3040
m[7] 12.5963 9.6755 0.3067 3.0375 3.6064 2.9494
m[8] -0.1241 1.1384 0.0361 0.2588 4.9286 1.0661
m[9] 6.8390 4.6361 0.1505 1.5761 3.6064 2.0920
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment