Created
April 30, 2024 16:21
-
-
Save THargreaves/bbc7cb1228c4f08098dff236c9824b85 to your computer and use it in GitHub Desktop.
A combined proposal for multiple enhancements to the SSMProblems.jl interface that lead to a clean API for Rao-Blackwellised filtering
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script contains a combined proposal for multiple enhancements to the existing | |
SSMProblems.jl interface that we have discussed over the past few months. The main | |
features are: | |
- Separation of latent dynamics and observation processes | |
- Introduction of control variables | |
- Unified interface for defining linear Gaussian models | |
- Interface for forward simulation and conditioning | |
- Hierarchical (conditional) latent dynamics | |
- Rao-Blackwellised filtering | |
Open questions and potential issues (of which there are many) are marked with TODOs. It | |
is probably easier to read this script if you ignore these at first then look back at | |
them after. | |
The Rao-Blackwellised filter is demonstrated on a dummy linear Gaussian model for | |
which a subset of the latent dynamics are treated as general (not necessarily linear or | |
Gaussian). This way the samples from the Rao-Blackwellised particle filter can be | |
compared to the ground truth Kalman filter estimates. | |
""" | |
using Distributions | |
using FillArrays | |
using LinearAlgebra | |
using LogExpFunctions | |
using Plots | |
using ProgressMeter | |
using Random | |
using StatsBase | |
# Validation | |
using DynamicIterators | |
using GaussianDistributions | |
using Kalman | |
# TODO: control and time-step inputs should be optional | |
# TODO: should we differentiate between time-steps and "real" time? Note, it's often the | |
# time-delta that is more important than the absolute time. Perhaps we could condition on | |
# time-observation pairs as in Kalman.jl. | |
# TODO: would be nice to have the controls have names rather than combining all latent, | |
# observation and conditioning controls into one vector. | |
# TODO: need to make a decision about whether an observation is available for the initial | |
# state. For now I have assumed this to be the case. If not, the `xs` and `ys` from forward | |
# simulation will have different lengths. | |
# TODO: check that time indexes for controls, latent states, and observations are consistent | |
# TODO: is there a macro for creating Random.default_rng() versions of sample methods? | |
############################ | |
#### MODIFIED INTERFACE #### | |
############################ | |
# TODO: do we ever need initialisation_logdensity? If so (/regardless), should we have a | |
# initialisation_distribution for convinience? | |
""" | |
Latent dynamics of a state space model. Should implement the following methods using the | |
value of the control vector and current time-step as inputs: | |
- `transition` | |
- `transition_logdensity` | |
Alternatively, you can specify a `transition_distribution method`` that will be used to | |
generate the above two methods. | |
An `initialise` method should also be implemented to generate the initial state | |
but this (obviously) doesn't need to accept the time-step. | |
Finally, a method `dim` returning the dimension of the latent state should be defined. | |
""" | |
abstract type LatentDynamics end | |
# Default transition methods when transition_distribution is defined | |
function transition( | |
rng::AbstractRNG, | |
dynamics::LatentDynamics, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
check_transition_distribution(dynamics) | |
return rand(rng, transition_distribution(dynamics, x, u, t)) | |
end | |
function transition_logdensity( | |
dynamics::LatentDynamics, | |
x_curr::AbstractVector, | |
x_next::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
check_transition_distribution(dynamics) | |
return logpdf(transition_distribution(dynamics, x_curr, u, t), x_next) | |
end | |
function check_transition_distribution(dynamics::LatentDynamics) | |
if !hasmethod( | |
transition_distribution, | |
Tuple{typeof(dynamics),AbstractVector,AbstractVector,Integer}, | |
) | |
calling_func = StackTraces.stacktrace()[2].func | |
error("neither $calling_func nor transition_distribution is defined for $dynamics") | |
end | |
end | |
""" | |
Latent dynamics of a state space model that are linear and Gaussian. Methods should be | |
implemented to define how to construct the following quanities given the control vector | |
and time step: | |
- A: transition matrix | |
- b: drift vector | |
- Q: transition noise covariance | |
These should be name `calc_A`, `calc_b`, and `calc_Q` respectively. In the | |
simplest case these would simply return the values of matrices and vectors stored in the | |
model struct, but they can also be more complex functions that depend on the control and | |
time-step in arbitrary ways. | |
Likewise, a prior mean vector and covariance matrix should be specified by `calc_μ0` and | |
`calc_Σ0`, respectively, not depending on the time-step. | |
Note, that we do not include a control matrix `B` as is standard in the literature since | |
we have opted for the more general setting in which the drift vector can be any | |
non-linear function of the control. To retrieve this standard behaviour, one can simple | |
define `calc_C(...) = dynamics.B * u`. | |
""" | |
# Note: defining our linear Gaussian model in this way (rather than fixed matrices/vectors) | |
# has an additional advantage on top of unifying the interface. Specifically, it allows for | |
# time-varying and control-dependent linear Gaussian dynamics, facilitating incredibly | |
# expressive modelling. This could be achieved by defining a vector of matrices/vectors, one | |
# for each time-step, but this is incredibly memory inefficient. | |
abstract type LinearGaussianDynamics <: LatentDynamics end | |
# Default transitions for use in general particle filtering | |
function transition_distribution( | |
dynamics::LinearGaussianDynamics, x::AbstractVector, u::AbstractVector, t::Integer | |
) | |
A = calc_A(dynamics, u, t) | |
b = calc_b(dynamics, u, t) | |
Q = calc_Q(dynamics, u, t) | |
return MvNormal(A * x + b, Q) | |
end | |
# Default initialisation for use in general particle filtering | |
function initialise(rng::AbstractRNG, dynamics::LinearGaussianDynamics, u::AbstractVector) | |
μ0 = calc_μ0(dynamics, u) | |
Σ0 = calc_Σ0(dynamics, u) | |
return rand(rng, MvNormal(μ0, Σ0)) | |
end | |
# Default model matrices/vectors can be defined. This is convenient for simple models that | |
# don't include a drift term. This could potentially lead to obfuscated bugs if the user | |
# attempts to define `calc_X` but has a typo or incorrect signature. Perhaps including a | |
# warning that default values are being used would be useful. | |
calc_A(dynamics::LinearGaussianDynamics, u::AbstractVector, t::Integer) = I | |
calc_b(dynamics::LinearGaussianDynamics, u::AbstractVector, t::Integer) = Zeros(dynamics.D) | |
calc_Q(dynamics::LinearGaussianDynamics, u::AbstractVector, t::Integer) = I | |
calc_μ0(dynamics::LinearGaussianDynamics, u::AbstractVector) = Zeros(dynamics.D) | |
calc_Σ0(dynamics::LinearGaussianDynamics, u::AbstractVector) = I | |
# TODO: think this should have two dimensions, input and output | |
""" | |
Observation process of a state space model. Should implement the following methods using | |
the value of the control vector and current time-step as inputs: | |
- `observation` | |
- `observation_logdensity` | |
Alternatively, you can specify an observation_distribution method that will be used to | |
generate the above two methods. | |
A method `dim` returning the dimension of the observation should also be defined. | |
""" | |
abstract type ObservationProcess end | |
# Default observation methods when observation_distribution is defined | |
function observation( | |
rng::AbstractRNG, | |
process::ObservationProcess, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
check_observation_distribution(process) | |
return rand(rng, observation_distribution(process, x, u, t)) | |
end | |
function observation_logdensity( | |
process::ObservationProcess, | |
x::AbstractVector, | |
y::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
check_observation_distribution(process) | |
return logpdf(observation_distribution(process, x, u, t), y) | |
end | |
# TODO: is a try-catch more efficient here. Or better, can we cache the result? | |
function check_observation_distribution(process::ObservationProcess) | |
if !hasmethod( | |
observation_distribution, | |
Tuple{typeof(process),AbstractVector,AbstractVector,Integer}, | |
) | |
calling_func = StackTraces.stacktrace()[2].func | |
error("neither $calling_func nor observation_distribution is defined for $process") | |
end | |
end | |
""" | |
Observation process of a state space model that is linear and Gaussian. Methods should | |
be implemented to define how to construct the following quanities given the control | |
vector and time step: | |
- H: observation matrix | |
- c: observation bias vector | |
- R: observation noise covariance | |
These should be name `calc_H`, `calc_c`, and `calc_R` respectively. In the simplest case | |
these would simply return the values of matrices and vectors stored in the model struct, | |
but they can also be more complex functions that depend on the control. | |
""" | |
abstract type LinearGaussianObservation <: ObservationProcess end | |
# Default observation for use in general particle filtering | |
function observation_distribution( | |
process::LinearGaussianObservation, x::AbstractVector, u::AbstractVector, t::Integer | |
) | |
H = calc_H(process, u, t) | |
c = calc_c(process, u, t) | |
R = calc_R(process, u, t) | |
return MvNormal(H * x + c, R) | |
end | |
# Default model matrices/vectors | |
calc_H(process::LinearGaussianObservation, u::AbstractVector, t::Integer) = I | |
calc_c(process::LinearGaussianObservation, u::AbstractVector, t::Integer) = Zeros(process.D) | |
calc_R(process::LinearGaussianObservation, u::AbstractVector, t::Integer) = I | |
# Regular and hierarchical models are both state space models, so we introduce an abstract | |
# type that they will both inherit from | |
abstract type AbstractStateSpaceModel end | |
# TODO: need a way of ensuring that the latent and observation dimensions match (for all H). | |
"""A state space model combines latent dynamics and an observation processes.""" | |
struct StateSpaceModel{LD<:LatentDynamics,OP<:ObservationProcess} <: AbstractStateSpaceModel | |
latent_dynamics::LD | |
observation_process::OP | |
end | |
# We make this split interface compatible with the existing interface by defining | |
# initialisation, transition, and observation methods for the entire SSM. This also allows | |
# these methods to be defined uniquely for the hierarical SSM (see below). | |
# TODO: should we also define log-densities/distributions here? | |
function initialise(rng::AbstractRNG, model::AbstractStateSpaceModel, u::AbstractVector) | |
return initialise(rng, model.latent_dynamics, u) | |
end | |
function transition( | |
rng::AbstractRNG, | |
model::AbstractStateSpaceModel, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
return transition(rng, model.latent_dynamics, x, u, t) | |
end | |
function observation( | |
rng::AbstractRNG, | |
model::AbstractStateSpaceModel, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
return observation(rng, model.observation_process, x, u, t) | |
end | |
############################ | |
#### FORWARD SIMULATION #### | |
############################ | |
# TODO: could use with better control of element types | |
function StatsBase.sample( | |
rng::AbstractRNG, model::AbstractStateSpaceModel, controls::AbstractVector | |
) | |
xs = Vector{Vector}(undef, length(controls)) | |
ys = Vector{Vector}(undef, length(controls)) | |
for t in 1:length(control) | |
xs[t] = if t == 1 | |
initialise(rng, model, controls[t]) | |
else | |
transition(rng, model, xs[t - 1], controls[t], t) | |
end | |
ys[t] = observation(rng, model, xs[t], controls[t], t) | |
end | |
return xs, ys | |
end | |
function StatsBase.sample(model::AbstractStateSpaceModel, controls::AbstractVector) | |
return sample(Random.default_rng(), model, controls) | |
end | |
# Forward simulation is easy even when a linear Gaussian model is defined through matrices | |
# and vectors. | |
struct DefaultLinearGaussianModel <: LinearGaussianDynamics | |
D::Int # dimension | |
end | |
dim(model::DefaultLinearGaussianModel) = model.D | |
struct DefaultLinearGaussianObservation <: LinearGaussianObservation | |
D::Int # dimension | |
end | |
dim(model::DefaultLinearGaussianObservation) = model.D | |
# Example sampling (dynamics and observation dimensions must match) | |
unit_ssm = StateSpaceModel( | |
DefaultLinearGaussianModel(2), DefaultLinearGaussianObservation(2) | |
) | |
T = 100 | |
control = [zeros(1) for _ in 1:T] # arbitrary control whilst not yet optional | |
xs, ys = sample(unit_ssm, control) | |
# Plot the first dimension of the latent states and observations | |
p = plot(; title="Forward Simulation") | |
plot!(1:T, first.(xs); label="Latent Variables") | |
scatter!(1:T, first.(ys); label="Observations") | |
display(p) | |
###################### | |
#### CONDITIONING #### | |
###################### | |
# A state space model can be conditioned on a sequence of observations. This avoids | |
# observations needing to passed around as arguments to samplers and also allows a | |
# distinction between forward simulation and filtering/smoothing. | |
# HACK: simplified type for observations to make this work with the RB example | |
struct ConditionedStateSpaceModel{M<:AbstractStateSpaceModel} | |
model::M | |
observations::Vector | |
end | |
########################## | |
#### KALMAN FILTERING #### | |
########################## | |
# Special types of state space models can be defined for dispatching to optimised methods. | |
const LinearGaussianStateSpaceModel = | |
StateSpaceModel{LD,OP} where {LD<:LinearGaussianDynamics,OP<:LinearGaussianObservation} | |
# For example, the Kalman filter would be defined as follows | |
function StatsBase.sample( | |
rng::AbstractRNG, | |
cond_model::ConditionedStateSpaceModel{<:LinearGaussianStateSpaceModel}, | |
controls::AbstractVector, | |
) | |
# ... | |
end | |
# Alternatively, we might define sampling algorithms to, for example, differentiate between | |
# the vanilla Kalman filter, its square-root form, and even the RTS smoother. | |
abstract type SSMSamplingAlgorithm end | |
struct KalmanFilter <: SSMSamplingAlgorithm end | |
function StatsBase.sample( | |
rng::AbstractRNG, | |
cond_model::ConditionedStateSpaceModel{<:LinearGaussianStateSpaceModel}, | |
controls::AbstractVector, | |
algorithm::KalmanFilter, | |
) | |
# ... | |
end | |
# Importantly, because the linear Gaussian model is defined using a unified interface, there | |
# is nothing stopping you from sampling from it using a particle filter (e.g. for comparing | |
# the effciency against the Kalman filter) or anything other filtering method (e.g. for | |
# validating a new sampling algorithm). | |
# sample(rng, cond_linear_gaussian_model, controls, ParticleFilter(1000)) | |
############################# | |
#### HIERARCHICAL MODELS #### | |
############################# | |
# One of the main motivations for separating the latent dynamics and observation processes | |
# is that it allows a natural definition of hierarchical models. The value of the outer | |
# state can be passed to the inner state as a control variable. Not only is this useful for | |
# Rao-Blackwellisation, but it can also be used for (partially) independent particle filter | |
# of Lin, 2005. | |
struct HierarchicalStateSpaceModel{ | |
LD1<:LatentDynamics,LD2<:LatentDynamics,OP<:ObservationProcess | |
} <: AbstractStateSpaceModel | |
outer_latent_dynamics::LD1 | |
inner_latent_dynamics::LD2 | |
observation_process::OP | |
end | |
# Special cases of hierarchical models can be defined for dispatching to optimised methods | |
# through Rao-Blackwellisation. | |
const ConditionallyGaussianStateSpaceModel = HierarchicalStateSpaceModel{ | |
LD1,LD2,OP | |
} where {LD1<:LatentDynamics,LD2<:LinearGaussianDynamics,OP<:LinearGaussianObservation} | |
# A hierarchical SSM is still a regular SSM and can be sampled from or filtered in a general | |
# setting by defining appropriate methods | |
# TODO: this could be made cleaner if logdensities are including by defining a factorised | |
# Distributions.jl distribution | |
function initialise(rng::AbstractRNG, model::HierarchicalStateSpaceModel, u::AbstractVector) | |
outer_x = initialise(rng, model.outer_latent_dynamics, u) | |
inner_u = [outer_x; u] | |
inner_x = initialise(rng, model.inner_latent_dynamics, inner_u) | |
return [outer_x; inner_x] | |
end | |
function transition( | |
rng::AbstractRNG, | |
model::HierarchicalStateSpaceModel, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
# Transition outer dynamics | |
outer_x = x[1:dim(model.outer_latent_dynamics)] | |
outer_x = transition(rng, model.outer_latent_dynamics, outer_x, u, t) | |
inner_u = [outer_x; u] | |
# Transition inner dynamics | |
inner_x = x[(dim(model.outer_latent_dynamics) + 1):end] | |
inner_x = transition(rng, model.inner_latent_dynamics, inner_x, inner_u, t) | |
return [outer_x; inner_x] | |
end | |
function observation( | |
rng::AbstractRNG, | |
model::HierarchicalStateSpaceModel, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
outer_x = x[1:dim(model.outer_latent_dynamics)] | |
inner_u = [outer_x; u] | |
inner_x = x[(dim(model.outer_latent_dynamics) + 1):end] | |
return observation(rng, model.observation_process, inner_x, inner_u, t) | |
end | |
############################## | |
#### RAO-BLACKWELLISATION #### | |
############################## | |
# TODO: I'm currently only considering the conditionally Gaussian case of | |
# Rao-Blackwellisation. Other variants, such as conditionally discrete state space models | |
# can also be implemented. | |
# TODO: it should be possible to define a generic Rao-Blackwellised filter that called an | |
# arbitrary exact sampling algorithm for the inner dynamics (e.g. Kalman filter for linear | |
# Gaussian models or Forward-Backward for discrete models); they only need to return a | |
# weight update. | |
struct RaoBlackwellisedParticleFilter <: SSMSamplingAlgorithm | |
n_particles::Int | |
end | |
# TODO: likewise, this should be more generic, but illustrates the point | |
struct RaoBlackwellisedParticle | |
x::Vector{Float64} | |
μ::Vector{Float64} | |
Σ::Matrix{Float64} | |
log_w::Float64 | |
parent_idx::Int | |
end | |
# TODO: reduce re-used code between t == 1 and t > 1 cases | |
function StatsBase.sample( | |
rng::AbstractRNG, | |
cond_model::ConditionedStateSpaceModel{<:ConditionallyGaussianStateSpaceModel}, | |
controls::AbstractVector, | |
algorithm::RaoBlackwellisedParticleFilter, | |
) | |
ys = cond_model.observations | |
T = length(ys) | |
N = algorithm.n_particles | |
# TODO: should we swap the order of the dimensions for the particle container? | |
particle_container = Matrix{RaoBlackwellisedParticle}(undef, T, algorithm.n_particles) | |
outer = cond_model.model.outer_latent_dynamics | |
inner = cond_model.model.inner_latent_dynamics | |
obs = cond_model.model.observation_process | |
@showprogress for t in 1:T | |
y = ys[t] | |
u = controls[t] | |
if t == 1 | |
for i in 1:N | |
log_w = -log(N) | |
parent_idx = 0 | |
# Initialise outer state from prior | |
x = initialise(rng, cond_model.model.outer_latent_dynamics, u) | |
# The control vector for the inner dynamics is the concatenation of the | |
# the outer state of the control vector | |
inner_u = [x; u] | |
# Compute prior mean and covariance for inner state, conditioned on outer | |
μ0 = calc_μ0(cond_model.model.inner_latent_dynamics, inner_u) | |
Σ0 = calc_Σ0(cond_model.model.inner_latent_dynamics, inner_u) | |
# Extract model matrices and vectors | |
H = calc_H(obs, inner_u, t) | |
R = calc_R(obs, inner_u, t) | |
# Filter initial states | |
K = Σ0 * H' * inv(H * Σ0 * H' + R) | |
μ = μ0 + K * (y - H * μ0) | |
Σ = (I - K * H) * Σ0 | |
# Update weight given observation | |
μ_y = H * μ0 | |
Σ_y = H * Σ0 * H' + R | |
log_w += logpdf(MvNormal(μ_y, Σ_y), y) | |
particle_container[t, i] = RaoBlackwellisedParticle( | |
x, μ, Σ, log_w, parent_idx | |
) | |
end | |
else | |
# Resampling | |
weights = softmax(getproperty.(particle_container[t - 1, :], :log_w)) | |
parent_idxs = sample(1:N, Weights(weights), N) | |
for i in 1:N | |
# Resample | |
parent = particle_container[t - 1, parent_idxs[i]] | |
log_w = -log(N) | |
# Transition outer state | |
x = transition(rng, outer, parent.x, u, t) | |
# See t = 1 case | |
inner_u = [x; u] | |
# Extract model matrices and vectors | |
A = calc_A(inner, inner_u, t) | |
b = calc_b(inner, inner_u, t) | |
Q = calc_Q(inner, inner_u, t) | |
H = calc_H(obs, inner_u, t) | |
R = calc_R(obs, inner_u, t) | |
# Transition inner state | |
μ_pred = A * parent.μ + b | |
Σ_pred = A * parent.Σ * A' + Q | |
# Filter state | |
K = Σ_pred * H' * inv(H * Σ_pred * H' + R) | |
μ = μ_pred + K * (y - H * μ_pred) | |
Σ = (I - K * H) * Σ_pred | |
# Update weight given observation | |
μ_y = H * μ_pred | |
Σ_y = H * Σ_pred * H' + R | |
Σ_y = (Σ_y + Σ_y') / 2 # HACK: force symmetry | |
log_w += logpdf(MvNormal(μ_y, Σ_y), y) | |
particle_container[t, i] = RaoBlackwellisedParticle( | |
x, μ, Σ, log_w, parent_idxs[i] | |
) | |
end | |
end | |
end | |
return particle_container | |
end | |
function StatsBase.sample( | |
model::ConditionedStateSpaceModel, | |
controls::AbstractVector, | |
algorithm::RaoBlackwellisedParticleFilter, | |
) | |
return sample(Random.default_rng(), model, controls, algorithm) | |
end | |
###################### | |
#### FULL EXAMPLE #### | |
###################### | |
""" | |
Linear Gaussian dynamics that are represented as general dynamics. | |
""" | |
struct DummyLinearGaussianDynamics{T} <: LatentDynamics | |
A::Matrix{T} # transition matrix | |
b::Vector{T} # drift vector | |
Q::Matrix{T} # transition noise covariance | |
μ0::Vector{T} # prior mean | |
Σ0::Matrix{T} # prior covariance | |
D::Int # dimension | |
end | |
dim(model::DummyLinearGaussianDynamics) = model.D | |
function initialise( | |
rng::AbstractRNG, dynamics::DummyLinearGaussianDynamics, u::AbstractVector | |
) | |
return rand(rng, MvNormal(dynamics.μ0, dynamics.Σ0)) | |
end | |
function transition( | |
rng::AbstractRNG, | |
dynamics::DummyLinearGaussianDynamics, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
A, b, Q = dynamics.A, dynamics.b, dynamics.Q | |
return rand(rng, MvNormal(A * x + b, Q)) | |
end | |
struct InnerDynamics{T} <: LinearGaussianDynamics | |
A::Matrix{T} # transition matrix | |
Q::Matrix{T} # transition noise covariance | |
D_outer::Int # dimension of outer state | |
μ0::Vector{T} # prior mean | |
Σ0::Matrix{T} # prior covariance | |
end | |
calc_A(model::InnerDynamics, u::AbstractVector, t::Integer) = model.A | |
# Set drift equal to the value of the outer dynamics (e.g. Langevin dynamics) | |
calc_b(model::InnerDynamics, u::AbstractVector, t::Integer) = u[1:(model.D_outer)] | |
calc_Q(model::InnerDynamics, u::AbstractVector, t::Integer) = model.Q | |
calc_μ0(model::InnerDynamics, u::AbstractVector) = model.μ0 | |
calc_Σ0(model::InnerDynamics, u::AbstractVector) = model.Σ0 | |
# Define arbitrary parameters (for this example, outer and inner dimensions must be the same) | |
D_outer = 2 | |
D_inner = 2 | |
A_outer = [0.5 0.2; 0.1 0.4] | |
b_outer = [0.1, 0.2] | |
Q_outer = [0.1 0.05; 0.05 0.1] | |
μ0_outer = [0.0, 0.0] | |
Σ0_outer = [1.0 0.0; 0.0 1.0] | |
A_inner = [0.9 0.1; 0.0 0.8] | |
Q_inner = [1.0 0.3; 0.3 1.2] | |
μ0_inner = [0.0, 0.0] | |
Σ0_inner = [1.0 0.0; 0.0 1.0] | |
example_model = HierarchicalStateSpaceModel( | |
DummyLinearGaussianDynamics(A_outer, b_outer, Q_outer, μ0_outer, Σ0_outer, D_outer), | |
InnerDynamics(A_inner, Q_inner, D_outer, μ0_inner, Σ0_inner), | |
DefaultLinearGaussianObservation(D_inner), | |
) | |
# Simulate from the model | |
T = 100 | |
control = [zeros(1) for _ in 1:T] | |
xs, ys = sample(example_model, control) | |
# Plot first dimension of both latent states and observations | |
p = plot(; title="Forward Simulation of Hierarchical Model") | |
plot!(1:T, getindex.(xs, 1); label="Outer Latent Variables") | |
plot!(1:T, getindex.(xs, D_outer + 1); label="Inner Latent Variables") | |
scatter!(1:T, getindex.(ys, 1); label="Observations") | |
display(p) | |
# Run Rao-Blackwellised particle filter | |
N = 10000 | |
rbpf = RaoBlackwellisedParticleFilter(N) | |
particle_container = sample(ConditionedStateSpaceModel(example_model, ys), control, rbpf) | |
#################### | |
#### VALIDATION #### | |
#################### | |
# Define joint Kalman model | |
A = [ | |
A_outer zeros(D_outer, D_inner) | |
I A_inner | |
] | |
b = [b_outer; zeros(D_inner)] | |
Q = [ | |
Q_outer zeros(D_outer, D_inner) | |
zeros(D_inner, D_outer) Q_inner | |
] | |
E = LinearEvolution(A, Gaussian(b, Q)) | |
H = [zeros(D_inner, D_outer) I] | |
R = Diagonal(ones(D_inner)) | |
O = LinearObservation(E, H, R) | |
μ0 = [μ0_outer; μ0_inner] | |
Σ0 = [ | |
Σ0_outer zeros(D_outer, D_inner) | |
zeros(D_inner, D_outer) Σ0_inner | |
] | |
G0 = Gaussian(μ0, Σ0) | |
Y = [t => ys[t] for t in 1:T] | |
# Ground truth filtering | |
Xf, ll = kalmanfilter(O, 0 => G0, Y) | |
# Compare estimates for the mean of the filtered distribution the last inner and outer state | |
# dimension (expect the outer state to be a noisier estimate) | |
inner_kalman_estimates = map(G -> G.μ[end], Xf.x) | |
inner_rbpf_estimates = Vector{Float64}(undef, T) | |
for t in 1:T | |
weights = softmax(getproperty.(particle_container[t, :], :log_w)) | |
inner_rbpf_estimates[t] = sum(weights .* map(p -> p.μ[end], particle_container[t, :])) | |
end | |
p1 = plot(; title="Last Inner State") | |
plot!(p1, 1:T, inner_kalman_estimates; label="Kalman Estimates") | |
plot!(p1, 1:T, inner_rbpf_estimates; label="RBPF Estimates") | |
outer_kalman_estimates = map(G -> G.μ[1], Xf.x) | |
outer_rbpf_estimates = Vector{Float64}(undef, T) | |
for t in 1:T | |
weights = softmax(getproperty.(particle_container[t, :], :log_w)) | |
outer_rbpf_estimates[t] = sum(weights .* map(p -> p.x[1], particle_container[t, :])) | |
end | |
p2 = plot(; title="First Outer State") | |
plot!(p2, 1:T, outer_kalman_estimates; label="Kalman Estimates") | |
plot!(p2, 1:T, outer_rbpf_estimates; label="RBPF Estimates") | |
plot( | |
p1, | |
p2; | |
plot_title="Comparison of Kalman and Rao-Blackwellised Estimates", | |
layout=(2, 1), | |
size=(800, 800), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Final output of the script comparing the RBPF to the ground truth Kalman filter: