Last active
November 18, 2019 14:08
-
-
Save maxentile/5b9c352f9ee75a0d01fafa2870d1c85a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# simplified from GP structure learning example, section 7.2 | |
# https://github.com/probcomp/pldi2019-gen-experiments/tree/master/gp -- `cov_tree.jl` and `lightweight.jl` | |
using Gen | |
# abstract types for binary expression tree | |
abstract type Node end | |
abstract type LeafNode <: Node end | |
abstract type BinaryOpNode <: Node end | |
num_children(node::LeafNode) = 0 | |
num_children(node::BinaryOpNode) = 2 | |
Base.size(::LeafNode) = 1 | |
Base.size(node::BinaryOpNode) = node.size | |
# implement leaf nodes | |
struct Constant <: LeafNode param::Float64 end | |
predict(node::Constant, xs::Vector{Float64}) = fill(node.param, length(xs)) | |
struct Linear <: LeafNode coefficient::Float64 end | |
predict(node::Linear, xs::Vector{Float64}) = node.coefficient .* xs | |
struct Periodic <: LeafNode period::Float64 end | |
predict(node::Periodic, xs::Vector{Float64}) = sin.(2pi/node.period .* xs) | |
# implement binary operation nodes | |
struct Plus <: BinaryOpNode | |
left::Node | |
right::Node | |
size::Int | |
end | |
Plus(left, right) = Plus(left, right, size(left) + size(right)) | |
predict(node::Plus, xs::Vector{Float64}) = predict(node.left, xs) .+ predict(node.right, xs) | |
struct Times <: BinaryOpNode | |
left::Node | |
right::Node | |
size::Int | |
end | |
Times(left, right) = Times(left, right, size(left) + size(right)) | |
predict(node::Times, xs::Vector{Float64}) = predict(node.left, xs) .* predict(node.right, xs) | |
# defining a prior over expression trees | |
leaf_nodes = [Constant, Linear, Periodic] | |
binary_ops = [Plus, Times] | |
leaf_dist = ones(length(leaf_nodes)) ./ length(leaf_nodes) | |
binary_op_dist = ones(length(binary_ops)) ./ length(binary_ops) | |
p_split(depth) = exp(-Real(depth)/2 - 0.1) | |
@gen function expression_prior(depth) | |
if @trace(bernoulli(1 - p_split(depth)), :leaf_or_binary) | |
param = @trace(normal(0,1), :param) | |
leaf_index = @trace(categorical(leaf_dist), :leaf_index) | |
node = leaf_nodes[leaf_index](param) | |
else | |
left = @trace(expression_prior(depth+1), :left) | |
right = @trace(expression_prior(depth+1), :right) | |
binary_op_index = @trace(categorical(binary_op_dist), :binary_op_index) | |
node = binary_ops[binary_op_index](left, right) | |
end | |
return node | |
end | |
# put it all together | |
sigma = 1.0 | |
@gen function model(xs::Vector{Float64}) | |
expr_tree::Node = @trace(expression_prior(0), :tree) | |
mean = predict(expr_tree, xs) | |
for i=1:length(mean) | |
@trace(normal(mean[i], sigma), (:y, i)) | |
end | |
return expr_tree | |
end | |
function initialize_trace(xs::Vector{Float64}, ys::Vector{Float64}) | |
constraints = choicemap() | |
for i=1:length(ys) | |
constraints[(:y, i)] = ys[i] | |
end | |
(trace, weight) = generate(model, (xs,), constraints) | |
return trace | |
end | |
@gen function random_node_path_unbiased(node::Node) | |
p_stop = isa(node, LeafNode) ? 1.0 : 1/size(node) | |
if @trace(bernoulli(p_stop), :stop) | |
return :tree | |
else | |
p_left = size(node.left) / (size(node) - 1) | |
(next_node, direction) = @trace(bernoulli(p_left), :left) ? (node.left, :left) : (node.right, :right) | |
rest_of_path = @trace(random_node_path_unbiased(next_node), :rest_of_path) | |
if isa(rest_of_path, Pair) | |
return :tree => direction => rest_of_path[2] | |
else | |
return :tree => direction | |
end | |
end | |
end | |
@gen function regen_random_subtree(prev_trace) | |
@trace(expression_prior(0), :new_subtree) | |
@trace(random_node_path_unbiased(get_retval(prev_trace)), :path) | |
end | |
function subtree_involution(trace, fwd_assmt::ChoiceMap, path_to_subtree, proposal_args::Tuple) | |
# Need to return a new trace, a bwd_assmt, and a weight. | |
model_assmt = get_choices(trace) | |
# backward assessment | |
bwd_assmt = choicemap() | |
set_submap!(bwd_assmt, :path, get_submap(fwd_assmt, :path)) | |
set_submap!(bwd_assmt, :new_subtree, get_submap(model_assmt, path_to_subtree)) | |
# new_trace_update | |
new_trace_update = choicemap() | |
set_submap!(new_trace_update, path_to_subtree, get_submap(fwd_assmt, :new_subtree)) | |
# new_trace | |
(new_trace, weight, _, _) = | |
update(trace, get_args(trace), (NoChange(),), new_trace_update) | |
# return tuple | |
(new_trace, bwd_assmt, weight) | |
end | |
function run_mcmc(trace, iters::Int) | |
score_traj = Vector{Float64}([trace.score]) | |
n_accept = 0 | |
for iter=1:iters | |
(trace, accept) = mh(trace, regen_random_subtree, (), subtree_involution) | |
push!(score_traj, trace.score) | |
n_accept += accept | |
end | |
return trace, score_traj, n_accept | |
end | |
# constructing a training set | |
n = 100 | |
root = Plus(Times(Periodic(0.5), Linear(-1.0)), Periodic(0.45)) | |
x_train = Vector{Float64}(3 .* rand(n)) .- 1.5 | |
y_train = Vector{Float64}(predict(root, x_train)) | |
# run some inference! | |
initial_trace = initialize_trace(x_train, y_train) | |
trace, score_traj, n_accept = run_mcmc(initial_trace, 10000) | |
println("acceptance rate: $(n_accept/length(score_traj))") | |
# plot results | |
xs = Vector{Float64}(range(-4,stop=4,length=1000)); | |
ys = predict(root, xs); | |
y_pred = predict(trace.retval, xs); | |
using Plots | |
plot(xs, ys, label="underlying function") | |
plot!(xs, y_pred, label="predicted function") | |
scatter!(x_train, y_train, label="training set") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment