Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@cscherrer
Created December 28, 2017 20:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cscherrer/52a7e4520daec553c82930de55afc245 to your computer and use it in GitHub Desktop.
Save cscherrer/52a7e4520daec553c82930de55afc245 to your computer and use it in GitHub Desktop.
Probabilistic programming with source transformations, in Julia
using Distributions
using StatsFuns
using MacroTools
using MacroTools: postwalk
# A model is just a special kind of quoted expression:
myModel = quote
real(μ)
μ ~ Normal(0,5)
positiveReal(σ)
σ ~ Cauchy(0,3)
for x in data
x ~ Normal(μ,σ)
end
end
# `real` and `positiveReal` work as declarations, in order to
# identify the support of the parameter space.
# An inference method is a macro that transforms the model
# code. This gives flexibility and speed. Operational
# consequences of declarations and ~ are not fixed, but depend
# on the inference method.
# For example, here's an interpretation that implements a very
# simple Stan workalike:
macro logdensity(ex)
body = postwalk(@eval ($ex)) do x
if @capture(x, v_ ~ dist_)
quote
ℓ += logpdf($dist, $v)
end
elseif @capture(x, real(v_))
quote
$v = θ[𝚥]
𝚥 += 1
end
elseif @capture(x, positiveReal(v_))
quote
$v = softplus(θ[𝚥])
# Jacobian correction:
# Add log-determinant of the Jacobian of the transformation
# Careful with this, need to double check
ℓ += abs($v - θ[𝚥])
𝚥 += 1
end
else x
end
end
quote
function(θ, data)
ℓ = 0.0
𝚥 = 1
$body
return ℓ
end
end
end
# The idea is to walk the code looking for parameter
# declarations. The parameters are represented as a vector of
# real values. Each time we see a parameter, we reparameterize
# so it comes from the reals, and we increment the index 𝚥.
# To use the Stan approach, we just apply the macro:
logDensity = @logdensity myModel
# From here, the next step would usually be inference; Stan uses NUTS
# or variational inference, either of which is pretty quick to get to
# in Julia. But for simplicity, let's just visualize the result for a
# fixed σ.
using Plots
pyplot()
μs = linspace(-10,10,1000)
data = randn(10) .- 5
# `logDensity` is in terms of the joint, not conditional, density.
# Here's a quick hack to get at the latter:
priorLogZ = logsumexp([logDensity([μ,1.0],[]) for μ in linspace(-30,30,1000)])
postLogZ = logsumexp([logDensity([μ,1.0],data) for μ in linspace(-30,30,1000)])
logprior = [logDensity([μ,1.0],[]) - priorLogZ for μ in μs]
logpost = [logDensity([μ,1.0],data) - postLogZ for μ in μs]
plot(μs,exp.(logprior), label="prior")
plot!(μs, exp.(logpost), label="posterior")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment