Created
December 28, 2017 20:34
-
-
Save cscherrer/52a7e4520daec553c82930de55afc245 to your computer and use it in GitHub Desktop.
Probabilistic programming with source transformations, in Julia
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions | |
using StatsFuns | |
using MacroTools | |
using MacroTools: postwalk | |
# A model is just a special kind of quoted expression: | |
myModel = quote | |
real(μ) | |
μ ~ Normal(0,5) | |
positiveReal(σ) | |
σ ~ Cauchy(0,3) | |
for x in data | |
x ~ Normal(μ,σ) | |
end | |
end | |
# `real` and `positiveReal` work as declarations, in order to | |
# identify the support of the parameter space. | |
# An inference method is a macro that transforms the model | |
# code. This gives flexibility and speed. Operational | |
# consequences of declarations and ~ are not fixed, but depend | |
# on the inference method. | |
# For example, here's an interpretation that implements a very | |
# simple Stan workalike: | |
macro logdensity(ex) | |
body = postwalk(@eval ($ex)) do x | |
if @capture(x, v_ ~ dist_) | |
quote | |
ℓ += logpdf($dist, $v) | |
end | |
elseif @capture(x, real(v_)) | |
quote | |
$v = θ[𝚥] | |
𝚥 += 1 | |
end | |
elseif @capture(x, positiveReal(v_)) | |
quote | |
$v = softplus(θ[𝚥]) | |
# Jacobian correction: | |
# Add log-determinant of the Jacobian of the transformation | |
# Careful with this, need to double check | |
ℓ += abs($v - θ[𝚥]) | |
𝚥 += 1 | |
end | |
else x | |
end | |
end | |
quote | |
function(θ, data) | |
ℓ = 0.0 | |
𝚥 = 1 | |
$body | |
return ℓ | |
end | |
end | |
end | |
# The idea is to walk the code looking for parameter | |
# declarations. The parameters are represented as a vector of | |
# real values. Each time we see a parameter, we reparameterize | |
# so it comes from the reals, and we increment the index 𝚥. | |
# To use the Stan approach, we just apply the macro: | |
logDensity = @logdensity myModel | |
# From here, the next step would usually be inference; Stan uses NUTS | |
# or variational inference, either of which is pretty quick to get to | |
# in Julia. But for simplicity, let's just visualize the result for a | |
# fixed σ. | |
using Plots | |
pyplot() | |
μs = linspace(-10,10,1000) | |
data = randn(10) .- 5 | |
# `logDensity` is in terms of the joint, not conditional, density. | |
# Here's a quick hack to get at the latter: | |
priorLogZ = logsumexp([logDensity([μ,1.0],[]) for μ in linspace(-30,30,1000)]) | |
postLogZ = logsumexp([logDensity([μ,1.0],data) for μ in linspace(-30,30,1000)]) | |
logprior = [logDensity([μ,1.0],[]) - priorLogZ for μ in μs] | |
logpost = [logDensity([μ,1.0],data) - postLogZ for μ in μs] | |
plot(μs,exp.(logprior), label="prior") | |
plot!(μs, exp.(logpost), label="posterior") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment