Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Last active May 17, 2024 08:32
Show Gist options
  • Save torfjelde/e455f1a5c44c65496ecd651ed49881a5 to your computer and use it in GitHub Desktop.
Save torfjelde/e455f1a5c44c65496ecd651ed49881a5 to your computer and use it in GitHub Desktop.
Example of how to _generate_ a Turing.jl model. This can be useful if one is working with very performance critical code where we want to unroll loops of `~` statements, etc. to improve performance.
julia> using DynamicPPL, Distributions
julia> struct NTModel{names,V}
nt::NamedTuple{names,V}
end
julia> model_template = NTModel((a=Normal(0,1), b=Normal(100, 1)))
NTModel{(:a, :b), Tuple{Normal{Float64}, Normal{Float64}}}((a = Normal{Float64}(μ=0.0, σ=1.0), b = Normal{Float64}(μ=100.0, σ=1.0)))
julia> function define_model_from_template(model_template::NTModel; model_name::Symbol=gensym(:nt_model))
# This will be the body of the model.
body = []
# The return values of the model.
retvals = Expr(:tuple)
# Iterate over the keys and values to do two things.
for (name, dist) in pairs(model_template.nt)
# 1. Add a `~` statement to the body.
push!(body, :($name ~ $dist))
# 2. Add the `name` to the return values.
push!(retvals.args, name)
end
# Construct the actual model.
names = keys(model_template.nt)
@eval @model function $(model_name)()
$(body...)
return NamedTuple{$names}($retvals)
end
end
define_model_from_template (generic function with 2 methods)
julia> # Define the model from the template.
# NOTE: we need to capture the returned function, as this generates a new
# function name every time it is called.
demo_model = define_model_from_template(model_template)
##nt_model#369 (generic function with 2 methods)
julia> demo_model()()
(a = 1.4612352529081432, b = 100.0517306153272)
julia> # We can then use this inside another model using `@submodel`.
@model function outer_model(inner_model)
@submodel parameters = inner_model
x ~ Normal(parameters.a, parameters.b)
return (; x, parameters)
end
outer_model (generic function with 2 methods)
julia> model = outer_model(demo_model())
Model{typeof(outer_model), (:inner_model,), (), (), Tuple{Model{var"###nt_model#369", (), (), (), Tuple{}, Tuple{}, DefaultContext}}, Tuple{}, DefaultContext}(outer_model, (inner_model = Model{var"###nt_model#369", (), (), (), Tuple{}, Tuple{}, DefaultContext}(var"##nt_model#369", NamedTuple(), NamedTuple(), DefaultContext()),), NamedTuple(), DefaultContext())
julia> model()
(x = -7.9766227799777365, parameters = (a = 0.5489158156933944, b = 99.56385067809187))
julia> # Alternative approach: specify using a macro, though this doesn't work with programmatic generation.
"""
@model_from_namedtuple exprs...
Construct a model from specifications of the form `lhs = rhs`.
# Example
```julia
julia> nt_demo = @model_from_namedtuple a=Normal(0, 1) b=Normal(100, 1)
##nt_model#590 (generic function with 2 methods)
julia> model = nt_demo();
julia> model()
(a = -0.8077731696095273, b = 99.11691965855493)
```
Can also handle dependencies if ordered correctly.
```julia
julia> model_demo_alt = @model_from_namedtuple a=Normal(0, 1) b=Normal(10 * a, 1)
##nt_model#747 (generic function with 2 methods)
julia> model_demo_alt()()
(a = 1.9066394738007792, b = 18.95744677736865)
```
"""
macro model_from_namedtuple(exprs...)
# Every expression should be of the form `lhs = rhs`.
@assert all(Base.Fix2(Meta.isexpr, :(=)), exprs) "All expressions should be of the form `lhs = rhs`."
# Extract the LHS and RHS of the expressions.
lhs_rhs_iter = map(DynamicPPL.getargs_assignment, exprs)
model_name = gensym(:nt_model)
# This will be the body of the model.
body = []
# The return values of the model.
retvals = Expr(:tuple)
# Iterate over the keys and values to do two things.
for (name, dist) in lhs_rhs_iter
# 1. Add a `~` statement to the body.
push!(body, :($name ~ $dist))
# 2. Add the `name` to the return values.
push!(retvals.args, name)
end
# Construct the actual model.
names = map(first, lhs_rhs_iter)
expr = :(function $(model_name)()
$(body...)
return NamedTuple{$names}($retvals)
end)
return esc(DynamicPPL.model(__module__, __source__, expr, false))
end
@model_from_namedtuple
julia> model_demo_alt = @model_from_namedtuple a=Normal(0, 1) b=Normal(100, 1)
##nt_model#602 (generic function with 2 methods)
julia> model_demo_alt()()
(a = -1.2205519233973439, b = 101.82300087682167)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment