Skip to content

Instantly share code, notes, and snippets.

@wupeifan
Created October 16, 2020 01:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wupeifan/94bb6498562a0081bcbe31de20b3079d to your computer and use it in GitHub Desktop.
Save wupeifan/94bb6498562a0081bcbe31de20b3079d to your computer and use it in GitHub Desktop.
Forward Kalman code
# Forward Kalman
function solve_kalman(m::AbstractFirstOrderExpectationalDifferenceModel, sol::FirstOrderSolution, Q, obs, Ω, x_0 = nothing)
@unpack n, n_x, n_y, n_p, n_ϵ, η = m
@unpack h_x, g_x, h_x_p, g_x_p, Σ, Σ_p = sol
(isnothing(x_0) || length(x_0) == n_x) ||
throw(ArgumentError("Length of x_0 mismatches model"))
T = size(obs, 1)
n_z = size(Q, 1)
z = [zeros(n_z) for _ in 1:T]
V = [zeros(n_z, n_z) for _ in 1:T]
z_θ = [zeros(n_z, n_p) for _ in 1:T]
V_θ = [zeros(n_z, n_z, n_p) for _ in 1:T]
x_x_0 = nothing
G = Q * vcat(g_x, diagm(ones(n_x)))
# G_θ = Q * vcat(g_x_p, zeros(n_x, n_x))
if isnothing(x_0)
cur_x = zeros(n_x)
else
cur_x = deepcopy(x_0)
x_x_0 = [zeros(n_x, n_x) for _ in 1:T]
end
cur_P = lyapd(h_x, η * Σ * η')
cur_x_p = [zeros(n_x) for _ in 1:n_p]
cur_P_θ = [zeros(n_x, n_x) for _ in 1:n_p]
for i in 1:n_p
tmp = h_x_p[i] * cur_P * h_x'
cur_P_θ[i] = lyapd(h_x, η * Σ_p[i] * η' + tmp + tmp')
end
for i in 1:T
# Kalman iteration
for j in 1:n_p
cur_x_p[j] = h_x_p[j] * cur_x + h_x * cur_x_p[j]
cur_P_θ[j] = h_x_p[j] * cur_P * h_x' + h_x * cur_P_θ[j] * h_x' + h_x * cur_P * h_x_p[j]' + η * Σ_p[j] * η'
end
cur_x = h_x * cur_x
cur_P = h_x * cur_P * h_x' + η * Σ * η'
for j in 1:n_p
G_θ = Q * vcat(g_x_p[j], zeros(n_x, n_x))
z_θ[i][:, j] = G_θ * cur_x + G * cur_x_p[j]
V_θ[i][:, :, j] = G_θ * cur_P * G' + G * cur_P_θ[j] * G' + G * cur_P * G_θ'
end
z[i] = G * cur_x
V[i] = G * cur_P * G' + Ω
V[i] = (V[i] + V[i]') / 2.0 # make sure V is symmetric -- Hermitian form
for j in 1:n_p
G_θ = Q * vcat(g_x_p[j], zeros(n_x, n_x))
cur_x_p[j] += cur_P_θ[j] * G' * inv(V[i]) * (obs[i] - z[i]) + cur_P * G_θ' * inv(V[i]) * (obs[i] - z[i]) - cur_P * G' * inv(V[i]) * V_θ[i][:, :, j] * inv(V[i]) * (obs[i] - z[i]) - cur_P * G' * inv(V[i]) * z_θ[i][:, j]
cur_P_θ[j] -= cur_P_θ[j]' * G' * inv(V[i]) * G * cur_P + cur_P' * G_θ' * inv(V[i]) * G * cur_P - cur_P' * G' * inv(V[i]) * V_θ[i][:, :, j] * inv(V[i]) * G * cur_P + cur_P' * G' * inv(V[i]) * G_θ * cur_P + cur_P' * G' * inv(V[i]) * G * cur_P_θ[j]
end
cur_x += cur_P * G' * inv(V[i]) * (obs[i] - z[i])
cur_P -= cur_P' * G' * inv(V[i]) * G * cur_P
end
return (z = z, V = V, z_θ = z_θ, V_θ = V_θ)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment