Skip to content

Instantly share code, notes, and snippets.

@alex-lew
Last active February 16, 2021 17:49
Show Gist options
  • Save alex-lew/2c6c07e5b060d5b098617660138db514 to your computer and use it in GitHub Desktop.
Save alex-lew/2c6c07e5b060d5b098617660138db514 to your computer and use it in GitHub Desktop.
GenCompileContinuous Experiments

GenCompileContinuous.jl

This is a report on an experiment that implements a single function, compile_continuous, aimed at users of the Gen.jl static modeling language.

Given a trace tr of a generative function written in the static modeling language, and a selection sel of addresses of continuous choices (i.e., choices whose values are <: Union{Real, AbstractArray{Real}}), compile_continuous(tr, sel) returns four functions:

  • trace_to_vec(trace): Extract the choices with addresses in sel into a 1D vector of reals. This assumes that each choice has the same dimension and size as it did in tr.
  • vec_to_trace(vec): Given a 1D vector vec of reals, e.g. one produced by trace_to_vec(trace), interpret it as specifying values for each choice in sel, and produce a trace that is equal to tr except for those choices (which are replaced by the values in vec).
  • logpdf(vec): Given a 1D vector vec of reals, e.g. one produced by trace_to_vec(trace), interpret it as specifying values for each choice in sel, and compute the logpdf of the model with those choices replacing the ones in tr.
  • grad_logpdf(vec): Same as logpdf, but also returns a gradient (another 1D vector) of the logpdf with respect to each value in vec.

It is assumed that the set of addresses in tr cannot be altered by changing the values in sel.

The purpose of this is to allow for fast gradient-based updating of continuous variables in a model, holding other variables constant, without creating new choicemaps and traces at every iteration. For example, MAP optimization of variables in sel for a trace tr can be expressed as follows:

# Compile functions specialized to this trace
to_vec, to_trace, logp, gradlogp = compile_continuous(tr, sel)

# Optimize using vectors
v = to_vec(tr)
for i=1:N_ITERS
  _, grad = gradlogp(v)
  v += STEP_SIZE * grad
end

# Convert back to a trace
tr = to_trace(v)

Because the gradlogp operation directly reads values from the vector (using @view to avoid copying them), rather than converting it to a trace, this can be much more performant than using Gen's existing map_optimize!.

The vector interface also plays well with other Julia ecosystem packages, e.g. AdvancedHMC:

# Compile functions specialized to this trace
to_vec, to_trace, logp, gradlogp = compile_continuous(tr, sel)

# Run AdvancedHMC using vectors
v = to_vec(tr)
metric = DiagEuclideanMetric(length(v))
hamiltonian = Hamiltonian(metric, logp, gradlogp)
initial_ϵ = find_good_stepsize(hamiltonian, v)
integrator = Leapfrog(0.0001)
proposal = AdvancedHMC.StaticTrajectory(integrator, 5) # Or, e.g. AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
samples, stats = sample(hamiltonian, proposal, v, N_ITERS, adaptor, N_ADAPTATION_ITERS; progress=true);

# Convert the final sample back to a trace
tr = to_trace(last(samples))

(Severe) Limitations

  1. Currently, only continuous choices made in the top-level static generative function or its static callees can be included in sel. This means that choices made by combinators cannot yet be included (although they can contribute to likelihoods, when upstream choices in sel change).
  2. The current mechanism for creating the four functions returned by compile_continuous is to use eval. This has the downside that, when compile_continuous is called within a function (rather than at the top level of a script), the returned closures have a newer "world age" than the rest of the function -- which causes Julia to error when they are called. This can be side-stepped by using Base.invokelatest (e.g. Base.invokelatest(logp, v)), but with a performance penalty. The performance hit hasn't been too bad in my limited tests, but this is a wart that we should definitely try to fix.
  3. Parts of the input trace tr that are not in sel are explicitly included (as constants) in the generated code. This can have two downsides: (1) the functions returned by compile_continuous are specialized to a particular rest of the trace, rather than just a particular control flow path, so that if e.g. a Metropolis-Hastings move changes a discrete variable, compile_continuous needs to be called anew; and (2) when the values are large (in terms of memory), the returned function rebuilds them (using a constant literal), instead of reading them from tr.
  4. When it can be statically proven that a likelihood term will not change no matter the values of choices in sel, that term is currently not computed by logpdf. This does not affect the gradients, but means that logpdfs may be off by a constant factor. It would be simple enough to emit code that adds the total log constant at the end, before returning logpdf (so that e.g., logpdf-based diagnostics taken for different input tr values are comparable).

On point 2 (and sort of on point 3): we could potentially rewrite compile_continuous as a @generated function, which operates based on StaticIR and a StaticSelection, which are available 'at the type level'. This version would not have access to the runtime values inside tr, and so the code it generates would dynamically read values from the trace, rather statically compiling them in as constants. This could also circumvent the world age issue. However, there are parts of the compile_continuous code that use the runtime values inside tr to deduce things like array sizes and concrete types (e.g. of variables that the static DSL parser could not figure out). It is unclear how to rewrite these using @generated functions.

Another option is to change to_trace, logpdf and gradlogpdf to take as input a trace and a vector. The requirement would be that the input trace has the same shape as tr (the "template" trace), and the same control flow, but not necessarily the same choice values. Users could then, at the top level of their script, make several calls to compile_continuous for a handful of "trace shapes," and call them inside their inference programs without triggering a world age error. This too seems unsatisfactory, though, because there are often an unbounded number of 'trace shapes.'

(In PClean I wound up going with invokelatest. But maybe someone has a better idea!)

Other notes

Using AD directly with logpdf. When only static generative functions are involved, logp is sometimes amenable to end-to-end AD by a framework like ForwardDiff. This can be faster than gradlogp, which separately differentiates each bit of Julia code with Zygote. However, when your model invokes non-static functions, we fall back to using the GFI to compute their logpdfs and gradients, which generally does not play well with AD.

Extensibility. The code is designed to be extensible, e.g. for combinators like Map or custom genreative functions to support the generation of specialized code. Each custom generative function must provide GenCompileContinuous.process_generative_function_call!; I haven't, however, documented the interface.

# An example from Turing
using Flux
import Random
using Gen
#########################
######## BAYES NN #######
#########################
N = 80
M = round(Int, N / 4)
Random.seed!(1234)
# Generate artificial data.
x1s = rand(M) * 4.5; x2s = rand(M) * 4.5;
xt1s = Array([[x1s[i] + 0.5; x2s[i] + 0.5] for i = 1:M])
x1s = rand(M) * 4.5; x2s = rand(M) * 4.5;
append!(xt1s, Array([[x1s[i] - 5; x2s[i] - 5] for i = 1:M]))
x1s = rand(M) * 4.5; x2s = rand(M) * 4.5;
xt0s = Array([[x1s[i] + 0.5; x2s[i] - 5] for i = 1:M])
x1s = rand(M) * 4.5; x2s = rand(M) * 4.5;
append!(xt0s, Array([[x1s[i] - 5; x2s[i] + 0.5] for i = 1:M]))
# Store all the data for later.
xs = [xt1s; xt0s]
ts = [ones(2*M); zeros(2*M)]
# Create a regularization term and a Gaussain prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
# Turn a vector into a set of weights and biases.
function unpack(nn_params::AbstractVector)
W₁ = reshape(nn_params[1:6], 3, 2);
b₁ = reshape(nn_params[7:9], 3)
W₂ = reshape(nn_params[10:15], 2, 3);
b₂ = reshape(nn_params[16:17], 2)
Wₒ = reshape(nn_params[18:19], 1, 2);
bₒ = reshape(nn_params[20:20], 1)
return W₁, b₁, W₂, b₂, Wₒ, bₒ
end
# Construct a neural network using Flux and return a predicted value.
function nn_forward(xs, nn_params::AbstractVector{T}) where T
xs = hcat(xs...)
W₁, b₁, W₂, b₂, Wₒ, bₒ = unpack(nn_params)
nn = Chain(Dense(W₁, b₁, tanh),
Dense(W₂, b₂, tanh),
Dense(Wₒ, bₒ, σ))
return vec(nn(xs)::Array{T,2})
end;
@gen (static) function bern((grad)(p::Float64))
result ~ bernoulli(p)
return result
end
@load_generated_functions
const make_predictions = Map(bern)
# First a version that uses reshaping.
using LinearAlgebra
@gen (static, grad) function bayes_nn(xs::Vector{Vector{Float64}})
mu::Vector{Float64} = zeros(20)
sig::Array{Float64,2} = Matrix(sig^2*1.0I, 20, 20)
nn_params ~ mvnormal(mu, sig)
preds::Vector{Float64} = nn_forward(xs, nn_params)
ts ~ make_predictions(preds)
z::Float64 = 0.0
return z
end
@load_generated_functions
constraints = choicemap()
set_submap!(constraints, :ts, choicemap([((i => :result) => Bool(t)) for (i, t) in enumerate(ts)]...))
tr, = generate(bayes_nn, (xs,), constraints);
nn_to_vec, vec_to_nn, nn_log_p, nn_glog_p = compile_continuous(tr, select(:nn_params))
# Try MAP optimization
function simple_map_optimize(glogp, v, iters, step_size=0.01)
for i=0:iters
lp, glp = glogp(v)
if i % 100 == 0
println(lp)
end
v += step_size * glp
end
return v
end
v = simple_map_optimize(nn_glog_p, nn_to_vec(tr), 5000)
println("Final logp: $(nn_log_p(v))")
## Try HMC
using AdvancedHMC
# Choose parameter dimensionality and initial parameter value
D = 20; initial_θ = ones(20)
# Set the number of samples to draw and warmup iterations
n_samples, n_adapts = 2_000, 1_000
# Define a Hamiltonian system
metric = DiagEuclideanMetric(D)
hamiltonian = Hamiltonian(metric, nn_log_p, nn_glog_p)
# Define a leapfrog solver, with initial step size chosen heuristically
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(0.0001)
# Define an HMC sampler
proposal = AdvancedHMC.StaticTrajectory(integrator, 5)
# Or, for NUTS e.g.:
# proposal = AdvancedHMC.NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))
# Run the sampler to draw samples from the specified Gaussian, where
# - `samples` will store the samples
# - `stats` will store diagnostic statistics for each sample
samples, stats = sample(hamiltonian, proposal, initial_θ, 100, adaptor, 50; progress=true);
println("HMC final logp: $(nn_log_p(last(samples)))")
# Plotting code from Turing:
using Plots
# Plot data points.
function plot_data()
x1 = map(e -> e[1], xt1s)
y1 = map(e -> e[2], xt1s)
x2 = map(e -> e[1], xt0s)
y2 = map(e -> e[2], xt0s)
Plots.scatter(x1,y1, color="red", clim = (0,1))
Plots.scatter!(x2, y2, color="blue", clim = (0,1))
end
plot_data()
x_range = collect(range(-6,stop=6,length=25))
y_range = collect(range(-6,stop=6,length=25))
Z = [nn_forward([[x, y]], last(samples))[1] for x=x_range, y=y_range]
contour!(x_range, y_range, Z)
import Zygote
using Gen
using Gen: RandomChoiceNode, StaticIRGenerativeFunction, ArgumentNode, TrainableParameterNode, GenerativeFunctionCallNode, JuliaNode
struct StackFrame
name::Symbol
parent::Union{Nothing, Symbol}
callee_frames::Dict{Symbol, Symbol}
# Maps to a node in the callee stack frame:
callee_retvals::Dict{Symbol, Any}
# Nodes in the parent stack frame.
arg_list::Vector
trace::Trace
selection::Selection
# Variable names to refer to different nodes of the computation graph.
variable_names::Dict{Symbol, Symbol}
# A set of "dynamic nodes" (their names) that may have changed since the code was generated.
dynamic_nodes :: Set{Symbol}
address_prefixer
# Variable names for storing dL/dx for each intermediate result x (indexed by node.name)
dLd::Dict{Symbol, Symbol}
# vector_index_base_expression::Union{Nothing, Expr} # Useful in loops/Map, in case we need to index at e.g. 4i+j
end
mutable struct CodeGenState
# Variable name for the logp accumulation variable
logp_name :: Symbol
# Input vector
input_vector_name::Symbol
current_vector_index_offset::Int
# Statements
log_p_statements::Vector{Expr}
grad_log_p_statements::Vector{Expr}
grad_log_p_reverse_statements::Vector{Expr}
# These will be used to generate code for converting
# traces to and from vectors.
address_order::Vector
# The subset of dL/dx variable names corresponding to the selected choices,
# in the order they are stored in the vector.
output_gradient_variables::Vector{Symbol}
# Variables related to particular function calls; the "stack"
current_stack_frame :: Symbol
stack :: Dict{Symbol, StackFrame}
end
function frame(state::CodeGenState)
state.stack[state.current_stack_frame]
end
function with_argument_node(f, node::ArgumentNode, state)
ir = Gen.get_ir(typeof(Gen.get_gen_fn(frame(state).trace)))
arg_index = findfirst(n -> n.name == node.name, ir.arg_nodes)
arg = frame(state).arg_list[arg_index]
old_frame_name = state.current_stack_frame
state.current_stack_frame = frame(state).parent
result = f(arg, state)
state.current_stack_frame = old_frame_name
return result
end
function with_retval_node(f, node::GenerativeFunctionCallNode, state)
retval_node = frame(state).callee_retvals[node.name]
old_frame_name = state.current_stack_frame
state.current_stack_frame = frame(state).callee_frames[node.name]
result = f(node, state)
state.current_stack_frame = old_frame_name
return result
end
function is_dynamic(node, state::CodeGenState)
node.name in frame(state).dynamic_nodes
end
function is_dynamic(node::ArgumentNode, state::CodeGenState)
return with_argument_node(is_dynamic, node, state)
end
function is_dynamic(node::GenerativeFunctionCallNode, state::CodeGenState)
# Was the call skipped?
if !haskey(frame(state).callee_frames, node.name)
return false
end
return with_retval_node(is_dynamic, node, state)
end
function has_dynamic_inputs(node::Union{JuliaNode, RandomChoiceNode, GenerativeFunctionCallNode}, state::CodeGenState)
any(n -> is_dynamic(n, state), node.inputs)
end
function mark_dynamic!(node, state::CodeGenState)
push!(frame(state).dynamic_nodes, node.name)
end
function is_tracked(node, state::CodeGenState)
haskey(frame(state).dLd, node.name)
end
function is_tracked(node::ArgumentNode, state::CodeGenState)
return with_argument_node(is_tracked, node, state)
end
function is_tracked(node::GenerativeFunctionCallNode, state::CodeGenState)
if !haskey(frame(state).callee_frames, node.name)
return false
end
return with_retval_node(is_tracked, node, state)
end
function dLd_variable(node, state::CodeGenState)
frame(state).dLd[node.name]
end
function dLd_variable(node::ArgumentNode, state::CodeGenState)
return with_argument_node(dLd_variable, node, state)
end
function dLd_variable(node::GenerativeFunctionCallNode, state::CodeGenState)
# Assume that the call was actually made if we are looking for a dLd variable.
return with_retval_node(dLd_variable, node, state)
end
function static_node_value(node::Gen.StaticIRNode, state::CodeGenState)
trace = frame(state).trace
value = Base.getproperty(trace, Gen.get_value_fieldname(node))
end
function refer_to_static_node!(node, state::CodeGenState)
# We may have already looked this variable up.
if haskey(frame(state).variable_names, node.name)
return frame(state).variable_names[node.name]
end
# Otherwise add a new initialization statement:
static_value_variable = gensym(node.name)
value = static_node_value(node, state)
assmt = :($static_value_variable = $value)
push!(state.log_p_statements, assmt)
push!(state.grad_log_p_statements, assmt)
frame(state).variable_names[node.name] = static_value_variable
return static_value_variable
end
function refer_to_node!(node::ArgumentNode, state::CodeGenState)
return with_argument_node(refer_to_node!, node, state)
end
function refer_to_node!(node::JuliaNode, state::CodeGenState)
if !is_dynamic(node, state)
return refer_to_static_node!(node, state)
end
if haskey(frame(state).variable_names, node.name)
return frame(state).variable_names[node.name]
end
# This is a dynamic Julia node for which we have not issued
# instructions yet. Get the variable names for its inputs,
# then assign a new variable to the result of evaluating
# node.fn:
all_inputs = [refer_to_node!(inp, state) for inp in node.inputs]
result_variable = gensym(node.name)
frame(state).variable_names[node.name] = result_variable
basic_assmt = :($result_variable = $(node.fn)($(all_inputs...)))
push!(state.log_p_statements, basic_assmt)
# To check whether gradients will be needed, we can look at
# the type of the value. Rather than rely on node.typ, which
# could be too general, we can look at the type of the actual
# value computed for this node in the existing trace.
static_value_type = typeof(static_node_value(node, state))
needs_gradients = static_value_type <: Union{Real, AbstractArray{<:Real}}
if !needs_gradients
push!(state.grad_log_p_statements, basic_assmt)
return result_variable
end
# If gradients are needed, instead of evaluating the value by calling node.fn,
# call its Zygote pullback and, in the reverse direction, evaluate the adjoint.
intermediate_deriv_variable = gensym("dLd$(node.name)")
zygote_backward_variable = gensym("zygote_backward_$(node.name)")
zygote_result_variable = gensym("zygote_result_$(node.name)")
zero_val = zero(static_node_value(node, state))
push!(state.grad_log_p_statements, :($zygote_result_variable = Zygote.pullback($(node.fn), $(all_inputs...))))
push!(state.grad_log_p_statements, :($result_variable = first($zygote_result_variable)::$(static_value_type)))
push!(state.grad_log_p_statements, :($zygote_backward_variable = last($zygote_result_variable)))
push!(state.grad_log_p_statements, :($intermediate_deriv_variable = $zero_val))
frame(state).dLd[node.name] = intermediate_deriv_variable
# In the reverse pass, accumulate gradients to inputs:
gradient_variable = gensym("$(node.name)_gradients")
accumulation_statements = [:($(dLd_variable(n, state)) += ($gradient_variable[$i])) for (i, n) in enumerate(node.inputs) if is_tracked(n, state)]
push!(state.grad_log_p_reverse_statements, quote
$gradient_variable = $zygote_backward_variable($intermediate_deriv_variable)
$(accumulation_statements...)
end)
return result_variable
end
function refer_to_node!(node::RandomChoiceNode, state::CodeGenState)
# When referring to a random choice node, there are two cases:
# * It is one of the nodes in the selection, in which case we will already have initialized a variable for it
# and refer_to_static_node! will look up that variable name.
# * It is not one of the nodes in the selection, and we should use its static value.
return refer_to_static_node!(node, state)
end
function process_node!(node::JuliaNode, state::CodeGenState)
if has_dynamic_inputs(node, state)
mark_dynamic!(node, state)
end
end
function process_node!(node::RandomChoiceNode, state::CodeGenState)
# A RandomChoiceNode is dynamic if it lies in the current selection.
current_selection = frame(state).selection
if node.addr in current_selection
mark_dynamic!(node, state)
# This is one of the special variables, so lots of housekeeping to do.
# Unpack the value of the variable from the vector, and
# initialize the variable to hold either a scalar or array.
value_variable = gensym(node.name)
frame(state).variable_names[node.name] = value_variable
static_value = static_node_value(node, state)
index_expression(i) = i # isnothing(state.vector_index_base_expression) ? i : :(state.vector_index_base_expression + $i)
if static_value isa Real
assmt = :($value_variable = $(state.input_vector_name)[$(index_expression(state.current_vector_index_offset))])
state.current_vector_index_offset += 1
elseif static_value isa AbstractArray{<:Real}
value_size = size(static_value)
value_length = length(static_value)
start_index = index_expression(state.current_vector_index_offset)
last_index = index_expression(state.current_vector_index_offset + value_length - 1)
state.current_vector_index_offset += value_length
if length(value_size) == 1
assmt = :($value_variable = @view $(state.input_vector_name)[$start_index:$last_index])
else
assmt = :($value_variable = reshape(@view $(state.input_vector_name)[$start_index:$last_index]), $value_size)
end
else
@error "Selection ($(frame(state).address_prefixer(node.addr))) was of type $(typeof(static_value)); expecting Real or AbstractArray{<:Real}"
end
push!(state.log_p_statements, assmt)
push!(state.grad_log_p_statements, assmt)
# Add to the address order, keeping in mind the address prefix:
addr = frame(state).address_prefixer(node.addr)
push!(state.address_order, addr)
# Set up the variable that will track the derivative w.r.t. this node:
derivative_variable = gensym("dLd$(node.name)")
frame(state).dLd[node.name] = derivative_variable
push!(state.output_gradient_variables, derivative_variable)
zero_val = zero(static_value)
push!(state.grad_log_p_statements, :($derivative_variable = $zero_val))
end
# We need to compute logpdf if this node is dynamic (selected)
# or has dynamic parents.
if is_dynamic(node, state) || has_dynamic_inputs(node, state)
all_inputs = [refer_to_node!(n, state) for n in node.inputs]
value_variable = refer_to_node!(node, state)
log_p_increment = :($(state.logp_name) += Gen.logpdf($(node.dist), $(value_variable), $(all_inputs...)))
push!(state.log_p_statements, log_p_increment)
push!(state.grad_log_p_statements, log_p_increment)
# Which gradients do we need?
needed_gradients = []
if has_output_grad(node.dist) && is_tracked(node, state)
push!(needed_gradients, (dLd_variable(node, state), 1))
end
for (i, has_argument_grad) in enumerate(has_argument_grads(node.dist))
if has_argument_grad && is_tracked(node.inputs[i], state)
push!(needed_gradients, (dLd_variable(node.inputs[i], state), i+1))
end
end
# Add the gradient call and accumulation logic in the reverse pass.
grad_results_var = gensym("$(node.name)_logp_gradients")
assmts = [:($var += $grad_results_var[$grad]) for (var, grad) in needed_gradients]
if !isempty(assmts)
push!(state.grad_log_p_reverse_statements, quote
$grad_results_var = logpdf_grad($(node.dist), $(value_variable), $(all_inputs...))
$(assmts...)
end)
end
end
end
function refer_to_node!(node::GenerativeFunctionCallNode, state::CodeGenState)
if !haskey(frame(state).callee_frames, node.name)
# This node was not descended into in the forward direction.
# Have we already processed?
if haskey(frame(state).variable_names, node.name)
return frame(state).variable_names[node.name]
end
# If not, add an assignment:
static_value_variable = gensym(node.name)
value = get_retval(Base.getproperty(frame(state).trace, Gen.get_subtrace_fieldname(node)))
assmt = :($static_value_variable = $value)
push!(state.log_p_statements, assmt)
push!(state.grad_log_p_statements, assmt)
frame(state).variable_names[node.name] = static_value_variable
return static_value_variable
end
# Otherwise, let the callee handle:
return with_retval_node(refer_to_node!, node, state)
end
function process_node!(node::GenerativeFunctionCallNode, state::CodeGenState)
# First determine whether we need to descend at all.
callee_selection = frame(state).selection[node.addr]
if isempty(callee_selection) && !has_dynamic_inputs(node, state)
return
end
# Either we have dynamic inputs or a random choice in the callee is selected.
# Prepare the state to descend.
current_prefixer = frame(state).address_prefixer
current_frame = state.current_stack_frame
callee_trace = Base.getproperty(frame(state).trace, Gen.get_subtrace_fieldname(node))
callee_frame = StackFrame(gensym(node.name), current_frame, Dict(), Dict(), node.inputs,
callee_trace, callee_selection, Dict(), Set(),
addr -> current_prefixer(node.addr => addr), Dict())
state.stack[callee_frame.name] = callee_frame
frame(state).callee_frames[node.name] = callee_frame.name
state.current_stack_frame = callee_frame.name
# Make the call
retval_node = process_generative_function_call!(node.generative_function, state)
# Reset to how things were.
state.current_stack_frame = current_frame
frame(state).callee_retvals[node.name] = retval_node
end
function process_node!(node, state)
end
struct ExternalArgumentNode
name :: Symbol
arg_idx :: Int
end
function refer_to_node!(node::ExternalArgumentNode, state::CodeGenState)
if haskey(frame(state).variable_names, node.name)
return frame(state).variable_names[node.name]
end
variable_name = gensym()
value = get_args(frame(state).trace)[node.arg_idx]
assmt = :($variable_name = $value)
push!(state.log_p_statements, assmt)
push!(state.grad_log_p_statements, assmt)
frame(state).variable_names[node.name] = variable_name
return variable_name
end
function process_generative_function_call!(gen_fn::StaticIRGenerativeFunction, state::CodeGenState)
ir = Gen.get_ir(typeof(gen_fn))
for node in ir.nodes
process_node!(node, state)
end
return ir.return_node
end
struct BlackBoxRetvalNode
ret_var::Symbol
dLd_var::Symbol
end
# Because a black-box function may have likelihood terms we need to evaluate,
# if it was called, we have definitely run it and added reverse-pass code to
# get gradients of it. There's no
function is_dynamic(::BlackBoxRetvalNode, state::CodeGenState)
return true
end
function is_tracked(::BlackBoxRetvalNode, state::CodeGenState)
return true
end
function dLd_variable(node::BlackBoxRetvalNode, state::CodeGenState)
return node.dLd_var
end
function refer_to_node!(node::BlackBoxRetvalNode, state::CodeGenState)
return ret_var
end
function refer_to_all_arguments!(state::CodeGenState)
# We briefly step up a frame.
arg_list = frame(state).arg_list
current_frame = state.current_stack_frame
state.current_stack_frame = state.stack[current_frame].parent
# Then evaluate all arguments.
args = [refer_to_node!(n, state) for n in arg_list]
argdiffs = [is_dynamic(n, state) ? Gen.UnknownChange() : Gen.NoChange() for n in arg_list]
dLds = [is_tracked(n, state) ? dLd_variable(n, state) : nothing for n in arg_list]
# Then set it back.
state.current_stack_frame = current_frame
return args, argdiffs, dLds
end
function process_generative_function_call!(gen_fn::GenerativeFunction, state::CodeGenState)
# If the current selection is empty, then a black-box generative function call is like
# a glorified RandomChoiceNode: we call `update` (with new args but no choicemap)
# and use the score of the new trace. For gradients, we use choice_gradients but with
# no nodes selected, and look only at the `arg_grads`.
if !isempty(frame(state).selection)
@error "We currently only support the selection of continuous variables in StaticIRGenerativeFunctions, but you have selected $(frame(state).selection) within a $(typeof(gen_fn))."
end
# In a forward pass, we should run `update` to create a new trace, and evaluate
# logp. Note that `choice_gradients` will be necessary whether or not there is a
# dLdretval, so maybe not a bad idea to create one. The challenge is that it needs
# to be associated with
argument_variables, argdiffs, argdLds = refer_to_all_arguments!(state)
trace_variable = gensym("subtrace")
new_trace_var = gensym("updpated_subtrace")
result_var = gensym("callee_retval")
update_call = quote
# TODO: Is it really OK to just put the trace into the code of the generative function like this?
# Another option is to have a name-of-variable for the trace, and create variable names
# for subtraces. Hard to say whether that will be faster or not... But should test.
$trace_variable = $(frame(state).trace)
$new_trace_var, = Gen.update($trace_variable, ($(argument_variables...),), ($(argdiffs...),), Gen.EmptyChoiceMap())
$(state.logp_name) += Gen.get_score($new_trace_var)
$result_var = Gen.get_retval($new_trace_var)
end
push!(state.log_p_statements, update_call)
push!(state.grad_log_p_statements, update_call)
# Which gradients do we need?
dLd_assmts = Expr[]
arg_grads_var = gensym("arg_grads_var")
for (i, (arg_dLd, has_argument_grad)) in enumerate(zip(argdLds, has_argument_grads(gen_fn)))
if has_argument_grad && !isnothing(arg_dLd)
push!(dLd_assmts, :($arg_dLd += $arg_grads_var[$i]))
end
end
# Figure out which args have gradients tracked, and for each of them, increment dLdarg
dLd_var = gensym("dLdcallee_retval")
push!(state.grad_log_p_statements, :($dLd_var = zero($result_var))) # TODO: make this depend on result type
choice_gradients_call = quote
$arg_grads_var, = Gen.choice_gradients($new_trace_var, Gen.EmptySelection(), $dLd_var)
$(dLd_assmts...)
end
push!(state.grad_log_p_reverse_statements, choice_gradients_call)
return BlackBoxRetvalNode(result_var, dLd_var)
end
# TODO: design question of whether we should require `compile_continuous` each time
# anything from the rest of the trace changes, or just when control flow changes.
# Currently we take the stricter path (the former), because values from the old trace
# are baked in.
function compile_continuous(trace::Gen.StaticIRTrace, selection::Selection)
gen_fn = Gen.get_gen_fn(trace)
# Set up a top-level stack frame with the desired arguments.
# (The only thing that should happen at this level is that ExternalArgumentNodes from one layer down are evaluated.)
top_level_stack_frame = StackFrame(:top_level, nothing, Dict(), Dict(), [], trace, selection, Dict(), Set(), (addr -> @error "No choice nodes in top-level frame"), Dict())
# Set up a frame for the call.
arg_list = [ExternalArgumentNode(gensym("arg$i"), i) for i in 1:length(get_args(trace))]
call_frame = StackFrame(gensym("main_call"), :top_level, Dict(), Dict(), arg_list, trace, selection, Dict(), Set(), addr -> addr, Dict())
# Create the stack
stack = Dict(:top_level => top_level_stack_frame, call_frame.name => call_frame)
# Create the state
state = CodeGenState(:logp, :continuous_choices, 1, Expr[], Expr[], Expr[], [], Symbol[], call_frame.name, stack)
# Run the call
process_generative_function_call!(gen_fn, state)
# Code to evaluate logpdf.
logp_function_name = gensym(:continuous_logpdf)
log_p_code = quote
function $(logp_function_name)(continuous_choices::Vector{T}) where T <: Real
logp::T = zero(T) # Should it just be "0"?
$(state.log_p_statements...)
return logp
end
end
# Code to evaluate logpdf and compute gradient w.r.t. continuous_choices vector.
grad_log_p_function_name = gensym(:continuous_grad_logpdf)
# Code to pack gradients back into a vector.
# TODO: Instead of relying on the trace, the state should probably
# track and provide the indices here. This will be necessary for Map.
vector_index = 1
grad_assmts = []
for (address, grad) in zip(state.address_order, state.output_gradient_variables)
value_size = size(trace[address])
value_length = length(trace[address])
if length(value_size) == 0
push!(grad_assmts, quote gradlogp[$vector_index] = $grad end)
elseif length(value_size) == 1
push!(grad_assmts, quote gradlogp[$vector_index:$(vector_index+value_length-1)] = $grad end)
else
push!(grad_assmts, quote gradlogp[$vector_index : $(vector_index+value_length-1)] = reshape($grad, ($value_length,)) end)
end
vector_index += value_length
end
# Gradient code
grad_log_p_code = quote
function $(grad_log_p_function_name)(continuous_choices::Vector{T}) where T <: Real
logp::T = zero(T)
gradlogp::Vector{T} = zero(continuous_choices)
$(state.grad_log_p_statements...)
$(reverse(state.grad_log_p_reverse_statements)...)
$(grad_assmts...)
return logp, gradlogp
end
end
# Vector from trace
vector_from_trace_function_name = gensym(:vector_from_trace)
process_addr(addr) = addr isa Symbol ? Meta.quot(addr) : addr
vector_from_trace_code = quote
function $(vector_from_trace_function_name)(tr::$(typeof(trace)))
Float64[$([:(tr[$(process_addr(addr))]...) for addr in state.address_order]...)]
end
end
# Trace from vector
trace_from_vector_function_name = gensym(:trace_from_vector)
choicemap_var = gensym("choices")
assignments = Expr[]
i = 1
for addr in state.address_order
value_size = size(trace[addr])
value_len = length(trace[addr])
addr = process_addr(addr)
if length(value_size) == 0
push!(assignments, :($choicemap_var[$addr] = continuous_choices[$i]))
elseif length(value_size) == 1
push!(assignments, :($choicemap_var[$addr] = continuous_choices[$i:$(i+value_len-1)]))
else
push!(assignments, :($choicemap_var[$addr] = reshape(continuous_choices[$i:$(i+value_len-1)], $value_size)))
end
i += value_len
end
trace_from_vector_code = quote
function $(trace_from_vector_function_name)(continuous_choices::Vector{Float64})
$choicemap_var = choicemap()
$(assignments...)
new_tr, = update($trace, $(get_args(trace)), map(_ -> Gen.NoChange(), $(get_args(trace))), $choicemap_var)
return new_tr
end
end
return eval(vector_from_trace_code), eval(trace_from_vector_code), eval(log_p_code), eval(grad_log_p_code)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment