Skip to content

Instantly share code, notes, and snippets.

@trappmartin
Last active April 30, 2020 14:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save trappmartin/073194b368218c610a74844064e6b865 to your computer and use it in GitHub Desktop.
Save trappmartin/073194b368218c610a74844064e6b865 to your computer and use it in GitHub Desktop.
using Turing, Turing.RandomMeasures
@model dpmixture(x) = begin
rpm = DirichletProcess(1.0)
n = zeros(Int, length(x))
z = zeros(Int, length(x))
for i in eachindex(x)
z[i] ~ ChineseRestaurantProcess(rpm, n)
n[z[i]] += 1
end
K = findlast(!iszero, n)
m ~ MvNormal(fill(0.0, K), 1.0)
x ~ MvNormal(m[z], 1.0)
end
x = vcat(randn(100).-2, randn(100).+2);
model = dpmixture(x);
alg = Gibbs(PG(20, :z), HMC(0.05, 5, :m));
chn = sample(model, alg, 1_000);
@trappmartin
Copy link
Author

trappmartin commented Apr 15, 2020

julia> r = Array(chn[:z]);

julia> histogram(map(it -> length(unique(r[it,:])), 1:1_000))
                ┌                                        ┐
   [ 1.0,  2.0) ┤▇▇▇▇ 46
   [ 2.0,  3.0) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 160
   [ 3.0,  4.0) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 304
   [ 4.0,  5.0) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 373
   [ 5.0,  6.0) ┤▇▇▇▇▇▇▇▇ 89
   [ 6.0,  7.0) ┤▇ 15
   [ 7.0,  8.0) ┤▇ 8
   [ 8.0,  9.0) ┤ 4
   [ 9.0, 10.0) ┤ 1
                └                                        ┘
                                Frequency

@trappmartin
Copy link
Author

trappmartin commented Apr 15, 2020

julia> lineplot(map(it -> length(unique(r[it,:])), 1:1_000), xlabel="iteration", ylabel="K")
       ┌────────────────────────────────────────┐
     9 │⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       │⣀⠀⣇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       │⣿⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       │⣿⣀⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       │⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       │⣿⣿⣿⠀⢠⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       │⣿⣿⣿⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
   K   │⣿⣿⣿⣤⣼⡄⠀⢠⡄⡄⡄⡄⠀⠀⠀⠀⠀⠀⠀⠀⢠⠀⠀⠀⠀⣤⣤⠀⠀⢠⠀⡄⠀⣤⣤⣤⣤⣤⣤⣤│
       │⣿⣿⣿⣿⣿⡇⠀⢸⡇⡇⡇⡇⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⣿⣿⠀⠀⢸⠀⡇⠀⣿⣿⣿⣿⣿⣿⣿│
       │⣿⣿⣿⣿⣿⣷⣶⣾⣷⣷⣷⣷⣶⣶⣶⣶⣶⣶⣶⣶⣾⣶⣶⣶⣶⣿⣿⣶⣶⣾⣶⣷⣶⣿⣿⣿⣿⣿⣿⣿│
       │⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿│
       │⡟⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⢻⣿⣿⣿⣿⣿⣿⢻⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⢻⣿⣿⣿⡟⣿│
       │⡇⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⢸⣿⣿⣿⣿⣿⣿⢸⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⢸⣿⣿⣿⡇⣿│
       │⠁⢹⠉⡏⢹⠉⣿⣿⠉⡏⢹⢹⢸⢹⣿⡏⣿⡏⠉⢸⢹⡏⢹⡏⠉⣿⢹⡏⢹⡏⠉⡏⢹⠉⠈⢹⡏⢹⠁⣿│
     1 │⠀⢸⠀⡇⢸⠀⢸⣿⠀⡇⢸⢸⢸⢸⣿⡇⣿⡇⠀⢸⢸⡇⢸⡇⠀⣿⢸⡇⢸⡇⠀⡇⢸⠀⠀⢸⡇⢸⠀⣿│
       └────────────────────────────────────────┘
       0                                     1000
                       iteration

@trappmartin
Copy link
Author

trappmartin commented Apr 15, 2020

julia> chn[:m]
Object of type Chains, with data of type 1000×9×1 Array{Union{Missing, Float64},3}

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1
Samples per chain = 1000
parameters        = m[1], m[2], m[3], m[4], m[5], m[6], m[7], m[8], m[9]

2-element Array{ChainDataFrame,1}

Summary Statistics
  parameters     mean     std  naive_se    mcse      ess   r_hat
  ──────────  ───────  ──────  ────────  ──────  ───────  ──────
        m[1]  -0.0984  0.2812    0.0089  0.0077  20.0729  1.0013
        m[2]  -0.0954  0.5023    0.0159  0.0524  33.5777  1.0045
        m[3]  -0.1827  0.8570    0.0271  0.1166  20.5320  1.0169
        m[4]   0.0657  1.1112    0.0351  0.2279  13.2401  1.0701
        m[5]   4.2691  2.2366    0.0707  0.6882   5.6262  1.0119
        m[6]   6.7992  2.7605    0.0873  0.8468   6.7059  1.3040
        m[7]  12.5963  9.6755    0.3067  3.0375   3.6064  2.9494
        m[8]  -0.1241  1.1384    0.0361  0.2588   4.9286  1.0661
        m[9]   6.8390  4.6361    0.1505  1.5761   3.6064  2.0920

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment