Skip to content

Instantly share code, notes, and snippets.

@wupeifan
Last active June 28, 2020 15:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wupeifan/c02b272cb52bd3ffc2b4af404a562761 to your computer and use it in GitHub Desktop.
Save wupeifan/c02b272cb52bd3ffc2b4af404a562761 to your computer and use it in GitHub Desktop.
Kalman filter + Zygote + Turing
using DifferentiableStateSpaceModels, ModelingToolkit, SparseArrays, LinearAlgebra, Parameters, Test, TimerOutputs, BenchmarkTools
using Turing, Zygote, ChainRules, Optim
Turing.setadbackend(:zygote)
# Generate fake data
H, mod_vals = Examples.rbc()
model = FirstOrderStateSpaceModel(H; mod_vals...)
p = [0.2, 0.02, 0.01]
θ_value = [0.5, 0.95]
solution, _ = solve_model!(model, θ = θ_value, p = p)
T = 200
# eps_value = [0.22, 0.01, 0.14, 0.03, 0.15, 0.21, 0.22, 0.05, 0.18]
# eps_value = reshape(eps_value, 1, 9)
eps_value = randn(1, T) * 0.01
seq = solve_sequence(model, solution, eps_value)
Q = zeros(2, 4)
Q[1, 1] = 1 # c
Q[2, 3] = 1 # k
obs = [Q * seq.u[i] for i in 1:T]
Ω = diagm(ones(2)) * 0.0001
function solve_kalman(m::AbstractFirstOrderStateSpaceModel, sol::FirstOrderSolution, Q, obs, Ω, x_0 = nothing)
@unpack n, n_x, n_y, n_θ, n_ϵ, η = m
@unpack h_x, g_x, h_x_θ, g_x_θ, Σ, Σ_θ = sol
(isnothing(x_0) || length(x_0) == n_x) ||
throw(ArgumentError("Length of x_0 mismatches model"))
T = size(obs, 1)
n_z = size(Q, 1)
z = [zeros(n_z) for _ in 1:T]
V = [zeros(n_z, n_z) for _ in 1:T]
z_θ = [zeros(n_x, n_θ) for _ in 1:T]
V_θ = [zeros(n_z, n_z, n_θ) for _ in 1:T]
x_x_0 = nothing
G = Q * vcat(g_x, diagm(ones(n_x)))
# G_θ = Q * vcat(g_x_θ, zeros(n_x, n_x))
if isnothing(x_0)
cur_x = zeros(n_x)
else
cur_x = deepcopy(x_0)
x_x_0 = [zeros(n_x, n_x) for _ in 1:T]
end
cur_P = diagm(ones(n_x)) * 0.01
cur_x_θ = [zeros(n_x) for _ in 1:n_θ]
cur_P_θ = [zeros(n_x, n_x) for _ in 1:n_θ]
for i in 1:T
# Kalman iteration
for j in 1:n_θ
cur_x_θ[j] = h_x_θ[j] * cur_x + h_x * cur_x_θ[j]
cur_P_θ[j] = h_x_θ[j] * cur_P * h_x' + h_x * cur_P_θ[j] * h_x' + h_x * cur_P * h_x_θ[j]' + η * Σ_θ[j] * η'
end
cur_x = h_x * cur_x
cur_P = h_x * cur_P * h_x' + η * Σ * η'
for j in 1:n_θ
G_θ = Q * vcat(g_x_θ[j], zeros(n_x, n_x))
z_θ[i][:, j] = G_θ * cur_x + G * cur_x_θ[j]
V_θ[i][:, :, j] = G_θ * cur_P * G' + G * cur_P_θ[j] * G' + G * cur_P * G_θ'
end
z[i] = G * cur_x
V[i] = G * cur_P * G' + Ω
V[i] = (V[i] + V[i]') / 2.0 # make sure V is symmetric -- Hermitian form
for j in 1:n_θ
G_θ = Q * vcat(g_x_θ[j], zeros(n_x, n_x))
cur_x_θ[j] += cur_P_θ[j] * G' * inv(V[i]) * (obs[i] - z[i]) + cur_P * G_θ' * inv(V[i]) * (obs[i] - z[i]) + cur_P * G' * inv(V[i]) * V_θ[i][:, :, j] * inv(V[i]) * (obs[i] - z[i]) - cur_P * G' * inv(V[i]) * z_θ[i][:, j]
cur_P_θ[j] -= cur_P_θ[j]' * G' * inv(V[i]) * G * cur_P + cur_P' * G_θ' * inv(V[i]) * G * cur_P - cur_P' * G' * inv(V[i]) * V_θ[i][:, :, j] * inv(V[i]) * G * cur_P + cur_P' * G' * inv(V[i]) * G_θ * cur_P + cur_P' * G' * inv(V[i]) * G * cur_P_θ[j]
end
cur_x += cur_P * G' * inv(V[i]) * (obs[i] - z[i])
cur_P -= cur_P' * G' * inv(V[i]) * G * cur_P
end
return (z = z, V = V, z_θ = z_θ, V_θ = V_θ)
end
seq = solve_kalman(model, solution, Q, obs, Ω)
function seq_z_V(θ)
solution, _ = solve_model!(model, θ = θ, p = p)
seq = solve_kalman(model, solution, Q, obs, Ω)
return (seq.z, seq.V)
end
function ChainRules.rrule(::typeof(seq_z_V), θ)
solution, _ = solve_model!(model, θ = θ, p = p)
seq = solve_kalman(model, solution, Q, obs, Ω)
function seq_pullback((Δz, ΔV))
Δθ = zeros(size(θ))
n_θ = size(θ, 1)
for i in 1:length(Δz)
Δθ += (Δz[i]' * seq.z_θ[i])[:]
for j in 1:n_θ
Δθ[j] += dot(seq.V_θ[i][:, :, j], ΔV[i])
end
end
return ChainRules.NO_FIELDS, Δθ
end
return (seq.z, seq.V), seq_pullback
end
@model function estimate_θ(obs)
n_ϵ = 1
α ~ Uniform(0.4, 0.6)
β ~ Uniform(0.8, 0.99)
θ = [α, β]
obs_model = seq_z_V(θ)
for i in 1:T
obs[i] ~ MvNormal(obs_model[1][i], obs_model[2][i])
end
end
est_model = estimate_θ(obs)
n_samples = 1000
chain = sample(est_model, HMC(0.001, 7), n_samples; save_state = true, progress = true)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment