Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Last active December 20, 2021 12:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save torfjelde/5dd1ed93a81759c98ff0ef3feeb24237 to your computer and use it in GitHub Desktop.
Save torfjelde/5dd1ed93a81759c98ff0ef3feeb24237 to your computer and use it in GitHub Desktop.
This requires `NestedSamplers#tor/improvements` and `Turing@0.19.2` or higher.
julia> include("nested_samplers.jl")
julia> @model function demo()
m ~ Normal()
x ~ Normal(m, 1)
y ~ Normal(x, 1)
return (; m, x, y)
end
demo (generic function with 2 methods)
julia> model = NestedModel(demo() | (y = 1, ));
julia> bounds = Bounds.MultiEllipsoid
NestedSamplers.Bounds.MultiEllipsoid
julia> prop = Proposals.Slice(slices=10)
NestedSamplers.Proposals.Slice
slices: Int64 10
scale: Float64 1.0
julia> # 1000 live points
sampler = Nested(2, 1000; bounds=bounds, proposal=prop)
Nested{UnionAll, NestedSamplers.Proposals.Slice}(2, 1000, NestedSamplers.Bounds.MultiEllipsoid, 1.25, 18000, 2000, 0.1, NestedSamplers.Proposals.Slice
slices: Int64 10
scale: Float64 1.0
, 0.0009995003330836028)
julia> chain = first(sample(model, sampler; dlogz=0.2))
Nested Sampling 100%|███████████████████████████████████████████████| Time: 0:00:00
┌ Warning: timestamp of type Missing unknown
└ @ MCMCChains ~/.julia/packages/MCMCChains/pPqxj/src/chains.jl:364
┌ Warning: timestamp of type Missing unknown
└ @ MCMCChains ~/.julia/packages/MCMCChains/pPqxj/src/chains.jl:364
Chains MCMC chain (3420×3×1 Array{Float64, 3}):
Log evidence = -1.6317018558704408
Iterations = 1:1:3420
Number of chains = 1
Samples per chain = 3420
parameters = Parameter 1, Parameter 2
internals = weights
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
Parameter 1 0.2859 0.8713 0.0149 0.0512 40.5689 1.0509
Parameter 2 0.5766 1.0119 0.0173 0.0999 14.0664 1.1749
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
Parameter 1 -1.5621 -0.2712 0.3187 0.8734 1.9186
Parameter 2 -2.0324 0.1966 0.8158 1.1451 2.2301
julia> # Or from Turing.jl's side.
chain = sample(
demo() | (y = 1, ),
NestedSampler(1000; bounds=bounds, proposal=prop),
1000; # TODO: Make it so we don't have to specify the number of steps?
dlogz=0.2
)
Sampling 100%|██████████████████████████████████████████████████████| Time: 0:00:00
┌ Warning: timestamp of type Missing unknown
└ @ MCMCChains ~/.julia/packages/MCMCChains/pPqxj/src/chains.jl:364
┌ Warning: timestamp of type Missing unknown
└ @ MCMCChains ~/.julia/packages/MCMCChains/pPqxj/src/chains.jl:364
Chains MCMC chain (2000×3×1 Array{Float64, 3}):
Log evidence = -1.658376606163276
Iterations = 1:1:2000
Number of chains = 1
Samples per chain = 2000
parameters = m, x
internals = weights
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
m 0.1037 0.9324 0.0208 0.0692 16.5079 1.1133
x 0.2792 1.2525 0.0280 0.1375 8.0195 1.2631
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
m -1.7807 -0.5142 0.1251 0.7495 1.8497
x -2.4141 -0.4269 0.3345 1.1117 2.5301
using Turing
using NestedSamplers
using Setfield: Setfield
using Random: Random
# Implementation of `InferenceAlgorithm`.
struct NestedSampler{B,P,Ks} <: Turing.Inference.InferenceAlgorithm
num_particles::Int
bounds::B
proposal::P
kwargs::Ks
end
function NestedSampler(
num_particles;
bounds=Bounds.MultiEllipsoid,
proposal=Proposals.Slice(slices=10),
kwargs...
)
return NestedSampler(num_particles, bounds, proposal, kwargs)
end
struct NestedSamplerState{M,V,S,Spl}
"model from NestedSamplers.jl"
model::M
"representative `AbstractVarInfo` from the target model"
varinfo::V
"state from NestedSamplers.jl"
state::S
"sampler from NestedSamplers.jl"
sampler::Spl
end
DynamicPPL.initialsampler(sampler::DynamicPPL.Sampler{<:NestedSampler}) = sampler
DynamicPPL.getspace(::DynamicPPL.Sampler{<:NestedSampler}) = ()
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:NestedSampler};
kwargs...
)
# Set up the `AbstractVarInfo`.
varinfo = DynamicPPL.VarInfo(rng, model, spl)
# Construct the `NestedModel`.
nested_model = NestedSamplers.NestedModel(model, varinfo, spl)
# Set up the `NestedSamplers.Nested` sampler.
alg = spl.alg
sampler_nested = NestedSamplers.Nested(
length(varinfo[spl]), alg.num_particles;
bounds=alg.bounds,
proposal=alg.proposal,
alg.kwargs...
)
# Transition.
transition, state_nested = AbstractMCMC.step(
rng, nested_model, sampler_nested;
kwargs...
)
# Construct state.
state = NestedSamplerState(nested_model, varinfo, state_nested, sampler_nested)
return transition, state
end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:NestedSampler},
state::NestedSamplerState;
kwargs...
)
transition, state_nested = AbstractMCMC.step(
rng, state.model, state.sampler, state.state;
kwargs...
)
return transition, Setfield.@set(state.state = state_nested)
end
function DynamicPPL.assume(
rng::Random.AbstractRNG,
sampler::DynamicPPL.Sampler{<:NestedSampler},
dist::Distribution,
vn::DynamicPPL.VarName,
vi::DynamicPPL.AbstractVarInfo,
)
if haskey(vi, vn)
# If it's already present, we assume it's in unit-cube space
# but we want to return it in the original space.
# TODO: We should probably specify this somehow, e.g. using
# a field in `sampler` or something.
val = invlogcdf(dist, log(vi[vn]))
vi = DynamicPPL.setindex!!(vi, val, vn)
else
val = rand(rng, dist)
vi = DynamicPPL.push!!(vi, vn, val, dist, sampler)
end
return val, zero(eltype(val)), vi
end
function AbstractMCMC.bundle_samples(
ts::Vector,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:NestedSampler},
state::NestedSamplerState,
chain_type::Type{MCMCChains.Chains};
kwargs...
)
chain = first(AbstractMCMC.bundle_samples(
ts, model, state.sampler, state.state, chain_type;
kwargs...
))
return MCMCChains.replacenames(
chain,
Dict(zip(names(chain, :parameters), map(string, keys(state.varinfo))))
)
end
###################################
### NestedSamplers.jl interface ###
###################################
function NestedSamplers.NestedModel(model::DynamicPPL.Model)
return NestedSamplers.NestedModel(Random.GLOBAL_RNG, model)
end
function NestedSamplers.NestedModel(rng::Random.AbstractRNG, model::DynamicPPL.Model)
sampler = DynamicPPL.Sampler(NestedSampler())
vi = DynamicPPL.VarInfo(rng, model, sampler)
return NestedSamplers.NestedModel(model, vi, sampler)
end
function NestedSamplers.NestedModel(
model::DynamicPPL.Model,
vi::DynamicPPL.AbstractVarInfo,
sampler::DynamicPPL.AbstractSampler=DynamicPPL.Sampler(NestedSampler())
)
vi_base = deepcopy(vi)
function prior_transform_and_loglikelihod_model(u)
# Update in unit-cube space.
vi_new = DynamicPPL.setindex!!(vi_base, u, sampler)
# Evaluate model, computing the transformed variables and the loglikelihood.
_, vi_new = DynamicPPL.evaluate!!(model, vi_new, sampler)
# Return the new samples and loglikelihood.
return vi_new[sampler], DynamicPPL.getlogp(vi_new)
end
return NestedSamplers.NestedModel(prior_transform_and_loglikelihod_model)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment