Last active
April 6, 2022 17:14
-
-
Save devmotion/37d8d706938364eeb900bc3678860da6 to your computer and use it in GitHub Desktop.
SDE inference (based on https://gist.github.com/mschauer/d1b95bc7031eb858e94de9fb86622c75)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using AdvancedMH | |
using ArraysOfArrays | |
using CairoMakie | |
using DiffEqNoiseProcess | |
using Distributions | |
using StochasticDiffEq | |
using Turing | |
using Random | |
struct CrankNicolsonProposal{P,T} <: AdvancedMH.Proposal{P} | |
proposal::P | |
sinθ::T | |
cosθ::T | |
end | |
function CrankNicolsonProposal(proposal; θ::Real) | |
return CrankNicolsonProposal(proposal, sincospi(θ)...) | |
end | |
function AdvancedMH.propose( | |
rng::Random.AbstractRNG, proposal::CrankNicolsonProposal, ::DensityModel | |
) | |
return rand(rng, proposal.proposal) | |
end | |
function AdvancedMH.propose( | |
rng::Random.AbstractRNG, | |
proposal::CrankNicolsonProposal, | |
::DensityModel, | |
X | |
) | |
return proposal.sinθ .* X .+ proposal.cosθ .* rand(rng, proposal.proposal) | |
end | |
struct Wiener{T,N} <: Distribution{ArrayLikeVariate{N},Continuous} | |
sqrt_dt::T | |
size::NTuple{N,Int} | |
end | |
function Wiener(dt::Real, size::Dims) | |
sqrt_dt = sqrt(dt) | |
return Wiener{typeof(sqrt_dt),length(size)}(sqrt_dt, size) | |
end | |
Base.size(d::Wiener) = d.size | |
Base.eltype(::Type{<:Wiener{T}}) where {T} = T | |
function Distributions._rand!(rng::Random.AbstractRNG, d::Wiener, x::AbstractVector) | |
randn!(rng, x) | |
x .*= d.sqrt_dt | |
x[1] = zero(eltype(x)) | |
cumsum!(x, x) | |
return x | |
end | |
function Distributions._rand!(rng::Random.AbstractRNG, d::Wiener, X::AbstractMatrix) | |
randn!(rng, X) | |
X .*= d.sqrt_dt | |
fill!(view(X, :, 1), zero(eltype(X))) | |
cumsum!(X, X; dims=2) | |
return X | |
end | |
function AdvancedMH.logratio_proposal_density( | |
::CrankNicolsonProposal{<:Wiener}, state, candidate | |
) | |
return 0 | |
end | |
wiener = Wiener(0.5, (10_000, 3)) | |
W = rand(wiener) | |
mean(W; dims=1) | |
var(W; dims=1) | |
Wstep = AdvancedMH.propose( | |
Random.GLOBAL_RNG, CrankNicolsonProposal(wiener; θ=0.35), DensityModel(identity), W | |
) | |
function f(du, u, θ, t) | |
c = 0.2 * θ | |
du[1] = -0.1 * u[1] + c * u[2] | |
du[2] = - c * u[1] - 0.1 * u[2] | |
return | |
end | |
function g(du, u, θ, t) | |
fill!(du, 0.15) | |
return | |
end | |
x0 = [1.0, 1.0] | |
tspan = (0.0, 20.0) | |
θ0 = 1.0 | |
dt = 0.05 | |
t = range(tspan...; step=dt) | |
saveat = range(tspan...; step=10*dt) | |
prob = SDEProblem{true}(f, g, x0, tspan, θ0) | |
sol = solve(prob, EM(); save_noise=true, dt=dt, saveat=saveat) | |
W0 = reduce(hcat, sol.W.W) | |
sol2 = solve( | |
remake(prob; noise=NoiseGrid(t, ArrayOfSimilarVectors(W0))), EM(); | |
dt=dt, saveat=saveat | |
) | |
sol.u ≈ sol2.u | |
ensembleprob = EnsembleProblem(prob) | |
ensemblesol = solve( | |
ensembleprob, EM(), EnsembleThreads(); dt=dt, saveat=saveat, trajectories=1000 | |
) | |
ensemblesummary = EnsembleSummary(ensemblesol) | |
ensembleprob_wiener = EnsembleProblem( | |
prob; | |
prob_func=let t=t, wiener=Wiener(dt, (2, length(t))) | |
(prob, i, repeat) -> begin | |
remake(prob; noise=NoiseGrid(t, ArrayOfSimilarVectors(rand(wiener)))) | |
end | |
end | |
) | |
ensemblesol_wiener = solve( | |
ensembleprob_wiener, EM(), EnsembleThreads(); dt=dt, saveat=saveat, trajectories=1000 | |
) | |
ensemblesummary_wiener = EnsembleSummary(ensemblesol_wiener) | |
maximum(abs, Array(ensemblesummary_wiener.u) .- Array(ensemblesummary.u)) | |
# add noise | |
ς = 0.2 | |
Y = Array(sol) .+ ς .* randn.() | |
model = let prob=prob, dt=dt, saveat=saveat, t=t, twoς2=2*ς^2, Y=Y | |
DensityModel() do (θ, W) | |
sol = solve( | |
remake(prob; noise=NoiseGrid(t, ArrayOfSimilarVectors(W))), EM(); | |
p=θ, dt=dt, saveat=saveat | |
) | |
return -sum(abs2, Y - Array(sol)) / twoς2 | |
end | |
end | |
W = rand(Wiener(dt, (2, length(t)))) | |
model.logdensity((θ0 + 0.1, W)) | |
@code_warntype model.logdensity((θ0 + 0.1, W)) | |
θprop = SymmetricRandomWalkProposal(Normal(0, 0.05)) | |
Wprop = CrankNicolsonProposal(Wiener(dt, (2, length(t))); θ=0.44) | |
θ = 0.95 | |
N = 100_000 | |
chain = sample( | |
model, MetropolisHastings((θ=θprop, W=Wprop)), N; | |
init_params=(θ=θ,W=W), chain_type=Vector{NamedTuple} | |
) | |
θs = first.(chain) | |
@show mean(θs), std(θs) | |
fig = Figure(; resolution=(2000, 500)) | |
lines!(Axis(fig[1, 1]), θs) | |
lines!([1, N], fill(θ0, 2); color=:red) | |
lines!([1, N], fill(mean(view(θs, (N÷2):N)), 2); color=:orange) | |
WM = mean(x.W for x in view(chain, (N÷2):N)) | |
lines!(Axis(fig[1, 2]), view(W0, 1, :); color=:red, linewidth=3) | |
lines!(view(WM, 1, :); linewidth=3, color=:orange) | |
lines!(view(W0, 2, :); color=:red) | |
lines!(view(WM, 2, :); color=:orange) | |
lines!(Axis(fig[1, 3]), W0; color=:red) | |
for i in reverse(1:5_000:N) | |
lines!(chain[i].W; color=fill(i, length(t)), colorrange=(1, N)) | |
end | |
lines!(W0; color=:red, linewidth=4) | |
lines!(WM; color=:orange, linewidth=4) # posterior latent mean | |
fig | |
save("fig.png", fig) | |
# Gibbs sampling with Turing | |
@model function cn_model(; prob, Y, saveat, t, ς) | |
# parameter | |
θ ~ Normal() | |
# SDE | |
dt = step(t) | |
W ~ Wiener(dt, (2, length(t))) | |
sol = solve( | |
remake(prob; noise=NoiseGrid(t, ArrayOfSimilarVectors(W))), EM(); | |
p=θ, dt=dt, saveat=saveat | |
) | |
# observations | |
for i in 1:size(Y, 2) # bug in DynamicPPL does not allow to broadcast... | |
Y[:, i] ~ MvNormal(sol.u[i], ς) | |
end | |
return | |
end | |
# DynamicPPL wants to evaluate `logpdf(::Wiener, ::AbstractMatrix)`... | |
Distributions.logpdf(::Wiener, ::AbstractMatrix{<:Real}) = 0.0 | |
chain_turing = sample( | |
cn_model(; prob, Y, saveat, t, ς), MH(:θ => θprop, :W => Wprop), N; | |
init_params=vcat(θ, vec(W)), chain_type=Vector{NamedTuple}, | |
); # do not print summary statistics to save time ;) | |
#= Gibbs sampling works as well | |
chain_turing = sample( | |
cn_model(; prob, Y, saveat, t, ς), Gibbs(MH(:θ => θprop), MH(:W => Wprop)), N; | |
init_params=vcat(θ, vec(W)), chain_type=Vector{NamedTuple}, | |
); # do not print summary statistics to save time ;) | |
=# | |
θs_turing = map(x -> x.θ[1], chain_turing) | |
@show mean(θs_turing), std(θs_turing) | |
fig = Figure(; resolution=(2000, 500)) | |
lines!(Axis(fig[1, 1]), θs_turing) | |
lines!([1, N], fill(θ0, 2); color=:red) | |
lines!([1, N], fill(mean(view(θs_turing, (N÷2):N)), 2); color=:orange) | |
WM_turing = mean( | |
map(view(chain_turing, (N÷2):N)) do x | |
return x.W[1] | |
end | |
) | |
lines!(Axis(fig[1, 2]), view(W0, 1, :); color=:red, linewidth=3) | |
lines!(view(WM_turing, 1, :); linewidth=3, color=:orange) | |
lines!(view(W0, 2, :); color=:red) | |
lines!(view(WM_turing, 2, :); color=:orange) | |
lines!(Axis(fig[1, 3]), W0; color=:red) | |
for i in reverse(1:5_000:N) | |
lines!(chain_turing[i].W[1]; color=fill(i, length(t)), colorrange=(1, N)) | |
end | |
lines!(W0; color=:red, linewidth=4) | |
lines!(WM_turing; color=:orange, linewidth=4) # posterior latent mean | |
fig | |
save("fig_turing.png", fig) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uncertainty visualisation for latent