Created
May 30, 2024 21:52
-
-
Save THargreaves/9691f832563595debe27b754e4f3d67d to your computer and use it in GitHub Desktop.
An implementation of the Lévy SSM using a RBPF implemented in the proposed SSMProblems.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions | |
using FillArrays | |
using LinearAlgebra | |
using LogExpFunctions | |
using Plots | |
using ProgressMeter | |
using Random | |
using StatsBase | |
using LevyProcesses | |
############################ | |
#### MODIFIED INTERFACE #### | |
############################ | |
""" | |
Latent dynamics of a state space model. Should implement the following methods using the | |
value of the control vector and current time-step as inputs: | |
- `transition` | |
- `transition_logdensity` | |
Alternatively, you can specify a `transition_distribution method`` that will be used to | |
generate the above two methods. | |
An `initialise` method should also be implemented to generate the initial state | |
but this (obviously) doesn't need to accept the time-step. | |
Finally, a method `dim` returning the dimension of the latent state should be defined. | |
""" | |
abstract type LatentDynamics end | |
# Default transition methods when transition_distribution is defined | |
function transition( | |
rng::AbstractRNG, | |
dynamics::LatentDynamics, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
check_transition_distribution(dynamics) | |
return rand(rng, transition_distribution(dynamics, x, u, t)) | |
end | |
function transition_logdensity( | |
dynamics::LatentDynamics, | |
x_curr::AbstractVector, | |
x_next::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
check_transition_distribution(dynamics) | |
return logpdf(transition_distribution(dynamics, x_curr, u, t), x_next) | |
end | |
function check_transition_distribution(dynamics::LatentDynamics) | |
if !hasmethod( | |
transition_distribution, | |
Tuple{typeof(dynamics),AbstractVector,AbstractVector,Integer}, | |
) | |
calling_func = StackTraces.stacktrace()[2].func | |
error("neither $calling_func nor transition_distribution is defined for $dynamics") | |
end | |
end | |
""" | |
Latent dynamics of a state space model that are linear and Gaussian. Methods should be | |
implemented to define how to construct the following quanities given the control vector | |
and time step: | |
- A: transition matrix | |
- b: drift vector | |
- Q: transition noise covariance | |
These should be name `calc_A`, `calc_b`, and `calc_Q` respectively. In the | |
simplest case these would simply return the values of matrices and vectors stored in the | |
model struct, but they can also be more complex functions that depend on the control and | |
time-step in arbitrary ways. | |
Likewise, a prior mean vector and covariance matrix should be specified by `calc_μ0` and | |
`calc_Σ0`, respectively, not depending on the time-step. | |
Note, that we do not include a control matrix `B` as is standard in the literature since | |
we have opted for the more general setting in which the drift vector can be any | |
non-linear function of the control. To retrieve this standard behaviour, one can simple | |
define `calc_C(...) = dynamics.B * u`. | |
""" | |
# Note: defining our linear Gaussian model in this way (rather than fixed matrices/vectors) | |
# has an additional advantage on top of unifying the interface. Specifically, it allows for | |
# time-varying and control-dependent linear Gaussian dynamics, facilitating incredibly | |
# expressive modelling. This could be achieved by defining a vector of matrices/vectors, one | |
# for each time-step, but this is incredibly memory inefficient. | |
abstract type LinearGaussianDynamics <: LatentDynamics end | |
# Default transitions for use in general particle filtering | |
function transition_distribution( | |
dynamics::LinearGaussianDynamics, x::AbstractVector, u::AbstractVector, t::Integer | |
) | |
A, b, Q = calc_transition(dynamics, u, t) | |
return MvNormal(A * x + b, Q) | |
end | |
# Default initialisation for use in general particle filtering | |
function initialise(rng::AbstractRNG, dynamics::LinearGaussianDynamics, u::AbstractVector) | |
μ0 = calc_μ0(dynamics, u) | |
Σ0 = calc_Σ0(dynamics, u) | |
return rand(rng, MvNormal(μ0, Σ0)) | |
end | |
# Default model matrices/vectors can be defined. This is convenient for simple models that | |
# don't include a drift term. This could potentially lead to obfuscated bugs if the user | |
# attempts to define `calc_X` but has a typo or incorrect signature. Perhaps including a | |
# warning that default values are being used would be useful. | |
calc_A(dynamics::LinearGaussianDynamics, u::AbstractVector, t::Integer) = I | |
calc_b(dynamics::LinearGaussianDynamics, u::AbstractVector, t::Integer) = Zeros(dynamics.D) | |
calc_Q(dynamics::LinearGaussianDynamics, u::AbstractVector, t::Integer) = I | |
calc_μ0(dynamics::LinearGaussianDynamics, u::AbstractVector) = Zeros(dynamics.D) | |
calc_Σ0(dynamics::LinearGaussianDynamics, u::AbstractVector) = I | |
# TODO: think this should have two dimensions, input and output | |
""" | |
Observation process of a state space model. Should implement the following methods using | |
the value of the control vector and current time-step as inputs: | |
- `observation` | |
- `observation_logdensity` | |
Alternatively, you can specify an observation_distribution method that will be used to | |
generate the above two methods. | |
A method `dim` returning the dimension of the observation should also be defined. | |
""" | |
abstract type ObservationProcess end | |
# Default observation methods when observation_distribution is defined | |
function observation( | |
rng::AbstractRNG, | |
process::ObservationProcess, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
check_observation_distribution(process) | |
return rand(rng, observation_distribution(process, x, u, t)) | |
end | |
function observation_logdensity( | |
process::ObservationProcess, | |
x::AbstractVector, | |
y::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
check_observation_distribution(process) | |
return logpdf(observation_distribution(process, x, u, t), y) | |
end | |
# TODO: is a try-catch more efficient here. Or better, can we cache the result? | |
function check_observation_distribution(process::ObservationProcess) | |
if !hasmethod( | |
observation_distribution, | |
Tuple{typeof(process),AbstractVector,AbstractVector,Integer}, | |
) | |
calling_func = StackTraces.stacktrace()[2].func | |
error("neither $calling_func nor observation_distribution is defined for $process") | |
end | |
end | |
""" | |
Observation process of a state space model that is linear and Gaussian. Methods should | |
be implemented to define how to construct the following quanities given the control | |
vector and time step: | |
- H: observation matrix | |
- c: observation bias vector | |
- R: observation noise covariance | |
These should be name `calc_H`, `calc_c`, and `calc_R` respectively. In the simplest case | |
these would simply return the values of matrices and vectors stored in the model struct, | |
but they can also be more complex functions that depend on the control. | |
""" | |
abstract type LinearGaussianObservation <: ObservationProcess end | |
# Default observation for use in general particle filtering | |
function observation_distribution( | |
process::LinearGaussianObservation, x::AbstractVector, u::AbstractVector, t::Integer | |
) | |
H = calc_H(process, u, t) | |
c = calc_c(process, u, t) | |
R = calc_R(process, u, t) | |
return MvNormal(H * x + c, R) | |
end | |
# Default model matrices/vectors | |
calc_H(process::LinearGaussianObservation, u::AbstractVector, t::Integer) = I | |
calc_c(process::LinearGaussianObservation, u::AbstractVector, t::Integer) = Zeros(process.D) | |
calc_R(process::LinearGaussianObservation, u::AbstractVector, t::Integer) = I | |
# Regular and hierarchical models are both state space models, so we introduce an abstract | |
# type that they will both inherit from | |
abstract type AbstractStateSpaceModel end | |
# TODO: need a way of ensuring that the latent and observation dimensions match (for all H). | |
"""A state space model combines latent dynamics and an observation processes.""" | |
struct StateSpaceModel{LD<:LatentDynamics,OP<:ObservationProcess} <: AbstractStateSpaceModel | |
latent_dynamics::LD | |
observation_process::OP | |
end | |
# We make this split interface compatible with the existing interface by defining | |
# initialisation, transition, and observation methods for the entire SSM. This also allows | |
# these methods to be defined uniquely for the hierarical SSM (see below). | |
# TODO: should we also define log-densities/distributions here? | |
function initialise(rng::AbstractRNG, model::AbstractStateSpaceModel, u::AbstractVector) | |
return initialise(rng, model.latent_dynamics, u) | |
end | |
function transition( | |
rng::AbstractRNG, | |
model::AbstractStateSpaceModel, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
return transition(rng, model.latent_dynamics, x, u, t) | |
end | |
function observation( | |
rng::AbstractRNG, | |
model::AbstractStateSpaceModel, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
return observation(rng, model.observation_process, x, u, t) | |
end | |
############################ | |
#### FORWARD SIMULATION #### | |
############################ | |
# TODO: could use with better control of element types | |
function StatsBase.sample( | |
rng::AbstractRNG, model::AbstractStateSpaceModel, controls::AbstractVector | |
) | |
xs = Vector{Vector}(undef, length(controls)) | |
ys = Vector{Vector}(undef, length(controls)) | |
for t in 1:length(control) | |
xs[t] = if t == 1 | |
initialise(rng, model, controls[t]) | |
else | |
transition(rng, model, xs[t-1], controls[t], t) | |
end | |
ys[t] = observation(rng, model, xs[t], controls[t], t) | |
end | |
return xs, ys | |
end | |
function StatsBase.sample(model::AbstractStateSpaceModel, controls::AbstractVector) | |
return sample(Random.default_rng(), model, controls) | |
end | |
# Forward simulation is easy even when a linear Gaussian model is defined through matrices | |
# and vectors. | |
struct DefaultLinearGaussianModel <: LinearGaussianDynamics | |
D::Int # dimension | |
end | |
dim(model::DefaultLinearGaussianModel) = model.D | |
struct DefaultLinearGaussianObservation <: LinearGaussianObservation | |
D::Int # dimension | |
end | |
dim(model::DefaultLinearGaussianObservation) = model.D | |
###################### | |
#### CONDITIONING #### | |
###################### | |
# A state space model can be conditioned on a sequence of observations. This avoids | |
# observations needing to passed around as arguments to samplers and also allows a | |
# distinction between forward simulation and filtering/smoothing. | |
# HACK: simplified type for observations to make this work with the RB example | |
struct ConditionedStateSpaceModel{M<:AbstractStateSpaceModel} | |
model::M | |
observations::Vector | |
end | |
############################# | |
#### HIERARCHICAL MODELS #### | |
############################# | |
# One of the main motivations for separating the latent dynamics and observation processes | |
# is that it allows a natural definition of hierarchical models. The value of the outer | |
# state can be passed to the inner state as a control variable. Not only is this useful for | |
# Rao-Blackwellisation, but it can also be used for (partially) independent particle filter | |
# of Lin, 2005. | |
struct HierarchicalStateSpaceModel{ | |
LD1<:LatentDynamics,LD2<:LatentDynamics,OP<:ObservationProcess | |
} <: AbstractStateSpaceModel | |
outer_latent_dynamics::LD1 | |
inner_latent_dynamics::LD2 | |
observation_process::OP | |
end | |
# Special cases of hierarchical models can be defined for dispatching to optimised methods | |
# through Rao-Blackwellisation. | |
const ConditionallyGaussianStateSpaceModel = HierarchicalStateSpaceModel{ | |
LD1,LD2,OP | |
} where {LD1<:LatentDynamics,LD2<:LinearGaussianDynamics,OP<:LinearGaussianObservation} | |
# A hierarchical SSM is still a regular SSM and can be sampled from or filtered in a general | |
# setting by defining appropriate methods | |
# TODO: this could be made cleaner if logdensities are including by defining a factorised | |
# Distributions.jl distribution | |
function initialise(rng::AbstractRNG, model::HierarchicalStateSpaceModel, u::AbstractVector) | |
outer_x = initialise(rng, model.outer_latent_dynamics, u) | |
inner_u = [outer_x; u] | |
inner_x = initialise(rng, model.inner_latent_dynamics, inner_u) | |
return [outer_x; inner_x] | |
end | |
function transition( | |
rng::AbstractRNG, | |
model::HierarchicalStateSpaceModel, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
# Transition outer dynamics | |
# outer_x = x[1:dim(model.outer_latent_dynamics)] | |
# HACK | |
outer_x = first(x) | |
outer_x = transition(rng, model.outer_latent_dynamics, outer_x, u, t) | |
inner_u = [outer_x; u] | |
# Transition inner dynamics | |
inner_x = x[(dim(model.outer_latent_dynamics)+1):end] | |
inner_x = transition(rng, model.inner_latent_dynamics, inner_x, inner_u, t) | |
return [outer_x; inner_x] | |
end | |
function observation( | |
rng::AbstractRNG, | |
model::HierarchicalStateSpaceModel, | |
x::AbstractVector, | |
u::AbstractVector, | |
t::Integer, | |
) | |
outer_x = x[1:dim(model.outer_latent_dynamics)] | |
inner_u = [outer_x; u] | |
inner_x = x[(dim(model.outer_latent_dynamics)+1):end] | |
return observation(rng, model.observation_process, inner_x, inner_u, t) | |
end | |
###################################### | |
#### ABSTRACT PARTICLE CONTAINERS #### | |
###################################### | |
# Depending on your downstream use cases, you may want to store different quantities | |
# generated by the particle filter. In the most extreme case (e.g. backward simulation), you | |
# need to store the entire NxT particle history. For regulary particle smoothing, you only | |
# need to store particles that survived resampling as in "Path storage in the particle | |
# filter". In an extreme case (e.g. large N, T where memory constaints are an issue), you | |
# may just want to store a summary of the particles at each time-step. | |
# The abstract particle container is designed to separate the details of particle storage | |
# from the particle filter algorithm. At each time-step the filtering algorithm passes all | |
# relevant variables to the particle container, which then stores them in whatever way it is | |
# defined to. | |
abstract type AbstractParticleContainer end | |
""" | |
A particle container that only stores the (weighted) mean of particles at each time-step. | |
""" | |
struct MeanParticleContainer{T} <: AbstractParticleContainer | |
μ::Vector{T} | |
end | |
MeanParticleContainer(T::Type, N_steps::Int) = MeanParticleContainer{T}(Vector{T}(undef, N_steps)) | |
# Should this be Vector{S} where S <: Real? | |
function store!( | |
container::MeanParticleContainer{T}, | |
t::Integer, | |
xs::AbstractVector{T}, | |
log_ws::AbstractVector{<:Real} | |
) where {T} | |
weights = softmax(log_ws) | |
container.μ[t] = sum(xs .* weights) | |
end | |
############################## | |
#### RAO-BLACKWELLISATION #### | |
############################## | |
abstract type SSMSamplingAlgorithm end | |
struct RaoBlackwellisedParticleFilter <: SSMSamplingAlgorithm | |
n_particles::Int | |
end | |
# TODO: reduce re-used code between t == 1 and t > 1 cases | |
function StatsBase.sample!( | |
rng::AbstractRNG, | |
cond_model::ConditionedStateSpaceModel{<:ConditionallyGaussianStateSpaceModel}, | |
controls::AbstractVector, | |
algorithm::RaoBlackwellisedParticleFilter, | |
particle_container::AbstractParticleContainer | |
) | |
ys = cond_model.observations | |
T = length(ys) | |
N = algorithm.n_particles | |
outer = cond_model.model.outer_latent_dynamics | |
inner = cond_model.model.inner_latent_dynamics | |
obs = cond_model.model.observation_process | |
xs = Vector{SampleJumps}(undef, N) | |
μs = Vector{Vector{Float64}}(undef, N) | |
Σs = Vector{Matrix{Float64}}(undef, N) | |
log_ws = Vector{Float64}(undef, N) | |
@showprogress for t in 1:T | |
y = ys[t] | |
u = controls[t] | |
if t == 1 | |
for i in 1:N | |
log_ws[i] = -log(N) | |
# Initialise outer state from prior | |
x = initialise(rng, cond_model.model.outer_latent_dynamics, u) | |
# The control vector for the inner dynamics is the concatenation of the | |
# the outer state of the control vector | |
inner_u = [x; u] | |
xs[i] = x | |
# Compute prior mean and covariance for inner state, conditioned on outer | |
μ0, Σ0 = calc_initial(cond_model.model.inner_latent_dynamics, inner_u) | |
# Extract model matrices and vectors | |
H = calc_H(obs, inner_u, t) | |
R = calc_R(obs, inner_u, t) | |
# Filter initial states | |
K = Σ0 * H' * inv(H * Σ0 * H' + R) | |
μs[i] = μ0 + K * (y - H * μ0) | |
Σs[i] = (I - K * H) * Σ0 | |
# Update weight given observation | |
μ_y = H * μ0 | |
Σ_y = H * Σ0 * H' + R | |
log_ws[i] += logpdf(MvNormal(μ_y, Σ_y), y) | |
end | |
store!(particle_container, t, μs, log_ws) | |
else | |
# Resampling | |
weights = softmax(log_ws) | |
parent_idxs = sample(1:N, Weights(weights), N) | |
xs = xs[parent_idxs] | |
μs = μs[parent_idxs] | |
Σs = Σs[parent_idxs] | |
for i in 1:N | |
# Resample | |
log_ws[i] = -log(N) | |
# Transition outer state | |
x = transition(rng, outer, SampleJumps(Float64[], Float64[]), u, t) | |
# See t = 1 case | |
inner_u = [x; u] | |
xs[i] = x | |
# Extract model matrices and vectors | |
A, b, Q = calc_transition(inner, inner_u, t) | |
H = calc_H(obs, inner_u, t) | |
R = calc_R(obs, inner_u, t) | |
# Transition inner state | |
μ_pred = A * μs[i] + b | |
Σ_pred = A * Σs[i] * A' + Q | |
# Filter state | |
K = Σ_pred * H' * inv(H * Σ_pred * H' + R) | |
μs[i] = μ_pred + K * (y - H * μ_pred) | |
Σs[i] = (I - K * H) * Σ_pred | |
# Update weight given observation | |
μ_y = H * μ_pred | |
Σ_y = H * Σ_pred * H' + R | |
Σ_y = (Σ_y + Σ_y') / 2 # HACK: force symmetry | |
log_ws[i] += logpdf(MvNormal(μ_y, Σ_y), y) | |
end | |
store!(particle_container, t, μs, log_ws) | |
end | |
end | |
return particle_container | |
end | |
function StatsBase.sample!( | |
model::ConditionedStateSpaceModel, | |
controls::AbstractVector, | |
algorithm::RaoBlackwellisedParticleFilter, | |
particle_container::AbstractParticleContainer | |
) | |
return sample!(Random.default_rng(), model, controls, algorithm, particle_container) | |
end | |
###################### | |
#### FULL EXAMPLE #### | |
###################### | |
C = 1.0 | |
β = 0.4 | |
ϵ = 1e-10 | |
μw = 0.0 | |
σw = 1.0 | |
θ = -0.5 | |
σe = 2.0 | |
""" | |
Linear Gaussian dynamics that are represented as general dynamics. | |
""" | |
struct SubordinatorDynamics <: LatentDynamics | |
p::LevyProcess | |
end | |
dim(model::SubordinatorDynamics) = 1 | |
function initialise( | |
rng::AbstractRNG, dyn::SubordinatorDynamics, u::AbstractVector | |
) | |
return SampleJumps(Float64[], Float64[]) | |
end | |
function transition( | |
rng::AbstractRNG, | |
dynamics::SubordinatorDynamics, | |
x::SampleJumps, | |
u::AbstractVector, | |
t::Integer, | |
) | |
dt = only(u) # store time deltas as control variables | |
return sample(rng, dynamics.p, dt) | |
end | |
struct LangevinLatentDynamics{T} <: LinearGaussianDynamics | |
dyn::LevyDrivenLinearSDE{T} | |
end | |
dim(model::LangevinLatentDynamics) = 2 | |
function calc_transition(model::LangevinLatentDynamics, u::AbstractVector, t::Integer) | |
# TODO: this requires u to contain `Any` elements, which is not great for performance | |
dt = u[2] | |
jumps = u[1] | |
dist = conditional_marginal(jumps, model.dyn, [0.0, 0.0], dt) | |
A = exp(model.dyn.linear_dynamics, dt) | |
b = dist.μ | |
Q = dist.Σ | |
return A, b, Q | |
end | |
calc_μ0(model::LangevinLatentDynamics, u::AbstractVector) = zeros(2) | |
calc_Σ0(model::LangevinLatentDynamics, u::AbstractVector) = Diagonal(ones(2)) | |
function calc_initial(model::LangevinLatentDynamics, u::AbstractVector) | |
return calc_μ0(model, u), calc_Σ0(model, u) | |
end | |
dyn = LangevinDynamics(θ) | |
struct PositionSelector <: LinearGaussianObservation | |
σ2::Float64 | |
end | |
dim(model::PositionSelector) = 1 | |
calc_H(model::PositionSelector, u::AbstractVector, t::Integer) = [1.0 0.0] | |
calc_c(model::PositionSelector, u::AbstractVector, t::Integer) = zeros(1) | |
calc_R(model::PositionSelector, u::AbstractVector, t::Integer) = Diagonal([model.σ2]) | |
subordinator = GammaProcess(C, β) | |
truncated_subordinator = truncate(subordinator, ϵ) | |
nvm = NormalVarianceMeanProcess(truncated_subordinator, μw, σw) | |
sde = LevyDrivenLinearSDE(nvm, dyn, [0.0, 1.0]) | |
example_model = HierarchicalStateSpaceModel( | |
SubordinatorDynamics(truncated_subordinator), | |
LangevinLatentDynamics(sde), | |
PositionSelector(σe^2), | |
) | |
# Simulate from the model | |
SEED = 1234 | |
T = 200 | |
finish = 100.0 | |
control = [finish / T * ones(1) for _ in 1:T] | |
rng = Random.MersenneTwister(SEED) | |
xs, ys = sample(rng, example_model, control) | |
# Plot position state and observation | |
p = plot(; title="Forward Simulation of Hierarchical Model") | |
plot!(1:T, getindex.(xs, 2); label="Position") | |
scatter!(1:T, getindex.(ys, 1); label="Observations") | |
display(p) | |
# Plot velocity state | |
p = plot(; title="Forward Simulation of Hierarchical Model") | |
plot!(1:T, getindex.(xs, 3); label="Velocity") | |
display(p) | |
# Run Rao-Blackwellised particle filter | |
N = 10^4 | |
rbpf = RaoBlackwellisedParticleFilter(N) | |
# TODO: is this not inefficient since Julia doesn't know length of inner vector? | |
particle_container = MeanParticleContainer(Vector{Float64}, T) | |
sample!(ConditionedStateSpaceModel(example_model, ys), control, rbpf, particle_container) | |
# Compare filtered trajectory to true trajectory | |
filtered_mean = particle_container.μ | |
p = plot(;) | |
plot!(1:T, getindex.(xs, 2); label="True Position") | |
plot!(1:T, getindex.(filtered_mean, 1); label="Filtered Position") | |
scatter!(1:T, getindex.(ys, 1); label="Observations", alpha=0.5) | |
display(p) | |
p2 = plot(;) | |
plot!(1:T, getindex.(xs, 3); label="True Velocity") | |
plot!(1:T, getindex.(filtered_mean, 2); label="Filtered Velocity") | |
plot(p, p2, layout=(2, 1), size=(800, 600)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment