Skip to content

Instantly share code, notes, and snippets.

@cscherrer
Last active November 4, 2021 14:15
Show Gist options
  • Save cscherrer/de1856b78be8aaf3d3511650e61e6997 to your computer and use it in GitHub Desktop.
Save cscherrer/de1856b78be8aaf3d3511650e61e6997 to your computer and use it in GitHub Desktop.
Summarizing ZigZag results
using Soss
using ZigZagBoomerang
using MeasureTheory
using Soss: logdensity, xform, ConditionalModel
using ZigZagBoomerang
using ForwardDiff
using ForwardDiff: gradient!
using LinearAlgebra
using SparseArrays
using StructArrays
using TransformVariables
# Soss.xform(m::SpikeMixture) = Soss.xform(m.m)
# Adapted from https://github.com/mschauer/ZigZagBoomerang.jl/blob/master/Soss/sparsezigzag.jl
function zigzag(m::ConditionalModel, T = 100000.0; c=10.0, adapt=false)
ℓ(pars) = logdensity(m, pars)
t = xform(m)
function f(x)
(θ, logjac) = TransformVariables.transform_and_logjac(t, x)
-ℓ(θ) - logjac
end
d = t.dimension
function partiali()
ith = zeros(d)
function (x,i)
ith[i] = 1
sa = StructArray{ForwardDiff.Dual{}}((x, ith))
δ = f(sa).partials[]
ith[i] = 0
return δ
end
end
∇ϕi = partiali()
# Draw a random starting points and velocity
tkeys = keys(transform(t, zeros(d)))
vars = Soss.select(rand(m), tkeys)
t0 = 0.0
x0 = inverse(t, vars)
θ0 = randn(d)
pdmp(∇ϕi, t0, x0, θ0, T, c*ones(d), ZigZag(sparse(I(d)), 0*x0); adapt=adapt)
end
m = @model x begin
α ~ Uniform()
β ~ Normal()
yhat = α .+ β .* x
y ~ For(eachindex(x)) do j
Normal(yhat[j], 2.0)
end
end
x = randn(20);
obs = -0.1 .+ 2x + 1randn(20);
posterior = m(x=x) | (y=obs,)
using ParameterHandling
using OnlineStats
T = 10000.0
trace, final, (num, acc) = @time zigzag(posterior, T)
# trace is a continous object, discretize to obtain samples
ts, xs = ZigZagBoomerang.sep(discretize(trace, 0.1))
using Measurements
function summarize(trace, tform; dt=0.1, ε=0.0001, maxiter=20000)
disc = discretize(trace, dt)
# iterate(disc) has the form (0.0 => vector_we_want, state)
nt = transform(tform, last(first(iterate(disc))))
v, unflatten = flatten(nt)
os = [FitNormal() for _ in v]
function close_enough(n)
o -> √(var(o) / n) / mean(o) < ε
end
n = 0
for xs in disc
n += 1
nt = transform(tform, last(xs))
fit!.(os, first(flatten(nt)))
if all(close_enough(n), os)
@info "Summarization used $n points"
break
end
if n > maxiter
@info "Reached iteration limit $maxiter"
break
end
end
means = mean.(os)
stds = std.(os)
return transform(tform, means .± stds)
end
# julia> summarize(trace, xform(posterior); maxiter=2000)
# [ Info: Reached iteration limit 2000
# (β = 1.45±0.51, α = 0.572±0.052)
# julia> summarize(trace, xform(posterior); ε=0.01)
# [ Info: Summarization used 5746 points
# (β = 1.48±0.42, α = 0.583±0.062)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment