Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@matsueushi
Created April 4, 2019 03:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matsueushi/58d78e222db8189594825aa294e390f5 to your computer and use it in GitHub Desktop.
Save matsueushi/58d78e222db8189594825aa294e390f5 to your computer and use it in GitHub Desktop.
GaussianRandomWalk for Mamba
# GaussianRandomWalk for Mamba
# \begin{align*}
# Y_0 &= D,\\
# Y_{i+1} &= Y_i+\mu_i+\epsilon_i,\ \epsilon_i \sim \mbox{Normal}(0, \sigma)\\
# \end{align*}
# Reference:
# Create User-Defined Multivariate Distribution
# https://mambajl.readthedocs.io/en/latest/mcmc/distributions.html#user-defined-univariate-distributions
using Distributed
@everywhere extensions = quote
using Distributions
import Distributions: length, insupport, _logpdf
mutable struct GaussianRandomWalk <: ContinuousMultivariateDistribution
mu::Vector{Float64}
sig::Float64
init::ContinuousUnivariateDistribution
end
length(d::GaussianRandomWalk) = length(d.mu) + 1
function insupport(d::GaussianRandomWalk, x::AbstractVector{T}) where {T <: Real}
length(d) == length(x) && all(isfinite.(x))
end
function _logpdf(d::GaussianRandomWalk, x::AbstractVector{T}) where {T <: Real}
randomwalk_like = logpdf.(Normal.(d.mu + x[1:end - 1], d.sig), x[2:end])
logpdf(d.init, x[1]) + sum(randomwalk_like)
end
end
# Test the extensions
using Distributions
module Testing end
Core.eval(Testing, extensions)
d = Testing.GaussianRandomWalk([1, 3], 1.0, Normal())
Testing.insupport(d, [2.0, 3.0, 3.0])
Testing.logpdf(d, [2.0, 3.0, 3.0])
@everywhere using Mamba
@everywhere eval(extensions)
model = Model(y = Stochastic(1,
sig->GaussianRandomWalk(zeros(99), sqrt(sig), Normal(0, sqrt(sig))),
false),
sig = Stochastic(()->InverseGamma(0.001, 0.001)),
)
scheme = [AMWG(:sig, 10.0)]
setsamplers!(model, scheme)
data = Dict(:y => cumsum(rand(MvNormal(100, sqrt(100)))))
inits = [
Dict(:y => data[:y],
:sig => 1,
)
for _ in 1:3
]
sim = mcmc(model, data, inits, 21000, burnin = 1000, thin = 4, chains = 3)
describe(sim)
println("Actual variance: ", var(diff(data[:y])))
p = Mamba.plot(sim, legend = true)
Mamba.draw(p, nrow = 1, ncol = 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment