Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Last active March 28, 2024 20:03
Show Gist options
  • Save torfjelde/37be5a672d29e473983b8e82b45c2e41 to your computer and use it in GitHub Desktop.
Save torfjelde/37be5a672d29e473983b8e82b45c2e41 to your computer and use it in GitHub Desktop.
Converting output from `generated_quantities(model, chain)` into a `MCMCChains.Chains` object
julia> using Turing
julia> include("utils.jl")
julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end
return (m = m, s = s)
end
demo (generic function with 1 method)
julia> xs = randn(100) .+ 1;
julia> m = demo(xs);
julia> chain = sample(m, MH(), MCMCThreads(), 100, 2);
┌ Warning: Only a single thread available: MCMC chains are not sampled in parallel
└ @ AbstractMCMC ~/.julia/packages/AbstractMCMC/iOkTf/src/sample.jl:197
Sampling (1 threads) 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
julia> res = DynamicPPL.generated_quantities(m, chain);
julia> size(res)
(100, 2)
julia> Chains(res)
Chains MCMC chain (100×2×2 Array{Float64,3}):
Iterations = 1:100
Thinning interval = 1
Chains = 1, 2
Samples per chain = 100
parameters = m, s
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
m 0.8961 0.2833 0.0200 0.0538 5.4693 1.2846
s 1.4348 1.4394 0.1018 0.1462 62.8920 1.0039
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
m -0.0137 0.7715 0.9439 1.0597 1.1702
s 0.8233 1.0051 1.3981 1.4125 1.6766
julia> # Or creating a chain by hand:
res = [(x1 = randn(), x2 = randn(2), x3 = randn(2, 2)) for i = 1:100];
julia> Chains(res)
Chains MCMC chain (100×7×1 Array{Float64,3}):
Iterations = 1:100
Thinning interval = 1
Chains = 1
Samples per chain = 100
parameters = x1, x2[1], x2[2], x3[1,1], x3[2,1], x3[1,2], x3[2,2]
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Missing Float64 Float64
x1 0.1048 1.0150 0.1015 missing 83.5494 0.9905
x2[1] -0.0370 1.0645 0.1064 missing 52.3812 0.9975
x2[2] -0.0174 1.1423 0.1142 missing 109.7089 0.9903
x3[1,1] -0.1262 0.9917 0.0992 missing 215.2624 0.9927
x3[2,1] -0.1030 0.8943 0.0894 missing 115.0757 0.9949
x3[1,2] 0.1921 0.9276 0.0928 missing 126.8230 1.0107
x3[2,2] 0.0725 1.0082 0.1008 missing 92.7707 0.9946
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
x1 -1.7972 -0.6828 0.2499 0.7517 1.9684
x2[1] -2.1026 -0.8716 -0.0407 0.7425 2.3251
x2[2] -2.2431 -0.8199 -0.0915 0.7274 2.1686
x3[1,1] -1.9870 -0.8212 -0.2149 0.5845 1.7913
x3[2,1] -1.6532 -0.8203 -0.1102 0.5373 1.5690
x3[1,2] -1.5999 -0.3648 0.2584 0.8105 1.8125
x3[2,2] -2.0730 -0.6014 0.0374 0.7459 2.0825
generate_names(val) = generate_names("", val)
generate_names(vn_str::String, val::Real) = [vn_str;]
function generate_names(vn_str::String, val::NamedTuple)
return map(keys(val)) do k
generate_names("$(vn_str)$(k)", val[k])
end
end
function generate_names(vn_str::String, val::AbstractArray{<:Real})
results = String[]
for idx in CartesianIndices(val)
s = join(idx.I, ",")
push!(results, "$vn_str[$s]")
end
return results
end
function generate_names(vn_str::String, val::AbstractArray{<:AbstractArray})
results = String[]
for idx in CartesianIndices(val)
s1 = join(idx.I, ",")
inner_results = map(f("", val[idx])) do s2
"$vn_str[$s1]$s2"
end
append!(results, inner_results)
end
return results
end
flatten(val::Real) = [val;]
function flatten(val::AbstractArray{<:Real})
return mapreduce(vcat, CartesianIndices(val)) do i
val[i]
end
end
function flatten(val::AbstractArray{<:AbstractArray})
return mapreduce(vcat, CartesianIndices(val)) do i
flatten(val[i])
end
end
function vectup2chainargs(ts::AbstractVector{<:NamedTuple})
ks = keys(first(ts))
vns = mapreduce(vcat, ks) do k
generate_names(string(k), first(ts)[k])
end
vals = map(eachindex(ts)) do i
mapreduce(vcat, ks) do k
flatten(ts[i][k])
end
end
arr_tmp = reduce(hcat, vals)'
arr = reshape(arr_tmp, (size(arr_tmp)..., 1)) # treat as 1 chain
return Array(arr), vns
end
function vectup2chainargs(ts::AbstractMatrix{<:NamedTuple})
num_samples, num_chains = size(ts)
res = map(1:num_chains) do chain_idx
vectup2chainargs(ts[:, chain_idx])
end
vals = getindex.(res, 1)
vns = getindex.(res, 2)
# Verify that the variable names are indeed the same
vns_union = reduce(union, vns)
@assert all(isempty.(setdiff.(vns, Ref(vns_union)))) "variable names differ between chains"
arr = cat(vals...; dims = 3)
return arr, first(vns)
end
function MCMCChains.Chains(ts::AbstractArray{<:NamedTuple})
return MCMCChains.Chains(vectup2chainargs(ts)...)
end
@joshualeond
Copy link

Hey Tor, this is great! Just curious if there's a reason why this isn't included in Turing.jl?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment