Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Last active November 18, 2023 22:26
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save torfjelde/cc5c41e97eb4c97e22a19b8440f6d506 to your computer and use it in GitHub Desktop.
Save torfjelde/cc5c41e97eb4c97e22a19b8440f6d506 to your computer and use it in GitHub Desktop.
Simple example of using NUTS with the new iterator interface in AbstractMCMC.jl available using Turing.jl > 0.15.
julia> using Turing, Random
julia> @model function gdemo(xs)
# Assumptions
σ² ~ InverseGamma(2, 3)
μ ~ Normal(0, √σ²)
# Observations
for i = 1:length(xs)
xs[i] ~ Normal(μ, √σ²)
end
end
gdemo (generic function with 2 methods)
julia> # Set up.
xs = randn(100);
julia> model = gdemo(xs);
julia> # Sampler.
alg = NUTS(0.65);
julia> kwargs = (nadapts=50,);
julia> num_samples = 100;
julia> ### The following two methods are equivalent ###
## Using `sample` ##
rng = MersenneTwister(42);
julia> chain = sample(rng, model, alg, num_samples; kwargs...)
┌ Info: Found initial step size
└ ϵ = 0.4
Sampling 100%|█████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (100×14×1 Array{Float64, 3}):
Iterations = 51:1:150
Number of chains = 1
Samples per chain = 100
Wall duration = 1.12 seconds
Compute duration = 1.12 seconds
parameters = σ², μ
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
σ² 1.1090 0.1728 0.0591 10.8467 51.5563 1.1136 9.6587
μ -0.1753 0.1030 0.0126 66.7404 78.3393 0.9940 59.4304
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
σ² 0.8499 0.9980 1.1052 1.2155 1.4554
μ -0.3526 -0.2485 -0.1774 -0.1074 0.0239
julia> ## Using the iterator-interface ##
rng = MersenneTwister(42);
julia> spl = DynamicPPL.Sampler(alg);
julia> nadapts = 50;
julia> # Create an iterator we can just step through.
it = AbstractMCMC.Stepper(rng, model, spl, kwargs);
julia> # Initial sample and state.
transition, state = iterate(it);
┌ Info: Found initial step size
└ ϵ = 0.4
julia> # Simple container to hold the samples.
transitions = [];
julia> # Simple condition that says we only want `num_samples` samples.
condition(spls) = length(spls) < num_samples
condition (generic function with 1 method)
julia> # Sample until `condition` is no longer satisfied
while condition(transitions)
# For an iterator we pass in the previous `state` as the second argument
transition, state = iterate(it, state)
# Save `transition` if we're not adapting anymore
if state.i > nadapts
push!(transitions, transition)
end
end
julia> length(transitions), state.i, state.i == length(transitions) + nadapts
(100, 150, true)
julia> # Finally, if you want to convert the vector of `transitions` into a
# `MCMCChains.Chains` like is typically done:
chain = AbstractMCMC.bundle_samples(
map(identity, transitions), # trick to concretize the eltype of `transitions`
model,
spl,
state,
MCMCChains.Chains
)
Chains MCMC chain (100×14×1 Array{Float64, 3}):
Iterations = 1:1:100
Number of chains = 1
Samples per chain = 100
parameters = σ², μ
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Missing
σ² 1.1090 0.1728 0.0591 10.8467 51.5563 1.1136 missing
μ -0.1753 0.1030 0.0126 66.7404 78.3393 0.9940 missing
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
σ² 0.8499 0.9980 1.1052 1.2155 1.4554
μ -0.3526 -0.2485 -0.1774 -0.1074 0.0239
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment