Skip to content

Instantly share code, notes, and snippets.

@maxentile
Last active November 18, 2019 14:08
Show Gist options
  • Save maxentile/5b9c352f9ee75a0d01fafa2870d1c85a to your computer and use it in GitHub Desktop.
Save maxentile/5b9c352f9ee75a0d01fafa2870d1c85a to your computer and use it in GitHub Desktop.
# simplified from GP structure learning example, section 7.2
# https://github.com/probcomp/pldi2019-gen-experiments/tree/master/gp -- `cov_tree.jl` and `lightweight.jl`
using Gen
# abstract types for binary expression tree
abstract type Node end
abstract type LeafNode <: Node end
abstract type BinaryOpNode <: Node end
num_children(node::LeafNode) = 0
num_children(node::BinaryOpNode) = 2
Base.size(::LeafNode) = 1
Base.size(node::BinaryOpNode) = node.size
# implement leaf nodes
struct Constant <: LeafNode param::Float64 end
predict(node::Constant, xs::Vector{Float64}) = fill(node.param, length(xs))
struct Linear <: LeafNode coefficient::Float64 end
predict(node::Linear, xs::Vector{Float64}) = node.coefficient .* xs
struct Periodic <: LeafNode period::Float64 end
predict(node::Periodic, xs::Vector{Float64}) = sin.(2pi/node.period .* xs)
# implement binary operation nodes
struct Plus <: BinaryOpNode
left::Node
right::Node
size::Int
end
Plus(left, right) = Plus(left, right, size(left) + size(right))
predict(node::Plus, xs::Vector{Float64}) = predict(node.left, xs) .+ predict(node.right, xs)
struct Times <: BinaryOpNode
left::Node
right::Node
size::Int
end
Times(left, right) = Times(left, right, size(left) + size(right))
predict(node::Times, xs::Vector{Float64}) = predict(node.left, xs) .* predict(node.right, xs)
# defining a prior over expression trees
leaf_nodes = [Constant, Linear, Periodic]
binary_ops = [Plus, Times]
leaf_dist = ones(length(leaf_nodes)) ./ length(leaf_nodes)
binary_op_dist = ones(length(binary_ops)) ./ length(binary_ops)
p_split(depth) = exp(-Real(depth)/2 - 0.1)
@gen function expression_prior(depth)
if @trace(bernoulli(1 - p_split(depth)), :leaf_or_binary)
param = @trace(normal(0,1), :param)
leaf_index = @trace(categorical(leaf_dist), :leaf_index)
node = leaf_nodes[leaf_index](param)
else
left = @trace(expression_prior(depth+1), :left)
right = @trace(expression_prior(depth+1), :right)
binary_op_index = @trace(categorical(binary_op_dist), :binary_op_index)
node = binary_ops[binary_op_index](left, right)
end
return node
end
# put it all together
sigma = 1.0
@gen function model(xs::Vector{Float64})
expr_tree::Node = @trace(expression_prior(0), :tree)
mean = predict(expr_tree, xs)
for i=1:length(mean)
@trace(normal(mean[i], sigma), (:y, i))
end
return expr_tree
end
function initialize_trace(xs::Vector{Float64}, ys::Vector{Float64})
constraints = choicemap()
for i=1:length(ys)
constraints[(:y, i)] = ys[i]
end
(trace, weight) = generate(model, (xs,), constraints)
return trace
end
@gen function random_node_path_unbiased(node::Node)
p_stop = isa(node, LeafNode) ? 1.0 : 1/size(node)
if @trace(bernoulli(p_stop), :stop)
return :tree
else
p_left = size(node.left) / (size(node) - 1)
(next_node, direction) = @trace(bernoulli(p_left), :left) ? (node.left, :left) : (node.right, :right)
rest_of_path = @trace(random_node_path_unbiased(next_node), :rest_of_path)
if isa(rest_of_path, Pair)
return :tree => direction => rest_of_path[2]
else
return :tree => direction
end
end
end
@gen function regen_random_subtree(prev_trace)
@trace(expression_prior(0), :new_subtree)
@trace(random_node_path_unbiased(get_retval(prev_trace)), :path)
end
function subtree_involution(trace, fwd_assmt::ChoiceMap, path_to_subtree, proposal_args::Tuple)
# Need to return a new trace, a bwd_assmt, and a weight.
model_assmt = get_choices(trace)
# backward assessment
bwd_assmt = choicemap()
set_submap!(bwd_assmt, :path, get_submap(fwd_assmt, :path))
set_submap!(bwd_assmt, :new_subtree, get_submap(model_assmt, path_to_subtree))
# new_trace_update
new_trace_update = choicemap()
set_submap!(new_trace_update, path_to_subtree, get_submap(fwd_assmt, :new_subtree))
# new_trace
(new_trace, weight, _, _) =
update(trace, get_args(trace), (NoChange(),), new_trace_update)
# return tuple
(new_trace, bwd_assmt, weight)
end
function run_mcmc(trace, iters::Int)
score_traj = Vector{Float64}([trace.score])
n_accept = 0
for iter=1:iters
(trace, accept) = mh(trace, regen_random_subtree, (), subtree_involution)
push!(score_traj, trace.score)
n_accept += accept
end
return trace, score_traj, n_accept
end
# constructing a training set
n = 100
root = Plus(Times(Periodic(0.5), Linear(-1.0)), Periodic(0.45))
x_train = Vector{Float64}(3 .* rand(n)) .- 1.5
y_train = Vector{Float64}(predict(root, x_train))
# run some inference!
initial_trace = initialize_trace(x_train, y_train)
trace, score_traj, n_accept = run_mcmc(initial_trace, 10000)
println("acceptance rate: $(n_accept/length(score_traj))")
# plot results
xs = Vector{Float64}(range(-4,stop=4,length=1000));
ys = predict(root, xs);
y_pred = predict(trace.retval, xs);
using Plots
plot(xs, ys, label="underlying function")
plot!(xs, y_pred, label="predicted function")
scatter!(x_train, y_train, label="training set")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment