Last active
February 16, 2022 22:42
-
-
Save sethaxen/b18e4101a7fd0b2cfe986dd240e99213 to your computer and use it in GitHub Desktop.
Getting parameter ranges from Turing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using DynamicPPL | |
# utilities for working with Turing model parameter names using only the DynamicPPL API | |
""" | |
flattened_varnames_list(model::DynamicPPL.Model) -> Vector{Symbol} | |
Get a vector of varnames as `Symbol`s with one-to-one correspondence to the | |
flattened parameter vector. | |
```julia | |
julia> @model function demo() | |
s ~ Dirac(1) | |
x = Matrix{Float64}(undef, 2, 4) | |
x[1, 1] ~ Dirac(2) | |
x[2, 1] ~ Dirac(3) | |
x[3] ~ Dirac(4) | |
y ~ Dirac(5) | |
x[4] ~ Dirac(6) | |
x[:, 3] ~ arraydist([Dirac(7), Dirac(8)]) | |
x[[2, 1], 4] ~ arraydist([Dirac(9), Dirac(10)]) | |
return s, x, y | |
end | |
demo (generic function with 2 methods) | |
julia> flattened_varnames_list(demo()) | |
10-element Vector{Symbol}: | |
:s | |
Symbol("x[1,1]") | |
Symbol("x[2,1]") | |
Symbol("x[3]") | |
Symbol("x[4]") | |
Symbol("x[:,3][1]") | |
Symbol("x[:,3][2]") | |
Symbol("x[[2, 1],4][1]") | |
Symbol("x[[2, 1],4][2]") | |
:y | |
``` | |
""" | |
function flattened_varnames_list(model::DynamicPPL.Model) | |
varnames_ranges = varnames_to_ranges(model) | |
nsyms = maximum(maximum, values(varnames_ranges)) | |
syms = Vector{Symbol}(undef, nsyms) | |
for (var_name, range) in varnames_to_ranges(model) | |
sym = Symbol(var_name) | |
if length(range) == 1 | |
syms[range[begin]] = sym | |
continue | |
end | |
for i in eachindex(range) | |
syms[range[i]] = Symbol("$sym[$i]") | |
end | |
end | |
return syms | |
end | |
# code snippet shared by @torfjelde | |
""" | |
varnames_to_ranges(model::DynamicPPL.Model) | |
varnames_to_ranges(model::DynamicPPL.VarInfo) | |
varnames_to_ranges(model::DynamicPPL.Metadata) | |
Get `Dict` mapping variable names in model to their ranges in a corresponding parameter vector. | |
# Examples | |
```julia | |
julia> @model function demo() | |
s ~ Dirac(1) | |
x = Matrix{Float64}(undef, 2, 4) | |
x[1, 1] ~ Dirac(2) | |
x[2, 1] ~ Dirac(3) | |
x[3] ~ Dirac(4) | |
y ~ Dirac(5) | |
x[4] ~ Dirac(6) | |
x[:, 3] ~ arraydist([Dirac(7), Dirac(8)]) | |
x[[2, 1], 4] ~ arraydist([Dirac(9), Dirac(10)]) | |
return s, x, y | |
end | |
demo (generic function with 2 methods) | |
julia> demo()() | |
(1, Any[2.0 4.0 7 10; 3.0 6.0 8 9], 5) | |
julia> varnames_to_ranges(demo()) | |
Dict{AbstractPPL.VarName, UnitRange{Int64}} with 8 entries: | |
s => 1:1 | |
x[4] => 5:5 | |
x[:,3] => 6:7 | |
x[1,1] => 2:2 | |
x[2,1] => 3:3 | |
x[[2, 1],4] => 8:9 | |
x[3] => 4:4 | |
y => 10:10 | |
``` | |
""" | |
function varnames_to_ranges end | |
varnames_to_ranges(model::DynamicPPL.Model) = varnames_to_ranges(DynamicPPL.VarInfo(model)) | |
varnames_to_ranges(varinfo::DynamicPPL.UntypedVarInfo) = varnames_to_ranges(varinfo.metadata) | |
function varnames_to_ranges(varinfo::DynamicPPL.TypedVarInfo) | |
offset = 0 | |
dicts = map(varinfo.metadata) do md | |
vns2ranges = varnames_to_ranges(md) | |
vals = collect(values(vns2ranges)) | |
vals_offset = map(r -> offset .+ r, vals) | |
offset += reduce((curr, r) -> max(curr, r[end]), vals; init=0) | |
Dict(zip(keys(vns2ranges), vals_offset)) | |
end | |
return reduce(merge, dicts) | |
end | |
function varnames_to_ranges(metadata::DynamicPPL.Metadata) | |
idcs = map(Base.Fix1(getindex, metadata.idcs), metadata.vns) | |
ranges = metadata.ranges[idcs] | |
return Dict(zip(metadata.vns, ranges)) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment