Last active
December 20, 2021 12:22
-
-
Save torfjelde/5dd1ed93a81759c98ff0ef3feeb24237 to your computer and use it in GitHub Desktop.
This requires `NestedSamplers#tor/improvements` and `Turing@0.19.2` or higher.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> include("nested_samplers.jl") | |
julia> @model function demo() | |
m ~ Normal() | |
x ~ Normal(m, 1) | |
y ~ Normal(x, 1) | |
return (; m, x, y) | |
end | |
demo (generic function with 2 methods) | |
julia> model = NestedModel(demo() | (y = 1, )); | |
julia> bounds = Bounds.MultiEllipsoid | |
NestedSamplers.Bounds.MultiEllipsoid | |
julia> prop = Proposals.Slice(slices=10) | |
NestedSamplers.Proposals.Slice | |
slices: Int64 10 | |
scale: Float64 1.0 | |
julia> # 1000 live points | |
sampler = Nested(2, 1000; bounds=bounds, proposal=prop) | |
Nested{UnionAll, NestedSamplers.Proposals.Slice}(2, 1000, NestedSamplers.Bounds.MultiEllipsoid, 1.25, 18000, 2000, 0.1, NestedSamplers.Proposals.Slice | |
slices: Int64 10 | |
scale: Float64 1.0 | |
, 0.0009995003330836028) | |
julia> chain = first(sample(model, sampler; dlogz=0.2)) | |
Nested Sampling 100%|███████████████████████████████████████████████| Time: 0:00:00 | |
┌ Warning: timestamp of type Missing unknown | |
└ @ MCMCChains ~/.julia/packages/MCMCChains/pPqxj/src/chains.jl:364 | |
┌ Warning: timestamp of type Missing unknown | |
└ @ MCMCChains ~/.julia/packages/MCMCChains/pPqxj/src/chains.jl:364 | |
Chains MCMC chain (3420×3×1 Array{Float64, 3}): | |
Log evidence = -1.6317018558704408 | |
Iterations = 1:1:3420 | |
Number of chains = 1 | |
Samples per chain = 3420 | |
parameters = Parameter 1, Parameter 2 | |
internals = weights | |
Summary Statistics | |
parameters mean std naive_se mcse ess rhat | |
Symbol Float64 Float64 Float64 Float64 Float64 Float64 | |
Parameter 1 0.2859 0.8713 0.0149 0.0512 40.5689 1.0509 | |
Parameter 2 0.5766 1.0119 0.0173 0.0999 14.0664 1.1749 | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
Symbol Float64 Float64 Float64 Float64 Float64 | |
Parameter 1 -1.5621 -0.2712 0.3187 0.8734 1.9186 | |
Parameter 2 -2.0324 0.1966 0.8158 1.1451 2.2301 | |
julia> # Or from Turing.jl's side. | |
chain = sample( | |
demo() | (y = 1, ), | |
NestedSampler(1000; bounds=bounds, proposal=prop), | |
1000; # TODO: Make it so we don't have to specify the number of steps? | |
dlogz=0.2 | |
) | |
Sampling 100%|██████████████████████████████████████████████████████| Time: 0:00:00 | |
┌ Warning: timestamp of type Missing unknown | |
└ @ MCMCChains ~/.julia/packages/MCMCChains/pPqxj/src/chains.jl:364 | |
┌ Warning: timestamp of type Missing unknown | |
└ @ MCMCChains ~/.julia/packages/MCMCChains/pPqxj/src/chains.jl:364 | |
Chains MCMC chain (2000×3×1 Array{Float64, 3}): | |
Log evidence = -1.658376606163276 | |
Iterations = 1:1:2000 | |
Number of chains = 1 | |
Samples per chain = 2000 | |
parameters = m, x | |
internals = weights | |
Summary Statistics | |
parameters mean std naive_se mcse ess rhat | |
Symbol Float64 Float64 Float64 Float64 Float64 Float64 | |
m 0.1037 0.9324 0.0208 0.0692 16.5079 1.1133 | |
x 0.2792 1.2525 0.0280 0.1375 8.0195 1.2631 | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
Symbol Float64 Float64 Float64 Float64 Float64 | |
m -1.7807 -0.5142 0.1251 0.7495 1.8497 | |
x -2.4141 -0.4269 0.3345 1.1117 2.5301 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Turing | |
using NestedSamplers | |
using Setfield: Setfield | |
using Random: Random | |
# Implementation of `InferenceAlgorithm`. | |
struct NestedSampler{B,P,Ks} <: Turing.Inference.InferenceAlgorithm | |
num_particles::Int | |
bounds::B | |
proposal::P | |
kwargs::Ks | |
end | |
function NestedSampler( | |
num_particles; | |
bounds=Bounds.MultiEllipsoid, | |
proposal=Proposals.Slice(slices=10), | |
kwargs... | |
) | |
return NestedSampler(num_particles, bounds, proposal, kwargs) | |
end | |
struct NestedSamplerState{M,V,S,Spl} | |
"model from NestedSamplers.jl" | |
model::M | |
"representative `AbstractVarInfo` from the target model" | |
varinfo::V | |
"state from NestedSamplers.jl" | |
state::S | |
"sampler from NestedSamplers.jl" | |
sampler::Spl | |
end | |
DynamicPPL.initialsampler(sampler::DynamicPPL.Sampler{<:NestedSampler}) = sampler | |
DynamicPPL.getspace(::DynamicPPL.Sampler{<:NestedSampler}) = () | |
function AbstractMCMC.step( | |
rng::Random.AbstractRNG, | |
model::DynamicPPL.Model, | |
spl::DynamicPPL.Sampler{<:NestedSampler}; | |
kwargs... | |
) | |
# Set up the `AbstractVarInfo`. | |
varinfo = DynamicPPL.VarInfo(rng, model, spl) | |
# Construct the `NestedModel`. | |
nested_model = NestedSamplers.NestedModel(model, varinfo, spl) | |
# Set up the `NestedSamplers.Nested` sampler. | |
alg = spl.alg | |
sampler_nested = NestedSamplers.Nested( | |
length(varinfo[spl]), alg.num_particles; | |
bounds=alg.bounds, | |
proposal=alg.proposal, | |
alg.kwargs... | |
) | |
# Transition. | |
transition, state_nested = AbstractMCMC.step( | |
rng, nested_model, sampler_nested; | |
kwargs... | |
) | |
# Construct state. | |
state = NestedSamplerState(nested_model, varinfo, state_nested, sampler_nested) | |
return transition, state | |
end | |
function AbstractMCMC.step( | |
rng::Random.AbstractRNG, | |
model::DynamicPPL.Model, | |
spl::DynamicPPL.Sampler{<:NestedSampler}, | |
state::NestedSamplerState; | |
kwargs... | |
) | |
transition, state_nested = AbstractMCMC.step( | |
rng, state.model, state.sampler, state.state; | |
kwargs... | |
) | |
return transition, Setfield.@set(state.state = state_nested) | |
end | |
function DynamicPPL.assume( | |
rng::Random.AbstractRNG, | |
sampler::DynamicPPL.Sampler{<:NestedSampler}, | |
dist::Distribution, | |
vn::DynamicPPL.VarName, | |
vi::DynamicPPL.AbstractVarInfo, | |
) | |
if haskey(vi, vn) | |
# If it's already present, we assume it's in unit-cube space | |
# but we want to return it in the original space. | |
# TODO: We should probably specify this somehow, e.g. using | |
# a field in `sampler` or something. | |
val = invlogcdf(dist, log(vi[vn])) | |
vi = DynamicPPL.setindex!!(vi, val, vn) | |
else | |
val = rand(rng, dist) | |
vi = DynamicPPL.push!!(vi, vn, val, dist, sampler) | |
end | |
return val, zero(eltype(val)), vi | |
end | |
function AbstractMCMC.bundle_samples( | |
ts::Vector, | |
model::DynamicPPL.Model, | |
spl::DynamicPPL.Sampler{<:NestedSampler}, | |
state::NestedSamplerState, | |
chain_type::Type{MCMCChains.Chains}; | |
kwargs... | |
) | |
chain = first(AbstractMCMC.bundle_samples( | |
ts, model, state.sampler, state.state, chain_type; | |
kwargs... | |
)) | |
return MCMCChains.replacenames( | |
chain, | |
Dict(zip(names(chain, :parameters), map(string, keys(state.varinfo)))) | |
) | |
end | |
################################### | |
### NestedSamplers.jl interface ### | |
################################### | |
function NestedSamplers.NestedModel(model::DynamicPPL.Model) | |
return NestedSamplers.NestedModel(Random.GLOBAL_RNG, model) | |
end | |
function NestedSamplers.NestedModel(rng::Random.AbstractRNG, model::DynamicPPL.Model) | |
sampler = DynamicPPL.Sampler(NestedSampler()) | |
vi = DynamicPPL.VarInfo(rng, model, sampler) | |
return NestedSamplers.NestedModel(model, vi, sampler) | |
end | |
function NestedSamplers.NestedModel( | |
model::DynamicPPL.Model, | |
vi::DynamicPPL.AbstractVarInfo, | |
sampler::DynamicPPL.AbstractSampler=DynamicPPL.Sampler(NestedSampler()) | |
) | |
vi_base = deepcopy(vi) | |
function prior_transform_and_loglikelihod_model(u) | |
# Update in unit-cube space. | |
vi_new = DynamicPPL.setindex!!(vi_base, u, sampler) | |
# Evaluate model, computing the transformed variables and the loglikelihood. | |
_, vi_new = DynamicPPL.evaluate!!(model, vi_new, sampler) | |
# Return the new samples and loglikelihood. | |
return vi_new[sampler], DynamicPPL.getlogp(vi_new) | |
end | |
return NestedSamplers.NestedModel(prior_transform_and_loglikelihod_model) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment