Skip to content

Instantly share code, notes, and snippets.

@mschauer
Last active March 16, 2021 11:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mschauer/638f2f50abd6469cec1c38e561374d24 to your computer and use it in GitHub Desktop.
Save mschauer/638f2f50abd6469cec1c38e561374d24 to your computer and use it in GitHub Desktop.
Simplistic ZigZag for Soss
using Soss: logdensity, xform, ConditionalModel
using ZigZagBoomerang
using ForwardDiff
using ForwardDiff: gradient!
using LinearAlgebra
using SparseArrays
using StructArrays
using TransformVariables
Soss.xform(m::SpikeMixture) = Soss.xform(m.m)
kappa(m::SpikeMixture{WeightedMeasure{Float64,Lebesgue{ℝ}},Float64}) = 1/(1/m.w - 1)
kappa(m::WeightedMeasure{Float64,Lebesgue{ℝ}}) = Inf
Soss.xform(m::SpikeMixture) = Soss.xform(m.m)
function sparse_zigzag(m::ConditionalModel, T = 1000.0; c=10.0, adapt=false)
ℓ(pars) = logdensity(m, pars)
t = xform(m)
function f(x)
(θ, logjac) = TransformVariables.transform_and_logjac(t, x)
-ℓ(θ) - logjac
end
d = t.dimension
function partiali()
ith = zeros(d)
function (x,i)
ith[i] = 1
sa = StructArray{ForwardDiff.Dual{}}((x, ith))
δ = f(sa).partials[]
ith[i] = 0
return δ
end
end
∇ϕi = partiali()
# Draw a random starting points and velocity
tkeys = keys(t(zeros(d)))
vars = Soss.select(rand(m), tkeys)
bm_ = Soss.basemeasure(m, vars)
bm = getindex.(Ref(bm_.data), (keys(vars)))
κ = kappa.(bm)
t0 = 0.0
x0 = inverse(t, vars)
θ0 = randn(d)
sspdmp(∇ϕi, t0, x0, θ0, T, c*ones(d), ZigZag(sparse(I(d)), 0*x0), κ; adapt=adapt)
end
using Soss: logdensity, xform, ConditionalModel
using ZigZagBoomerang
using ForwardDiff
using ForwardDiff: gradient!
using LinearAlgebra
using SparseArrays
using StructArrays
using TransformVariables
Soss.xform(m::SpikeMixture) = Soss.xform(m.m)
function zigzag(m::ConditionalModel, T = 1000.0; c=10.0, adapt=false)
ℓ(pars) = logdensity(m, pars)
t = xform(m)
function f(x)
(θ, logjac) = TransformVariables.transform_and_logjac(t, x)
-ℓ(θ) - logjac
end
d = t.dimension
function partiali()
ith = zeros(d)
function (x,i)
ith[i] = 1
sa = StructArray{ForwardDiff.Dual{}}((x, ith))
δ = f(sa).partials[]
ith[i] = 0
return δ
end
end
∇ϕi = partiali()
# Draw a random starting points and velocity
tkeys = keys(t(zeros(d)))
vars = Soss.select(rand(m), tkeys)
t0 = 0.0
x0 = inverse(t, vars)
θ0 = randn(d)
pdmp(∇ϕi, t0, x0, θ0, T, c*ones(d), ZigZag(sparse(I(d)), 0*x0); adapt=adapt)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment