Skip to content

Instantly share code, notes, and snippets.

@mschauer
Created September 4, 2020 12:30
Show Gist options
  • Save mschauer/6ad68716586ae4378374753d945e789b to your computer and use it in GitHub Desktop.
Save mschauer/6ad68716586ae4378374753d945e789b to your computer and use it in GitHub Desktop.
Differentiable Kalman filter
using GaussianDistributions
using GaussianDistributions: correct, ⊕
using LinearAlgebra
using Statistics
lchol(a) = cholesky(a).L
llikelihood(yres, S) = GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres)
# State space system
#
# x[0] ∼ N(x0, P0)
# x[k] = Φx[k−1] + w[k], w[k] ∼ N(0, Q)
# y[k] = Hx[k] + v[k], v[k] ∼ N(0, R)
prior = Gaussian([1., 0.], Matrix(1.0I, 2, 2))
modelf(θ = [1.0, 0.5]) = (Φ = [0.8 0.3; -0.1 0.8],
Q = Gaussian(θ, [0.2 0.0; 0.0 1.0])
)
model = modelf()
obsmodel = (H = [1.0 0.5],
R = Gaussian([0.0], Matrix(1.0I, 1, 1))
)
n = 50 # transitions
# Forward simulate
function forward(prior, model, obsmodel, n)
x = rand(prior)
y = obsmodel.H*x + rand(obsmodel.R)
xs = [x]
ys = [y]
for i in 1:n
x = model.Φ*x + rand(model.Q)
y = obsmodel.H*x + rand(obsmodel.R)
push!(xs, x)
push!(ys, y)
end
xs, ys
end
# Kalman filter
function filter(prior, model, obsmodel, ys)
x = prior
x, yres, S = GaussianDistributions.correct(x, ys[1] + obsmodel.R, obsmodel.H)
ll = llikelihood(yres, S)
xs = Any[x]
for i in 2:length(ys)
x = model.Φ*x ⊕ model.Q
x, yres, S = GaussianDistributions.correct(x, ys[i] + obsmodel.R, obsmodel.H)
ll += llikelihood(yres, S)
push!(xs, x)
end
xs, ll
end
xstrue, ys = forward(prior, model, obsmodel, n)
xs, ll = filter(prior, model, obsmodel, ys)
using Plots
p = plot(0:n, first.(ys), color=:red, label="observations")
plot!(p, 0:n, first.(mean.(xs)), color=:blue, label="filtered mean x1")
plot!(p, 0:n, last.(mean.(xs)), color=:blue, label="filtered mean x2")
plot!(p, 0:n, first.(xstrue), color=:black, label="true x1")
plot!(p, 0:n, last.(xstrue), color=:black, label="true x2")
using ForwardDiff
# Taking gradient of the likelihood with respect to a model parameter
grad = ForwardDiff.gradient(θ -> filter(prior, modelf(θ), obsmodel, ys)[2], rand(2))
@show grad
@mschauer
Copy link
Author

mschauer commented Sep 5, 2020

The key here are is that GaussianDistributions.jl provides

1.) arithmetic for Gaussians

        x = Φ*x ⊕ Q

to formulate your evolution law. Here x
is the filtered state (mean and covariance wrapped in a `Gaussian), Φ the matrix with the linear map of the forward evolution and ⊕ Q
adds an independent Gaussian noise variable to the evolved state Φ*x.

This setup makes it easy to extend with control-terms/whatever according to your need.

2.) A function for the correction/update step operating on Gaussian laws

        x, yres, S = correct(x, ys[i] + R, H)

which returns the corrected Gaussian and the quantities you need to...

3.) increment the likelihood

        ll += llikelihood(yres, S)

if you want to do inference on the dynamics of state space systems.

That it plays nice with ForwardDiff, StaticArrays, Complex, Unitful and other ecosystem components is just a nice extra

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment