|
import Zygote |
|
using Gen |
|
using Gen: RandomChoiceNode, StaticIRGenerativeFunction, ArgumentNode, TrainableParameterNode, GenerativeFunctionCallNode, JuliaNode |
|
|
|
struct StackFrame |
|
name::Symbol |
|
parent::Union{Nothing, Symbol} |
|
callee_frames::Dict{Symbol, Symbol} |
|
# Maps to a node in the callee stack frame: |
|
callee_retvals::Dict{Symbol, Any} |
|
|
|
# Nodes in the parent stack frame. |
|
arg_list::Vector |
|
trace::Trace |
|
selection::Selection |
|
|
|
# Variable names to refer to different nodes of the computation graph. |
|
variable_names::Dict{Symbol, Symbol} |
|
|
|
# A set of "dynamic nodes" (their names) that may have changed since the code was generated. |
|
dynamic_nodes :: Set{Symbol} |
|
|
|
address_prefixer |
|
|
|
# Variable names for storing dL/dx for each intermediate result x (indexed by node.name) |
|
dLd::Dict{Symbol, Symbol} |
|
|
|
# vector_index_base_expression::Union{Nothing, Expr} # Useful in loops/Map, in case we need to index at e.g. 4i+j |
|
end |
|
|
|
mutable struct CodeGenState |
|
# Variable name for the logp accumulation variable |
|
logp_name :: Symbol |
|
|
|
# Input vector |
|
input_vector_name::Symbol |
|
current_vector_index_offset::Int |
|
|
|
# Statements |
|
log_p_statements::Vector{Expr} |
|
grad_log_p_statements::Vector{Expr} |
|
grad_log_p_reverse_statements::Vector{Expr} |
|
|
|
# These will be used to generate code for converting |
|
# traces to and from vectors. |
|
address_order::Vector |
|
|
|
# The subset of dL/dx variable names corresponding to the selected choices, |
|
# in the order they are stored in the vector. |
|
output_gradient_variables::Vector{Symbol} |
|
|
|
# Variables related to particular function calls; the "stack" |
|
current_stack_frame :: Symbol |
|
stack :: Dict{Symbol, StackFrame} |
|
end |
|
|
|
function frame(state::CodeGenState) |
|
state.stack[state.current_stack_frame] |
|
end |
|
|
|
function with_argument_node(f, node::ArgumentNode, state) |
|
ir = Gen.get_ir(typeof(Gen.get_gen_fn(frame(state).trace))) |
|
arg_index = findfirst(n -> n.name == node.name, ir.arg_nodes) |
|
arg = frame(state).arg_list[arg_index] |
|
|
|
old_frame_name = state.current_stack_frame |
|
state.current_stack_frame = frame(state).parent |
|
result = f(arg, state) |
|
state.current_stack_frame = old_frame_name |
|
return result |
|
end |
|
|
|
function with_retval_node(f, node::GenerativeFunctionCallNode, state) |
|
retval_node = frame(state).callee_retvals[node.name] |
|
old_frame_name = state.current_stack_frame |
|
state.current_stack_frame = frame(state).callee_frames[node.name] |
|
result = f(node, state) |
|
state.current_stack_frame = old_frame_name |
|
return result |
|
end |
|
|
|
function is_dynamic(node, state::CodeGenState) |
|
node.name in frame(state).dynamic_nodes |
|
end |
|
|
|
function is_dynamic(node::ArgumentNode, state::CodeGenState) |
|
return with_argument_node(is_dynamic, node, state) |
|
end |
|
|
|
function is_dynamic(node::GenerativeFunctionCallNode, state::CodeGenState) |
|
# Was the call skipped? |
|
if !haskey(frame(state).callee_frames, node.name) |
|
return false |
|
end |
|
return with_retval_node(is_dynamic, node, state) |
|
end |
|
|
|
function has_dynamic_inputs(node::Union{JuliaNode, RandomChoiceNode, GenerativeFunctionCallNode}, state::CodeGenState) |
|
any(n -> is_dynamic(n, state), node.inputs) |
|
end |
|
|
|
function mark_dynamic!(node, state::CodeGenState) |
|
push!(frame(state).dynamic_nodes, node.name) |
|
end |
|
|
|
function is_tracked(node, state::CodeGenState) |
|
haskey(frame(state).dLd, node.name) |
|
end |
|
|
|
function is_tracked(node::ArgumentNode, state::CodeGenState) |
|
return with_argument_node(is_tracked, node, state) |
|
end |
|
|
|
function is_tracked(node::GenerativeFunctionCallNode, state::CodeGenState) |
|
if !haskey(frame(state).callee_frames, node.name) |
|
return false |
|
end |
|
return with_retval_node(is_tracked, node, state) |
|
end |
|
|
|
function dLd_variable(node, state::CodeGenState) |
|
frame(state).dLd[node.name] |
|
end |
|
|
|
function dLd_variable(node::ArgumentNode, state::CodeGenState) |
|
return with_argument_node(dLd_variable, node, state) |
|
end |
|
|
|
function dLd_variable(node::GenerativeFunctionCallNode, state::CodeGenState) |
|
# Assume that the call was actually made if we are looking for a dLd variable. |
|
return with_retval_node(dLd_variable, node, state) |
|
end |
|
|
|
function static_node_value(node::Gen.StaticIRNode, state::CodeGenState) |
|
trace = frame(state).trace |
|
value = Base.getproperty(trace, Gen.get_value_fieldname(node)) |
|
end |
|
|
|
function refer_to_static_node!(node, state::CodeGenState) |
|
# We may have already looked this variable up. |
|
if haskey(frame(state).variable_names, node.name) |
|
return frame(state).variable_names[node.name] |
|
end |
|
|
|
# Otherwise add a new initialization statement: |
|
static_value_variable = gensym(node.name) |
|
value = static_node_value(node, state) |
|
assmt = :($static_value_variable = $value) |
|
push!(state.log_p_statements, assmt) |
|
push!(state.grad_log_p_statements, assmt) |
|
|
|
frame(state).variable_names[node.name] = static_value_variable |
|
return static_value_variable |
|
end |
|
|
|
function refer_to_node!(node::ArgumentNode, state::CodeGenState) |
|
return with_argument_node(refer_to_node!, node, state) |
|
end |
|
|
|
function refer_to_node!(node::JuliaNode, state::CodeGenState) |
|
|
|
if !is_dynamic(node, state) |
|
return refer_to_static_node!(node, state) |
|
end |
|
|
|
if haskey(frame(state).variable_names, node.name) |
|
return frame(state).variable_names[node.name] |
|
end |
|
|
|
# This is a dynamic Julia node for which we have not issued |
|
# instructions yet. Get the variable names for its inputs, |
|
# then assign a new variable to the result of evaluating |
|
# node.fn: |
|
all_inputs = [refer_to_node!(inp, state) for inp in node.inputs] |
|
result_variable = gensym(node.name) |
|
frame(state).variable_names[node.name] = result_variable |
|
basic_assmt = :($result_variable = $(node.fn)($(all_inputs...))) |
|
push!(state.log_p_statements, basic_assmt) |
|
|
|
# To check whether gradients will be needed, we can look at |
|
# the type of the value. Rather than rely on node.typ, which |
|
# could be too general, we can look at the type of the actual |
|
# value computed for this node in the existing trace. |
|
static_value_type = typeof(static_node_value(node, state)) |
|
needs_gradients = static_value_type <: Union{Real, AbstractArray{<:Real}} |
|
if !needs_gradients |
|
push!(state.grad_log_p_statements, basic_assmt) |
|
return result_variable |
|
end |
|
|
|
# If gradients are needed, instead of evaluating the value by calling node.fn, |
|
# call its Zygote pullback and, in the reverse direction, evaluate the adjoint. |
|
intermediate_deriv_variable = gensym("dLd$(node.name)") |
|
zygote_backward_variable = gensym("zygote_backward_$(node.name)") |
|
zygote_result_variable = gensym("zygote_result_$(node.name)") |
|
zero_val = zero(static_node_value(node, state)) |
|
push!(state.grad_log_p_statements, :($zygote_result_variable = Zygote.pullback($(node.fn), $(all_inputs...)))) |
|
push!(state.grad_log_p_statements, :($result_variable = first($zygote_result_variable)::$(static_value_type))) |
|
push!(state.grad_log_p_statements, :($zygote_backward_variable = last($zygote_result_variable))) |
|
push!(state.grad_log_p_statements, :($intermediate_deriv_variable = $zero_val)) |
|
frame(state).dLd[node.name] = intermediate_deriv_variable |
|
|
|
# In the reverse pass, accumulate gradients to inputs: |
|
gradient_variable = gensym("$(node.name)_gradients") |
|
accumulation_statements = [:($(dLd_variable(n, state)) += ($gradient_variable[$i])) for (i, n) in enumerate(node.inputs) if is_tracked(n, state)] |
|
push!(state.grad_log_p_reverse_statements, quote |
|
$gradient_variable = $zygote_backward_variable($intermediate_deriv_variable) |
|
$(accumulation_statements...) |
|
end) |
|
|
|
return result_variable |
|
end |
|
|
|
function refer_to_node!(node::RandomChoiceNode, state::CodeGenState) |
|
# When referring to a random choice node, there are two cases: |
|
# * It is one of the nodes in the selection, in which case we will already have initialized a variable for it |
|
# and refer_to_static_node! will look up that variable name. |
|
# * It is not one of the nodes in the selection, and we should use its static value. |
|
return refer_to_static_node!(node, state) |
|
end |
|
|
|
function process_node!(node::JuliaNode, state::CodeGenState) |
|
if has_dynamic_inputs(node, state) |
|
mark_dynamic!(node, state) |
|
end |
|
end |
|
|
|
function process_node!(node::RandomChoiceNode, state::CodeGenState) |
|
# A RandomChoiceNode is dynamic if it lies in the current selection. |
|
current_selection = frame(state).selection |
|
if node.addr in current_selection |
|
mark_dynamic!(node, state) |
|
|
|
# This is one of the special variables, so lots of housekeeping to do. |
|
|
|
# Unpack the value of the variable from the vector, and |
|
# initialize the variable to hold either a scalar or array. |
|
value_variable = gensym(node.name) |
|
frame(state).variable_names[node.name] = value_variable |
|
static_value = static_node_value(node, state) |
|
index_expression(i) = i # isnothing(state.vector_index_base_expression) ? i : :(state.vector_index_base_expression + $i) |
|
if static_value isa Real |
|
assmt = :($value_variable = $(state.input_vector_name)[$(index_expression(state.current_vector_index_offset))]) |
|
state.current_vector_index_offset += 1 |
|
elseif static_value isa AbstractArray{<:Real} |
|
value_size = size(static_value) |
|
value_length = length(static_value) |
|
start_index = index_expression(state.current_vector_index_offset) |
|
last_index = index_expression(state.current_vector_index_offset + value_length - 1) |
|
state.current_vector_index_offset += value_length |
|
|
|
if length(value_size) == 1 |
|
assmt = :($value_variable = @view $(state.input_vector_name)[$start_index:$last_index]) |
|
else |
|
assmt = :($value_variable = reshape(@view $(state.input_vector_name)[$start_index:$last_index]), $value_size) |
|
end |
|
else |
|
@error "Selection ($(frame(state).address_prefixer(node.addr))) was of type $(typeof(static_value)); expecting Real or AbstractArray{<:Real}" |
|
end |
|
push!(state.log_p_statements, assmt) |
|
push!(state.grad_log_p_statements, assmt) |
|
|
|
# Add to the address order, keeping in mind the address prefix: |
|
addr = frame(state).address_prefixer(node.addr) |
|
push!(state.address_order, addr) |
|
|
|
# Set up the variable that will track the derivative w.r.t. this node: |
|
derivative_variable = gensym("dLd$(node.name)") |
|
frame(state).dLd[node.name] = derivative_variable |
|
push!(state.output_gradient_variables, derivative_variable) |
|
zero_val = zero(static_value) |
|
push!(state.grad_log_p_statements, :($derivative_variable = $zero_val)) |
|
end |
|
|
|
# We need to compute logpdf if this node is dynamic (selected) |
|
# or has dynamic parents. |
|
if is_dynamic(node, state) || has_dynamic_inputs(node, state) |
|
all_inputs = [refer_to_node!(n, state) for n in node.inputs] |
|
value_variable = refer_to_node!(node, state) |
|
|
|
log_p_increment = :($(state.logp_name) += Gen.logpdf($(node.dist), $(value_variable), $(all_inputs...))) |
|
push!(state.log_p_statements, log_p_increment) |
|
push!(state.grad_log_p_statements, log_p_increment) |
|
|
|
# Which gradients do we need? |
|
needed_gradients = [] |
|
if has_output_grad(node.dist) && is_tracked(node, state) |
|
push!(needed_gradients, (dLd_variable(node, state), 1)) |
|
end |
|
for (i, has_argument_grad) in enumerate(has_argument_grads(node.dist)) |
|
if has_argument_grad && is_tracked(node.inputs[i], state) |
|
push!(needed_gradients, (dLd_variable(node.inputs[i], state), i+1)) |
|
end |
|
end |
|
# Add the gradient call and accumulation logic in the reverse pass. |
|
grad_results_var = gensym("$(node.name)_logp_gradients") |
|
assmts = [:($var += $grad_results_var[$grad]) for (var, grad) in needed_gradients] |
|
if !isempty(assmts) |
|
push!(state.grad_log_p_reverse_statements, quote |
|
$grad_results_var = logpdf_grad($(node.dist), $(value_variable), $(all_inputs...)) |
|
$(assmts...) |
|
end) |
|
end |
|
end |
|
end |
|
|
|
function refer_to_node!(node::GenerativeFunctionCallNode, state::CodeGenState) |
|
if !haskey(frame(state).callee_frames, node.name) |
|
# This node was not descended into in the forward direction. |
|
# Have we already processed? |
|
if haskey(frame(state).variable_names, node.name) |
|
return frame(state).variable_names[node.name] |
|
end |
|
# If not, add an assignment: |
|
static_value_variable = gensym(node.name) |
|
value = get_retval(Base.getproperty(frame(state).trace, Gen.get_subtrace_fieldname(node))) |
|
assmt = :($static_value_variable = $value) |
|
push!(state.log_p_statements, assmt) |
|
push!(state.grad_log_p_statements, assmt) |
|
frame(state).variable_names[node.name] = static_value_variable |
|
return static_value_variable |
|
end |
|
|
|
# Otherwise, let the callee handle: |
|
return with_retval_node(refer_to_node!, node, state) |
|
end |
|
|
|
function process_node!(node::GenerativeFunctionCallNode, state::CodeGenState) |
|
# First determine whether we need to descend at all. |
|
callee_selection = frame(state).selection[node.addr] |
|
if isempty(callee_selection) && !has_dynamic_inputs(node, state) |
|
return |
|
end |
|
|
|
# Either we have dynamic inputs or a random choice in the callee is selected. |
|
# Prepare the state to descend. |
|
current_prefixer = frame(state).address_prefixer |
|
current_frame = state.current_stack_frame |
|
callee_trace = Base.getproperty(frame(state).trace, Gen.get_subtrace_fieldname(node)) |
|
callee_frame = StackFrame(gensym(node.name), current_frame, Dict(), Dict(), node.inputs, |
|
callee_trace, callee_selection, Dict(), Set(), |
|
addr -> current_prefixer(node.addr => addr), Dict()) |
|
state.stack[callee_frame.name] = callee_frame |
|
frame(state).callee_frames[node.name] = callee_frame.name |
|
state.current_stack_frame = callee_frame.name |
|
|
|
# Make the call |
|
retval_node = process_generative_function_call!(node.generative_function, state) |
|
|
|
# Reset to how things were. |
|
state.current_stack_frame = current_frame |
|
frame(state).callee_retvals[node.name] = retval_node |
|
end |
|
|
|
function process_node!(node, state) |
|
end |
|
|
|
struct ExternalArgumentNode |
|
name :: Symbol |
|
arg_idx :: Int |
|
end |
|
|
|
function refer_to_node!(node::ExternalArgumentNode, state::CodeGenState) |
|
if haskey(frame(state).variable_names, node.name) |
|
return frame(state).variable_names[node.name] |
|
end |
|
variable_name = gensym() |
|
value = get_args(frame(state).trace)[node.arg_idx] |
|
assmt = :($variable_name = $value) |
|
push!(state.log_p_statements, assmt) |
|
push!(state.grad_log_p_statements, assmt) |
|
frame(state).variable_names[node.name] = variable_name |
|
return variable_name |
|
end |
|
|
|
function process_generative_function_call!(gen_fn::StaticIRGenerativeFunction, state::CodeGenState) |
|
ir = Gen.get_ir(typeof(gen_fn)) |
|
|
|
for node in ir.nodes |
|
process_node!(node, state) |
|
end |
|
|
|
return ir.return_node |
|
end |
|
|
|
struct BlackBoxRetvalNode |
|
ret_var::Symbol |
|
dLd_var::Symbol |
|
end |
|
|
|
# Because a black-box function may have likelihood terms we need to evaluate, |
|
# if it was called, we have definitely run it and added reverse-pass code to |
|
# get gradients of it. There's no |
|
function is_dynamic(::BlackBoxRetvalNode, state::CodeGenState) |
|
return true |
|
end |
|
function is_tracked(::BlackBoxRetvalNode, state::CodeGenState) |
|
return true |
|
end |
|
function dLd_variable(node::BlackBoxRetvalNode, state::CodeGenState) |
|
return node.dLd_var |
|
end |
|
function refer_to_node!(node::BlackBoxRetvalNode, state::CodeGenState) |
|
return ret_var |
|
end |
|
|
|
function refer_to_all_arguments!(state::CodeGenState) |
|
# We briefly step up a frame. |
|
arg_list = frame(state).arg_list |
|
current_frame = state.current_stack_frame |
|
state.current_stack_frame = state.stack[current_frame].parent |
|
# Then evaluate all arguments. |
|
args = [refer_to_node!(n, state) for n in arg_list] |
|
argdiffs = [is_dynamic(n, state) ? Gen.UnknownChange() : Gen.NoChange() for n in arg_list] |
|
dLds = [is_tracked(n, state) ? dLd_variable(n, state) : nothing for n in arg_list] |
|
# Then set it back. |
|
state.current_stack_frame = current_frame |
|
return args, argdiffs, dLds |
|
end |
|
|
|
function process_generative_function_call!(gen_fn::GenerativeFunction, state::CodeGenState) |
|
# If the current selection is empty, then a black-box generative function call is like |
|
# a glorified RandomChoiceNode: we call `update` (with new args but no choicemap) |
|
# and use the score of the new trace. For gradients, we use choice_gradients but with |
|
# no nodes selected, and look only at the `arg_grads`. |
|
if !isempty(frame(state).selection) |
|
@error "We currently only support the selection of continuous variables in StaticIRGenerativeFunctions, but you have selected $(frame(state).selection) within a $(typeof(gen_fn))." |
|
end |
|
|
|
# In a forward pass, we should run `update` to create a new trace, and evaluate |
|
# logp. Note that `choice_gradients` will be necessary whether or not there is a |
|
# dLdretval, so maybe not a bad idea to create one. The challenge is that it needs |
|
# to be associated with |
|
|
|
argument_variables, argdiffs, argdLds = refer_to_all_arguments!(state) |
|
trace_variable = gensym("subtrace") |
|
new_trace_var = gensym("updpated_subtrace") |
|
result_var = gensym("callee_retval") |
|
update_call = quote |
|
# TODO: Is it really OK to just put the trace into the code of the generative function like this? |
|
# Another option is to have a name-of-variable for the trace, and create variable names |
|
# for subtraces. Hard to say whether that will be faster or not... But should test. |
|
$trace_variable = $(frame(state).trace) |
|
$new_trace_var, = Gen.update($trace_variable, ($(argument_variables...),), ($(argdiffs...),), Gen.EmptyChoiceMap()) |
|
$(state.logp_name) += Gen.get_score($new_trace_var) |
|
$result_var = Gen.get_retval($new_trace_var) |
|
end |
|
push!(state.log_p_statements, update_call) |
|
push!(state.grad_log_p_statements, update_call) |
|
|
|
# Which gradients do we need? |
|
dLd_assmts = Expr[] |
|
arg_grads_var = gensym("arg_grads_var") |
|
for (i, (arg_dLd, has_argument_grad)) in enumerate(zip(argdLds, has_argument_grads(gen_fn))) |
|
if has_argument_grad && !isnothing(arg_dLd) |
|
push!(dLd_assmts, :($arg_dLd += $arg_grads_var[$i])) |
|
end |
|
end |
|
|
|
# Figure out which args have gradients tracked, and for each of them, increment dLdarg |
|
dLd_var = gensym("dLdcallee_retval") |
|
push!(state.grad_log_p_statements, :($dLd_var = zero($result_var))) # TODO: make this depend on result type |
|
|
|
choice_gradients_call = quote |
|
$arg_grads_var, = Gen.choice_gradients($new_trace_var, Gen.EmptySelection(), $dLd_var) |
|
$(dLd_assmts...) |
|
end |
|
|
|
push!(state.grad_log_p_reverse_statements, choice_gradients_call) |
|
|
|
return BlackBoxRetvalNode(result_var, dLd_var) |
|
end |
|
|
|
|
|
# TODO: design question of whether we should require `compile_continuous` each time |
|
# anything from the rest of the trace changes, or just when control flow changes. |
|
# Currently we take the stricter path (the former), because values from the old trace |
|
# are baked in. |
|
function compile_continuous(trace::Gen.StaticIRTrace, selection::Selection) |
|
gen_fn = Gen.get_gen_fn(trace) |
|
|
|
# Set up a top-level stack frame with the desired arguments. |
|
# (The only thing that should happen at this level is that ExternalArgumentNodes from one layer down are evaluated.) |
|
top_level_stack_frame = StackFrame(:top_level, nothing, Dict(), Dict(), [], trace, selection, Dict(), Set(), (addr -> @error "No choice nodes in top-level frame"), Dict()) |
|
|
|
# Set up a frame for the call. |
|
arg_list = [ExternalArgumentNode(gensym("arg$i"), i) for i in 1:length(get_args(trace))] |
|
call_frame = StackFrame(gensym("main_call"), :top_level, Dict(), Dict(), arg_list, trace, selection, Dict(), Set(), addr -> addr, Dict()) |
|
|
|
# Create the stack |
|
stack = Dict(:top_level => top_level_stack_frame, call_frame.name => call_frame) |
|
|
|
# Create the state |
|
state = CodeGenState(:logp, :continuous_choices, 1, Expr[], Expr[], Expr[], [], Symbol[], call_frame.name, stack) |
|
|
|
# Run the call |
|
process_generative_function_call!(gen_fn, state) |
|
|
|
# Code to evaluate logpdf. |
|
logp_function_name = gensym(:continuous_logpdf) |
|
log_p_code = quote |
|
function $(logp_function_name)(continuous_choices::Vector{T}) where T <: Real |
|
logp::T = zero(T) # Should it just be "0"? |
|
$(state.log_p_statements...) |
|
return logp |
|
end |
|
end |
|
|
|
# Code to evaluate logpdf and compute gradient w.r.t. continuous_choices vector. |
|
grad_log_p_function_name = gensym(:continuous_grad_logpdf) |
|
|
|
# Code to pack gradients back into a vector. |
|
# TODO: Instead of relying on the trace, the state should probably |
|
# track and provide the indices here. This will be necessary for Map. |
|
vector_index = 1 |
|
grad_assmts = [] |
|
for (address, grad) in zip(state.address_order, state.output_gradient_variables) |
|
value_size = size(trace[address]) |
|
value_length = length(trace[address]) |
|
if length(value_size) == 0 |
|
push!(grad_assmts, quote gradlogp[$vector_index] = $grad end) |
|
elseif length(value_size) == 1 |
|
push!(grad_assmts, quote gradlogp[$vector_index:$(vector_index+value_length-1)] = $grad end) |
|
else |
|
push!(grad_assmts, quote gradlogp[$vector_index : $(vector_index+value_length-1)] = reshape($grad, ($value_length,)) end) |
|
end |
|
vector_index += value_length |
|
end |
|
|
|
# Gradient code |
|
grad_log_p_code = quote |
|
function $(grad_log_p_function_name)(continuous_choices::Vector{T}) where T <: Real |
|
logp::T = zero(T) |
|
gradlogp::Vector{T} = zero(continuous_choices) |
|
$(state.grad_log_p_statements...) |
|
$(reverse(state.grad_log_p_reverse_statements)...) |
|
$(grad_assmts...) |
|
return logp, gradlogp |
|
end |
|
end |
|
|
|
# Vector from trace |
|
vector_from_trace_function_name = gensym(:vector_from_trace) |
|
process_addr(addr) = addr isa Symbol ? Meta.quot(addr) : addr |
|
vector_from_trace_code = quote |
|
function $(vector_from_trace_function_name)(tr::$(typeof(trace))) |
|
Float64[$([:(tr[$(process_addr(addr))]...) for addr in state.address_order]...)] |
|
end |
|
end |
|
|
|
|
|
# Trace from vector |
|
trace_from_vector_function_name = gensym(:trace_from_vector) |
|
|
|
choicemap_var = gensym("choices") |
|
assignments = Expr[] |
|
i = 1 |
|
for addr in state.address_order |
|
value_size = size(trace[addr]) |
|
value_len = length(trace[addr]) |
|
addr = process_addr(addr) |
|
if length(value_size) == 0 |
|
push!(assignments, :($choicemap_var[$addr] = continuous_choices[$i])) |
|
elseif length(value_size) == 1 |
|
push!(assignments, :($choicemap_var[$addr] = continuous_choices[$i:$(i+value_len-1)])) |
|
else |
|
push!(assignments, :($choicemap_var[$addr] = reshape(continuous_choices[$i:$(i+value_len-1)], $value_size))) |
|
end |
|
i += value_len |
|
end |
|
trace_from_vector_code = quote |
|
function $(trace_from_vector_function_name)(continuous_choices::Vector{Float64}) |
|
$choicemap_var = choicemap() |
|
$(assignments...) |
|
new_tr, = update($trace, $(get_args(trace)), map(_ -> Gen.NoChange(), $(get_args(trace))), $choicemap_var) |
|
return new_tr |
|
end |
|
end |
|
|
|
|
|
return eval(vector_from_trace_code), eval(trace_from_vector_code), eval(log_p_code), eval(grad_log_p_code) |
|
end |
|
|
|
|