Created
September 4, 2020 12:30
-
-
Save mschauer/6ad68716586ae4378374753d945e789b to your computer and use it in GitHub Desktop.
Differentiable Kalman filter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using GaussianDistributions | |
using GaussianDistributions: correct, ⊕ | |
using LinearAlgebra | |
using Statistics | |
lchol(a) = cholesky(a).L | |
llikelihood(yres, S) = GaussianDistributions.logpdf(Gaussian(zero(yres), Symmetric(S)), yres) | |
# State space system | |
# | |
# x[0] ∼ N(x0, P0) | |
# x[k] = Φx[k−1] + w[k], w[k] ∼ N(0, Q) | |
# y[k] = Hx[k] + v[k], v[k] ∼ N(0, R) | |
prior = Gaussian([1., 0.], Matrix(1.0I, 2, 2)) | |
modelf(θ = [1.0, 0.5]) = (Φ = [0.8 0.3; -0.1 0.8], | |
Q = Gaussian(θ, [0.2 0.0; 0.0 1.0]) | |
) | |
model = modelf() | |
obsmodel = (H = [1.0 0.5], | |
R = Gaussian([0.0], Matrix(1.0I, 1, 1)) | |
) | |
n = 50 # transitions | |
# Forward simulate | |
function forward(prior, model, obsmodel, n) | |
x = rand(prior) | |
y = obsmodel.H*x + rand(obsmodel.R) | |
xs = [x] | |
ys = [y] | |
for i in 1:n | |
x = model.Φ*x + rand(model.Q) | |
y = obsmodel.H*x + rand(obsmodel.R) | |
push!(xs, x) | |
push!(ys, y) | |
end | |
xs, ys | |
end | |
# Kalman filter | |
function filter(prior, model, obsmodel, ys) | |
x = prior | |
x, yres, S = GaussianDistributions.correct(x, ys[1] + obsmodel.R, obsmodel.H) | |
ll = llikelihood(yres, S) | |
xs = Any[x] | |
for i in 2:length(ys) | |
x = model.Φ*x ⊕ model.Q | |
x, yres, S = GaussianDistributions.correct(x, ys[i] + obsmodel.R, obsmodel.H) | |
ll += llikelihood(yres, S) | |
push!(xs, x) | |
end | |
xs, ll | |
end | |
xstrue, ys = forward(prior, model, obsmodel, n) | |
xs, ll = filter(prior, model, obsmodel, ys) | |
using Plots | |
p = plot(0:n, first.(ys), color=:red, label="observations") | |
plot!(p, 0:n, first.(mean.(xs)), color=:blue, label="filtered mean x1") | |
plot!(p, 0:n, last.(mean.(xs)), color=:blue, label="filtered mean x2") | |
plot!(p, 0:n, first.(xstrue), color=:black, label="true x1") | |
plot!(p, 0:n, last.(xstrue), color=:black, label="true x2") | |
using ForwardDiff | |
# Taking gradient of the likelihood with respect to a model parameter | |
grad = ForwardDiff.gradient(θ -> filter(prior, modelf(θ), obsmodel, ys)[2], rand(2)) | |
@show grad |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The key here are is that GaussianDistributions.jl provides
1.) arithmetic for Gaussians
to formulate your evolution law. Here x
is the filtered state (mean and covariance wrapped in a `Gaussian), Φ the matrix with the linear map of the forward evolution and ⊕ Q
adds an independent Gaussian noise variable to the evolved state Φ*x.
This setup makes it easy to extend with control-terms/whatever according to your need.
2.) A function for the correction/update step operating on Gaussian laws
which returns the corrected Gaussian and the quantities you need to...
3.) increment the likelihood
if you want to do inference on the dynamics of state space systems.
That it plays nice with
ForwardDiff
,StaticArrays
,Complex
,Unitful
and other ecosystem components is just a nice extra