Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Created May 27, 2019 00:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save torfjelde/23603e16e592dd93d87c61bc609d77af to your computer and use it in GitHub Desktop.
Save torfjelde/23603e16e592dd93d87c61bc609d77af to your computer and use it in GitHub Desktop.
Simple parsing of Turing.Model into MetaGraph, allowing visualization of the probabilistic model.
using MacroTools
using Turing
using LightGraphs, MetaGraphs
# Expressions
ex1 = quote
m(x) = begin
# Assumptions
σ ~ InverseGamma(2,3)
μ ~ Normal(0,sqrt(σ))
# Observations
x ~ Normal(μ, sqrt(σ))
end
end
ex2 = quote
m(z) = begin
# priors
ε ~ q_ε
u ~ q_u
# observation
z ~ h_θ(u, ε)
# experimental interface to also track (conditionally) determinstic computations
y ≃ 1 + z
end
end
ex3 = quote
gdemo(x, y) = begin
s ~ InverseGamma(2,3)
m ~ Normal(0,sqrt(s))
x ~ Normal(m, sqrt(s))
y ~ Normal(m, sqrt(s))
end
end
# Parsing and construction of graph
ex = ex2
d = MacroTools.splitdef(ex)
@info d
@info d[:args]
@info d[:body]
g = MetaDiGraph()
sym2vertex = Dict()
add_rv!(g, sym, L, R) = begin
add_vertex!(g)
idx = vertices(g)[end]
# properties
set_prop!(g, idx, :sym, sym)
set_prop!(g, idx, :rv, true)
set_prop!(g, idx, :expr, R)
return idx
end
add_determinstic!(g, sym, L, R) = begin
add_vertex!(g)
idx = vertices(g)[end]
# properties
set_prop!(g, idx, :sym, sym)
set_prop!(g, idx, :rv, false)
set_prop!(g, idx, :expr, R)
return idx
end
add_deps!(g, idx, L, R) = begin
MacroTools.postwalk(R) do e
s = Symbol(e)
if s ∈ keys(sym2vertex)
add_edge!(g, sym2vertex[s], idx)
end
end
end
expr = MacroTools.postwalk(d[:body]) do e
if @capture(e, L_ ~ R_)
sym = Symbol(L)
idx = add_rv!(g, sym, L, R)
# check if observed
if sym ∈ d[:args]
@info "observed" sym
set_prop!(g, idx, :observed, true)
end
sym2vertex[sym] = idx
add_deps!(g, idx, L, R)
elseif @capture(e, L_ ≃ R_)
sym = Symbol(L)
idx = add_determinstic!(g, sym, L, R)
# check if observed
if sym ∈ d[:args]
@info "observed" sym
set_prop!(g, idx, :observed, true)
end
sym2vertex[sym] = idx
add_deps!(g, idx, L, R)
end
return e
end
@info sym2vertex
@info [e for e ∈ edges(g)]
@info "Graph" g
props(g, sym2vertex[:z])[:expr]
# visualize
using TikzPictures, TikzGraphs, Cairo, Fontconfig
# special style for different types of nodes
node_styles = Dict()
for v ∈ vertices(g)
v_props = props(g, v)
s = "draw"
if get(v_props, :rv, true)
s *= ", rounded corners"
end
if get(v_props, :observed, false)
s *= ", fill=green!10"
else
s *= ", fill=blue!10"
end
node_styles[v] = s
println(s)
end
p = TikzGraphs.plot(
g.graph,
[String(props(g, v)[:sym]) for v ∈ vertices(g)],
options="scale=2",
# node_style="draw, fill=blue!10",
node_styles=node_styles
)
TikzPictures.save(PDF("/tmp/test.pdf"), p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment