Skip to content

Instantly share code, notes, and snippets.

@mschauer
Last active April 13, 2020 14:36
Show Gist options
  • Save mschauer/6f8e0f57ce0b15eef9f86390d0631df4 to your computer and use it in GitHub Desktop.
Save mschauer/6f8e0f57ce0b15eef9f86390d0631df4 to your computer and use it in GitHub Desktop.
Reference implementation of backwards filtering, forward guiding with https://arxiv.org/abs/1712.03807
using LinearAlgebra
using Random
using GaussianDistributions
using GaussianDistributions: logpdf
pair(u) = u[1], u[2]
pair(p::Gaussian) = p.μ, p.Σ
skiplast(r) = r[1:end-1]
# time grid
dt = 0.01
T = 10.0
s = 0:dt:T
# model dX = b(x)dt + σ(x)dW
b(x) = -0.1x - 2.5sin(x*2pi) + 0.5
σ(x) = 0.9
# linear approximation of b and constant approximation of σ
B = -0.1
β = 0.5
b̃(x) = B*x + β
σ̃ = 0.9
# observation times t
ti = 1:10:length(s)
t = s[ti]
# observation scheme Y ∼ N(L*X, σϵ^2)
L = 1.0
σϵ = 0.2
Σ = σϵ*σϵ'
# Kalman correction step, https://en.wikipedia.org/wiki/Kalman_filter#Update
"""
correct(u::T, v, H)
Correction step of a Kalman filter with `u = (x, P)` the prediction with uncertainty
covariance `P`, and `v = (y, R)` the observation with uncertainty covariance `R`
and the observation operator `H`. See https://en.wikipedia.org/wiki/Kalman_filter#Update.
"""
function correct(u, v, H)
x, Ppred = pair(u)
y, R = pair(v)
yres = y - H*x # innovation residual
S = (H*Ppred*H' + R) # innovation covariance
K = Ppred*H'*inv(S) # Kalman gain
x = x + K*yres
P = (I - K*H)*Ppred*(I - K*H)' + K*R*K'
(x, P), yres, S
end
# Sample the model
"""
forwardsample(s, ti, x)
Simulate trajectory on timegrid `s` and observations at times `s[ti]`
using the Euler-Maruyama scheme.
"""
function forwardsample(s, ti, x)
xs = typeof(x)[]
ys = typeof(L*x)[]
for i in skiplast(eachindex(s))
if i in ti
push!(ys, L*x + σϵ*randn())
end
push!(xs, x)
x = x + b(x)*dt + σ(x)*sqrt(dt)*randn()
end
push!(xs, x)
if lastindex(s) in ti
push!(ys, L*x + σϵ*randn())
end
xs, ys
end
# Compute marginal approximate filtering distributions given data `ys` backwards
"""
backwardfilter(s, ti, ys, (ν, P)) -> ps, p0
Backward filtering, starting with `N(ν, P)` prior, assuming that ys contains observations
at times `t = s[ti]` with `y ∼ N(L X[t], Σ)`.
"""
function backwardfilter(s, ti, ys, πT)
@assert lastindex(s) in ti
j = length(ys)
p, _ = correct(πT, (ys[j], Σ), L)
ps = [p]
ν, P = pair(p)
for i in eachindex(s)[end-1:-1:1]
P = P - dt*(B*P + P*B' - σ̃*σ̃')
ν = ν - dt*(B*ν + β)
push!(ps, (ν, P))
if i in ti
j = j - 1
p, _ = correct((ν, P), (ys[j], Σ), L)
(ν, P) = pair(p)
end
end
reverse!(ps), (ν, P)
end
"""
forwardguiding(s, x, ps) -> xs, ll
Forward sample a guided trajectory `xs` starting in `x` and compute it's
log-likelihood `ll`.
"""
function forwardguiding(s, x, ps)
llstep(x, r, P) = dot(b(x) - b̃(x), r)*dt - 0.5*tr((σ(x)*σ(x)' - σ̃*σ̃')*(inv(P) - r*r'))*dt
xs = typeof(x)[]
ll = 0.0
for i in skiplast(eachindex(s))
push!(xs, x)
ν, P = pair(ps[i])
r = inv(P)*(ν - x)
ll += llstep(x, r, P) # accumulate log-likelihood
x = x + b(x)*dt + σ(x)*σ(x)'*r*dt + σ(x)*sqrt(dt)*randn() # evolution guided by observations
end
push!(xs, x)
xs, ll
end
Random.seed!(123)
# First generate data from the model for illustration
π0 = Gaussian(0.0, 1.0)
x0 = rand(π0)
xs, ys = forwardsample(s, ti, x0) # sample trajectory
# run backwards filter given the observations ys
πT = Gaussian(0.0, 10.0) # prior for the backward filter
ps, p0 = backwardfilter(s, ti, ys, πT)
# sample trajectories and their importance weight
K = 10
x̂s = Vector(undef, K)
ll = zeros(K)
for k in 1:K
x0 = rand(Gaussian(p0...)) # sample from p0
x̂s[k], ll[k] = forwardguiding(s, x0, ps)
ll[k] += logpdf(π0, x̂s[k][1]) - logpdf(πT, x̂s[k][end]) # correct for having used
# backward prior πT instead of
# our actual prior π0
end
lmax = maximum(exp.(ll)) # maximum of importance weights
# Plot samples of the latent trajectories colored according to imporance weight
using Plots
pl = Plots.scatter(t, ys, color=:orange, markersize=2., label="obs",legend=:outertopright) # observations
for k in 1:K
Plots.plot!(pl, s, x̂s[k], color=:maroon, lw = 0.6, alpha = exp(ll[k])/lmax, label="sample $k") # samples
end
Plots.plot!(pl, s, xs, color=:lightseagreen, label="x true") # ground truth
display(pl)
using LinearAlgebra
using Random
using GaussianDistributions
using GaussianDistributions: logpdf
using Parameters
pair(u) = u[1], u[2]
pair(p::Gaussian) = p.μ, p.Σ
skiplast(r) = r[1:end-1]
# model dX = b(x)dt + σ(x)dW
# argument M contains Model parameters, see below
b(x, M) = -0.1x - M.θ*sin(x*2pi) + 0.5
σ(x, M) = 0.9
# linear approximation of b and constant approximation of σ
b̃(x, M) = M.B*x + M.β
σ̃(M) = M.σ̃
# time grid
dt = 0.01
T = 10.0
s = 0:dt:T
# observation times t
ti = 1:10:length(s)
t = s[ti]
@with_kw struct Model{R} @deftype R # in 1d all parameters can be of the same type R
# unknown parameter
θ = 2.5
# parameters for linear approximation of b and constant approximation of σ
B = -0.1
β = 0.5
σ̃ = 0.9
# observation scheme Y ∼ N(L*X, σϵ^2)
L = 1.0
σϵ = 0.2
Σ = σϵ*σϵ'
end
# Kalman correction step, https://en.wikipedia.org/wiki/Kalman_filter#Update
"""
correct(u::T, v, H)
Correction step of a Kalman filter with `u = (x, P)` the prediction with uncertainty
covariance `P`, and `v = (y, R)` the observation with uncertainty covariance `R`
and the observation operator `H`. See https://en.wikipedia.org/wiki/Kalman_filter#Update.
"""
function correct(u, v, H, c = 0.0)
x, Ppred = pair(u)
y, R = pair(v)
yres = y - H*x # innovation residual
S = (H*Ppred*H' + R) # innovation covariance
K = Ppred*H'*inv(S) # Kalman gain
x = x + K*yres
P = (I - K*H)*Ppred*(I - K*H)' + K*R*K'
c = c - logpdf(Gaussian(zero(y), R), y)
(x, P), c, yres, S
end
# Sample the model
"""
forwardsample(s, ti, x)
Simulate trajectory on timegrid `s` and observations at times `s[ti]`
using the Euler-Maruyama scheme.
"""
function forwardsample(M, s, ti, x)
@unpack L, σϵ = M
xs = typeof(x)[]
ys = typeof(L*x)[]
for i in skiplast(eachindex(s))
dt = s[i+1] - s[i]
if i in ti
push!(ys, L*x + σϵ*randn())
end
push!(xs, x)
x = x + b(x, M)*dt + σ(x, M)*sqrt(dt)*randn()
end
push!(xs, x)
if lastindex(s) in ti
push!(ys, L*x + σϵ*randn())
end
xs, ys
end
# Compute marginal approximate filtering distributions given data `ys` backwards
"""
backwardfilter(M, s, ti, ys, (ν, P)) -> ps, p0, c
Backward filtering, starting with `N(ν, P)` prior, assuming that ys contains observations
at times `t = s[ti]` with `y ∼ N(L X[t], Σ)`. `exp(-c)` is the integration constant from Theorem 3.3.
"""
function backwardfilter(M, s, ti, ys, πT, c = 0.0)
@unpack L, Σ, B, β, σ̃ = M
@assert lastindex(s) in ti
j = length(ys)
p, _, c = correct(πT, (ys[j], Σ), L, c)
ps = [p]
ν, P = pair(p)
for i in eachindex(s)[end-1:-1:1]
dt = s[i+1] - s[i]
P = P - dt*(B*P + P*B' - σ̃*σ̃')
ν = ν - dt*(B*ν + β)
H = inv(P)
F = H*ν
c += β*F*dt + 0.5*F'*σ̃*σ̃'*F*dt - 0.5*sum(H .* (σ̃*σ̃'))*dt
push!(ps, (ν, P))
if i in ti
j = j - 1
p, _, c = correct((ν, P), (ys[j], Σ), L, c)
(ν, P) = pair(p)
end
end
reverse!(ps), (ν, P), c
end
"""
forwardguiding(M, s, x, ps, Z) -> xs, ll
Forward sample a guided trajectory `xs` starting in `x` and compute it's
log-likelihood `ll` with innovations `Z = randn(length(s))`.
"""
function forwardguiding(M, s, x, ps, Z=randn(length(s)))
llstep(x, r, P) = dot(b(x, M) - b̃(x, M), r)*dt - 0.5*tr((σ(x, M)*σ(x, M)' - σ̃(M)*σ̃(M)')*(inv(P) - r*r'))*dt
xs = typeof(x)[]
ll = 0.0
for i in skiplast(eachindex(s))
dt = s[i+1] - s[i]
push!(xs, x)
ν, P = pair(ps[i])
r = inv(P)*(ν - x)
ll += llstep(x, r, P) # accumulate log-likelihood
x = x + b(x, M)*dt + σ(x, M)*σ(x, M)'*r*dt + σ(x, M)*sqrt(dt)*Z[i] # evolution guided by observations
end
push!(xs, x)
xs, ll
end
"""
randomwalkmcmc(s, ti, ys, θ0, iters, ρ = 0.9, σθ = 0.01)
Infer parameter θ using Metropolis-Hastings with joint update of
innovations (Crank Nicolson with parameter ρ) and parameter θ (Gaussian random walk
with stepsize σθ)
"""
function randomwalkmcmc(s, ti, ys, θ0, iters, ρ = 0.9, σθ = 0.01)
θ = θ0
Mᵒ = Model(θ = θ)
θs = [θ]
# sample initial latent path
ps, p0, c = backwardfilter(M, s, ti, ys, πT)
x = rand(Gaussian(p0...))
Z = randn(length(s))
x̂, ll = forwardguiding(M, s, x, ps, Z)
acc = 0
for iter in 1:iters
# random walk proposal for parameter
θᵒ = θ + σθ* randn()
# independent proposal for starting point
x0ᵒ = rand(Gaussian(p0...))
# compute filtering density for guiding
Mᵒ = Model(θ = θᵒ)
ps, p0, c = backwardfilter(Mᵒ, s, ti, ys, πT)
ν0, P0 = p0
# random walk proposal for innovations
Zᵒ = ρ*Z + sqrt(1 - ρ^2)*randn(length(s))
# compute latent path
x̂ᵒ, llᵒ = forwardguiding(Mᵒ, s, x0ᵒ, ps, Zᵒ)
llᵒ += logpdf(π0, x̂ᵒ[1]) - logpdf(πT, x̂ᵒ[end])
llᵒ += -c + (-0.5*x0ᵒ' + ν0')*inv(P0)*x0ᵒ # constant may change if σ depends on parameter
# Metropolis-Hastings accept/reject for joint proposal of starting point, path, parameter
if rand() < exp(llᵒ - ll)
θ = θᵒ
ll = llᵒ
x0 = x0ᵒ
x̂ = x̂ᵒ
Z = Zᵒ
acc += 1
end
push!(θs, θ)
end
θs, acc/iters
end
Random.seed!(123)
# Set true model
θtrue = 2.5
M = Model(θ = θtrue)
# First generate data from the model for illustration
π0 = Gaussian(0.0, 1.0)
x0 = rand(π0)
xs, ys = forwardsample(M, s, ti, x0) # sample trajectory
# run backwards filter given the observations ys
πT = Gaussian(0.0, 10.0) # prior for the backward filter
ps, p0, c = backwardfilter(M, s, ti, ys, πT)
# sample trajectories and their importance weight
K = 10
x̂s = Vector(undef, K)
ll = zeros(K)
for k in 1:K
x0 = rand(Gaussian(p0...)) # sample from p0
x̂s[k], ll[k] = forwardguiding(M, s, x0, ps)
ll[k] += logpdf(π0, x̂s[k][1]) - logpdf(πT, x̂s[k][end]) # correct for having used
# backward prior πT instead of
# our actual prior π0
end
lmax = maximum(exp.(ll)) # maximum of importance weights
# inference for parameter θ
θ = 0.2θtrue # start somewhere wrong
iters = 50000
ρ = 0.9 # random walk parameter for innovation update (Crank Nicolson scheme)
σθ = 0.03 # stepsize randomwalk parameter
θs, a = @time randomwalkmcmc(s, ti, ys, θ, iters, ρ, σθ)
println("Acceptance rate: ", a)
# Plot samples of the latent trajectories colored according to imporance weight
using Plots
pl = Plots.scatter(t, ys, color=:orange, markersize=2., label="obs",legend=:outertopright) # observations
for k in 1:K
Plots.plot!(pl, s, x̂s[k], color=:maroon, lw = 0.6, alpha = exp(ll[k])/lmax, label="sample $k") # samples
end
Plots.plot!(pl, s, xs, color=:lightseagreen, label="x true") # ground truth
display(pl)
# Plot samples of the mcmc chain for θ
pl2 = Plots.plot(0:10:iters, θs[1:10:end], label = "theta, samples")
Plots.plot!(pl2, 0:10:iters, fill(θtrue, length(0:10:iters)), label= "theta, true")
display(pl2)
@fmeulen
Copy link

fmeulen commented Apr 13, 2020

Just to be sure, in line 190, is this \log \tilde\rho(0,x_0), expressed in Hfc-parametrisation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment