Last active
June 28, 2020 15:21
-
-
Save wupeifan/c02b272cb52bd3ffc2b4af404a562761 to your computer and use it in GitHub Desktop.
Kalman filter + Zygote + Turing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using DifferentiableStateSpaceModels, ModelingToolkit, SparseArrays, LinearAlgebra, Parameters, Test, TimerOutputs, BenchmarkTools | |
using Turing, Zygote, ChainRules, Optim | |
Turing.setadbackend(:zygote) | |
# Generate fake data | |
H, mod_vals = Examples.rbc() | |
model = FirstOrderStateSpaceModel(H; mod_vals...) | |
p = [0.2, 0.02, 0.01] | |
θ_value = [0.5, 0.95] | |
solution, _ = solve_model!(model, θ = θ_value, p = p) | |
T = 200 | |
# eps_value = [0.22, 0.01, 0.14, 0.03, 0.15, 0.21, 0.22, 0.05, 0.18] | |
# eps_value = reshape(eps_value, 1, 9) | |
eps_value = randn(1, T) * 0.01 | |
seq = solve_sequence(model, solution, eps_value) | |
Q = zeros(2, 4) | |
Q[1, 1] = 1 # c | |
Q[2, 3] = 1 # k | |
obs = [Q * seq.u[i] for i in 1:T] | |
Ω = diagm(ones(2)) * 0.0001 | |
function solve_kalman(m::AbstractFirstOrderStateSpaceModel, sol::FirstOrderSolution, Q, obs, Ω, x_0 = nothing) | |
@unpack n, n_x, n_y, n_θ, n_ϵ, η = m | |
@unpack h_x, g_x, h_x_θ, g_x_θ, Σ, Σ_θ = sol | |
(isnothing(x_0) || length(x_0) == n_x) || | |
throw(ArgumentError("Length of x_0 mismatches model")) | |
T = size(obs, 1) | |
n_z = size(Q, 1) | |
z = [zeros(n_z) for _ in 1:T] | |
V = [zeros(n_z, n_z) for _ in 1:T] | |
z_θ = [zeros(n_x, n_θ) for _ in 1:T] | |
V_θ = [zeros(n_z, n_z, n_θ) for _ in 1:T] | |
x_x_0 = nothing | |
G = Q * vcat(g_x, diagm(ones(n_x))) | |
# G_θ = Q * vcat(g_x_θ, zeros(n_x, n_x)) | |
if isnothing(x_0) | |
cur_x = zeros(n_x) | |
else | |
cur_x = deepcopy(x_0) | |
x_x_0 = [zeros(n_x, n_x) for _ in 1:T] | |
end | |
cur_P = diagm(ones(n_x)) * 0.01 | |
cur_x_θ = [zeros(n_x) for _ in 1:n_θ] | |
cur_P_θ = [zeros(n_x, n_x) for _ in 1:n_θ] | |
for i in 1:T | |
# Kalman iteration | |
for j in 1:n_θ | |
cur_x_θ[j] = h_x_θ[j] * cur_x + h_x * cur_x_θ[j] | |
cur_P_θ[j] = h_x_θ[j] * cur_P * h_x' + h_x * cur_P_θ[j] * h_x' + h_x * cur_P * h_x_θ[j]' + η * Σ_θ[j] * η' | |
end | |
cur_x = h_x * cur_x | |
cur_P = h_x * cur_P * h_x' + η * Σ * η' | |
for j in 1:n_θ | |
G_θ = Q * vcat(g_x_θ[j], zeros(n_x, n_x)) | |
z_θ[i][:, j] = G_θ * cur_x + G * cur_x_θ[j] | |
V_θ[i][:, :, j] = G_θ * cur_P * G' + G * cur_P_θ[j] * G' + G * cur_P * G_θ' | |
end | |
z[i] = G * cur_x | |
V[i] = G * cur_P * G' + Ω | |
V[i] = (V[i] + V[i]') / 2.0 # make sure V is symmetric -- Hermitian form | |
for j in 1:n_θ | |
G_θ = Q * vcat(g_x_θ[j], zeros(n_x, n_x)) | |
cur_x_θ[j] += cur_P_θ[j] * G' * inv(V[i]) * (obs[i] - z[i]) + cur_P * G_θ' * inv(V[i]) * (obs[i] - z[i]) + cur_P * G' * inv(V[i]) * V_θ[i][:, :, j] * inv(V[i]) * (obs[i] - z[i]) - cur_P * G' * inv(V[i]) * z_θ[i][:, j] | |
cur_P_θ[j] -= cur_P_θ[j]' * G' * inv(V[i]) * G * cur_P + cur_P' * G_θ' * inv(V[i]) * G * cur_P - cur_P' * G' * inv(V[i]) * V_θ[i][:, :, j] * inv(V[i]) * G * cur_P + cur_P' * G' * inv(V[i]) * G_θ * cur_P + cur_P' * G' * inv(V[i]) * G * cur_P_θ[j] | |
end | |
cur_x += cur_P * G' * inv(V[i]) * (obs[i] - z[i]) | |
cur_P -= cur_P' * G' * inv(V[i]) * G * cur_P | |
end | |
return (z = z, V = V, z_θ = z_θ, V_θ = V_θ) | |
end | |
seq = solve_kalman(model, solution, Q, obs, Ω) | |
function seq_z_V(θ) | |
solution, _ = solve_model!(model, θ = θ, p = p) | |
seq = solve_kalman(model, solution, Q, obs, Ω) | |
return (seq.z, seq.V) | |
end | |
function ChainRules.rrule(::typeof(seq_z_V), θ) | |
solution, _ = solve_model!(model, θ = θ, p = p) | |
seq = solve_kalman(model, solution, Q, obs, Ω) | |
function seq_pullback((Δz, ΔV)) | |
Δθ = zeros(size(θ)) | |
n_θ = size(θ, 1) | |
for i in 1:length(Δz) | |
Δθ += (Δz[i]' * seq.z_θ[i])[:] | |
for j in 1:n_θ | |
Δθ[j] += dot(seq.V_θ[i][:, :, j], ΔV[i]) | |
end | |
end | |
return ChainRules.NO_FIELDS, Δθ | |
end | |
return (seq.z, seq.V), seq_pullback | |
end | |
@model function estimate_θ(obs) | |
n_ϵ = 1 | |
α ~ Uniform(0.4, 0.6) | |
β ~ Uniform(0.8, 0.99) | |
θ = [α, β] | |
obs_model = seq_z_V(θ) | |
for i in 1:T | |
obs[i] ~ MvNormal(obs_model[1][i], obs_model[2][i]) | |
end | |
end | |
est_model = estimate_θ(obs) | |
n_samples = 1000 | |
chain = sample(est_model, HMC(0.001, 7), n_samples; save_state = true, progress = true) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment