Skip to content

Instantly share code, notes, and snippets.

@THargreaves
Created May 30, 2024 21:52
Show Gist options
  • Save THargreaves/9691f832563595debe27b754e4f3d67d to your computer and use it in GitHub Desktop.
Save THargreaves/9691f832563595debe27b754e4f3d67d to your computer and use it in GitHub Desktop.
An implementation of the Lévy SSM using a RBPF implemented in the proposed SSMProblems.jl
using Distributions
using FillArrays
using LinearAlgebra
using LogExpFunctions
using Plots
using ProgressMeter
using Random
using StatsBase
using LevyProcesses
############################
#### MODIFIED INTERFACE ####
############################
"""
Latent dynamics of a state space model. Should implement the following methods using the
value of the control vector and current time-step as inputs:
- `transition`
- `transition_logdensity`
Alternatively, you can specify a `transition_distribution method`` that will be used to
generate the above two methods.
An `initialise` method should also be implemented to generate the initial state
but this (obviously) doesn't need to accept the time-step.
Finally, a method `dim` returning the dimension of the latent state should be defined.
"""
abstract type LatentDynamics end
# Default transition methods when transition_distribution is defined
function transition(
rng::AbstractRNG,
dynamics::LatentDynamics,
x::AbstractVector,
u::AbstractVector,
t::Integer,
)
check_transition_distribution(dynamics)
return rand(rng, transition_distribution(dynamics, x, u, t))
end
function transition_logdensity(
dynamics::LatentDynamics,
x_curr::AbstractVector,
x_next::AbstractVector,
u::AbstractVector,
t::Integer,
)
check_transition_distribution(dynamics)
return logpdf(transition_distribution(dynamics, x_curr, u, t), x_next)
end
function check_transition_distribution(dynamics::LatentDynamics)
if !hasmethod(
transition_distribution,
Tuple{typeof(dynamics),AbstractVector,AbstractVector,Integer},
)
calling_func = StackTraces.stacktrace()[2].func
error("neither $calling_func nor transition_distribution is defined for $dynamics")
end
end
"""
Latent dynamics of a state space model that are linear and Gaussian. Methods should be
implemented to define how to construct the following quanities given the control vector
and time step:
- A: transition matrix
- b: drift vector
- Q: transition noise covariance
These should be name `calc_A`, `calc_b`, and `calc_Q` respectively. In the
simplest case these would simply return the values of matrices and vectors stored in the
model struct, but they can also be more complex functions that depend on the control and
time-step in arbitrary ways.
Likewise, a prior mean vector and covariance matrix should be specified by `calc_μ0` and
`calc_Σ0`, respectively, not depending on the time-step.
Note, that we do not include a control matrix `B` as is standard in the literature since
we have opted for the more general setting in which the drift vector can be any
non-linear function of the control. To retrieve this standard behaviour, one can simple
define `calc_C(...) = dynamics.B * u`.
"""
# Note: defining our linear Gaussian model in this way (rather than fixed matrices/vectors)
# has an additional advantage on top of unifying the interface. Specifically, it allows for
# time-varying and control-dependent linear Gaussian dynamics, facilitating incredibly
# expressive modelling. This could be achieved by defining a vector of matrices/vectors, one
# for each time-step, but this is incredibly memory inefficient.
abstract type LinearGaussianDynamics <: LatentDynamics end
# Default transitions for use in general particle filtering
function transition_distribution(
dynamics::LinearGaussianDynamics, x::AbstractVector, u::AbstractVector, t::Integer
)
A, b, Q = calc_transition(dynamics, u, t)
return MvNormal(A * x + b, Q)
end
# Default initialisation for use in general particle filtering
function initialise(rng::AbstractRNG, dynamics::LinearGaussianDynamics, u::AbstractVector)
μ0 = calc_μ0(dynamics, u)
Σ0 = calc_Σ0(dynamics, u)
return rand(rng, MvNormal(μ0, Σ0))
end
# Default model matrices/vectors can be defined. This is convenient for simple models that
# don't include a drift term. This could potentially lead to obfuscated bugs if the user
# attempts to define `calc_X` but has a typo or incorrect signature. Perhaps including a
# warning that default values are being used would be useful.
calc_A(dynamics::LinearGaussianDynamics, u::AbstractVector, t::Integer) = I
calc_b(dynamics::LinearGaussianDynamics, u::AbstractVector, t::Integer) = Zeros(dynamics.D)
calc_Q(dynamics::LinearGaussianDynamics, u::AbstractVector, t::Integer) = I
calc_μ0(dynamics::LinearGaussianDynamics, u::AbstractVector) = Zeros(dynamics.D)
calc_Σ0(dynamics::LinearGaussianDynamics, u::AbstractVector) = I
# TODO: think this should have two dimensions, input and output
"""
Observation process of a state space model. Should implement the following methods using
the value of the control vector and current time-step as inputs:
- `observation`
- `observation_logdensity`
Alternatively, you can specify an observation_distribution method that will be used to
generate the above two methods.
A method `dim` returning the dimension of the observation should also be defined.
"""
abstract type ObservationProcess end
# Default observation methods when observation_distribution is defined
function observation(
rng::AbstractRNG,
process::ObservationProcess,
x::AbstractVector,
u::AbstractVector,
t::Integer,
)
check_observation_distribution(process)
return rand(rng, observation_distribution(process, x, u, t))
end
function observation_logdensity(
process::ObservationProcess,
x::AbstractVector,
y::AbstractVector,
u::AbstractVector,
t::Integer,
)
check_observation_distribution(process)
return logpdf(observation_distribution(process, x, u, t), y)
end
# TODO: is a try-catch more efficient here. Or better, can we cache the result?
function check_observation_distribution(process::ObservationProcess)
if !hasmethod(
observation_distribution,
Tuple{typeof(process),AbstractVector,AbstractVector,Integer},
)
calling_func = StackTraces.stacktrace()[2].func
error("neither $calling_func nor observation_distribution is defined for $process")
end
end
"""
Observation process of a state space model that is linear and Gaussian. Methods should
be implemented to define how to construct the following quanities given the control
vector and time step:
- H: observation matrix
- c: observation bias vector
- R: observation noise covariance
These should be name `calc_H`, `calc_c`, and `calc_R` respectively. In the simplest case
these would simply return the values of matrices and vectors stored in the model struct,
but they can also be more complex functions that depend on the control.
"""
abstract type LinearGaussianObservation <: ObservationProcess end
# Default observation for use in general particle filtering
function observation_distribution(
process::LinearGaussianObservation, x::AbstractVector, u::AbstractVector, t::Integer
)
H = calc_H(process, u, t)
c = calc_c(process, u, t)
R = calc_R(process, u, t)
return MvNormal(H * x + c, R)
end
# Default model matrices/vectors
calc_H(process::LinearGaussianObservation, u::AbstractVector, t::Integer) = I
calc_c(process::LinearGaussianObservation, u::AbstractVector, t::Integer) = Zeros(process.D)
calc_R(process::LinearGaussianObservation, u::AbstractVector, t::Integer) = I
# Regular and hierarchical models are both state space models, so we introduce an abstract
# type that they will both inherit from
abstract type AbstractStateSpaceModel end
# TODO: need a way of ensuring that the latent and observation dimensions match (for all H).
"""A state space model combines latent dynamics and an observation processes."""
struct StateSpaceModel{LD<:LatentDynamics,OP<:ObservationProcess} <: AbstractStateSpaceModel
latent_dynamics::LD
observation_process::OP
end
# We make this split interface compatible with the existing interface by defining
# initialisation, transition, and observation methods for the entire SSM. This also allows
# these methods to be defined uniquely for the hierarical SSM (see below).
# TODO: should we also define log-densities/distributions here?
function initialise(rng::AbstractRNG, model::AbstractStateSpaceModel, u::AbstractVector)
return initialise(rng, model.latent_dynamics, u)
end
function transition(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
x::AbstractVector,
u::AbstractVector,
t::Integer,
)
return transition(rng, model.latent_dynamics, x, u, t)
end
function observation(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
x::AbstractVector,
u::AbstractVector,
t::Integer,
)
return observation(rng, model.observation_process, x, u, t)
end
############################
#### FORWARD SIMULATION ####
############################
# TODO: could use with better control of element types
function StatsBase.sample(
rng::AbstractRNG, model::AbstractStateSpaceModel, controls::AbstractVector
)
xs = Vector{Vector}(undef, length(controls))
ys = Vector{Vector}(undef, length(controls))
for t in 1:length(control)
xs[t] = if t == 1
initialise(rng, model, controls[t])
else
transition(rng, model, xs[t-1], controls[t], t)
end
ys[t] = observation(rng, model, xs[t], controls[t], t)
end
return xs, ys
end
function StatsBase.sample(model::AbstractStateSpaceModel, controls::AbstractVector)
return sample(Random.default_rng(), model, controls)
end
# Forward simulation is easy even when a linear Gaussian model is defined through matrices
# and vectors.
struct DefaultLinearGaussianModel <: LinearGaussianDynamics
D::Int # dimension
end
dim(model::DefaultLinearGaussianModel) = model.D
struct DefaultLinearGaussianObservation <: LinearGaussianObservation
D::Int # dimension
end
dim(model::DefaultLinearGaussianObservation) = model.D
######################
#### CONDITIONING ####
######################
# A state space model can be conditioned on a sequence of observations. This avoids
# observations needing to passed around as arguments to samplers and also allows a
# distinction between forward simulation and filtering/smoothing.
# HACK: simplified type for observations to make this work with the RB example
struct ConditionedStateSpaceModel{M<:AbstractStateSpaceModel}
model::M
observations::Vector
end
#############################
#### HIERARCHICAL MODELS ####
#############################
# One of the main motivations for separating the latent dynamics and observation processes
# is that it allows a natural definition of hierarchical models. The value of the outer
# state can be passed to the inner state as a control variable. Not only is this useful for
# Rao-Blackwellisation, but it can also be used for (partially) independent particle filter
# of Lin, 2005.
struct HierarchicalStateSpaceModel{
LD1<:LatentDynamics,LD2<:LatentDynamics,OP<:ObservationProcess
} <: AbstractStateSpaceModel
outer_latent_dynamics::LD1
inner_latent_dynamics::LD2
observation_process::OP
end
# Special cases of hierarchical models can be defined for dispatching to optimised methods
# through Rao-Blackwellisation.
const ConditionallyGaussianStateSpaceModel = HierarchicalStateSpaceModel{
LD1,LD2,OP
} where {LD1<:LatentDynamics,LD2<:LinearGaussianDynamics,OP<:LinearGaussianObservation}
# A hierarchical SSM is still a regular SSM and can be sampled from or filtered in a general
# setting by defining appropriate methods
# TODO: this could be made cleaner if logdensities are including by defining a factorised
# Distributions.jl distribution
function initialise(rng::AbstractRNG, model::HierarchicalStateSpaceModel, u::AbstractVector)
outer_x = initialise(rng, model.outer_latent_dynamics, u)
inner_u = [outer_x; u]
inner_x = initialise(rng, model.inner_latent_dynamics, inner_u)
return [outer_x; inner_x]
end
function transition(
rng::AbstractRNG,
model::HierarchicalStateSpaceModel,
x::AbstractVector,
u::AbstractVector,
t::Integer,
)
# Transition outer dynamics
# outer_x = x[1:dim(model.outer_latent_dynamics)]
# HACK
outer_x = first(x)
outer_x = transition(rng, model.outer_latent_dynamics, outer_x, u, t)
inner_u = [outer_x; u]
# Transition inner dynamics
inner_x = x[(dim(model.outer_latent_dynamics)+1):end]
inner_x = transition(rng, model.inner_latent_dynamics, inner_x, inner_u, t)
return [outer_x; inner_x]
end
function observation(
rng::AbstractRNG,
model::HierarchicalStateSpaceModel,
x::AbstractVector,
u::AbstractVector,
t::Integer,
)
outer_x = x[1:dim(model.outer_latent_dynamics)]
inner_u = [outer_x; u]
inner_x = x[(dim(model.outer_latent_dynamics)+1):end]
return observation(rng, model.observation_process, inner_x, inner_u, t)
end
######################################
#### ABSTRACT PARTICLE CONTAINERS ####
######################################
# Depending on your downstream use cases, you may want to store different quantities
# generated by the particle filter. In the most extreme case (e.g. backward simulation), you
# need to store the entire NxT particle history. For regulary particle smoothing, you only
# need to store particles that survived resampling as in "Path storage in the particle
# filter". In an extreme case (e.g. large N, T where memory constaints are an issue), you
# may just want to store a summary of the particles at each time-step.
# The abstract particle container is designed to separate the details of particle storage
# from the particle filter algorithm. At each time-step the filtering algorithm passes all
# relevant variables to the particle container, which then stores them in whatever way it is
# defined to.
abstract type AbstractParticleContainer end
"""
A particle container that only stores the (weighted) mean of particles at each time-step.
"""
struct MeanParticleContainer{T} <: AbstractParticleContainer
μ::Vector{T}
end
MeanParticleContainer(T::Type, N_steps::Int) = MeanParticleContainer{T}(Vector{T}(undef, N_steps))
# Should this be Vector{S} where S <: Real?
function store!(
container::MeanParticleContainer{T},
t::Integer,
xs::AbstractVector{T},
log_ws::AbstractVector{<:Real}
) where {T}
weights = softmax(log_ws)
container.μ[t] = sum(xs .* weights)
end
##############################
#### RAO-BLACKWELLISATION ####
##############################
abstract type SSMSamplingAlgorithm end
struct RaoBlackwellisedParticleFilter <: SSMSamplingAlgorithm
n_particles::Int
end
# TODO: reduce re-used code between t == 1 and t > 1 cases
function StatsBase.sample!(
rng::AbstractRNG,
cond_model::ConditionedStateSpaceModel{<:ConditionallyGaussianStateSpaceModel},
controls::AbstractVector,
algorithm::RaoBlackwellisedParticleFilter,
particle_container::AbstractParticleContainer
)
ys = cond_model.observations
T = length(ys)
N = algorithm.n_particles
outer = cond_model.model.outer_latent_dynamics
inner = cond_model.model.inner_latent_dynamics
obs = cond_model.model.observation_process
xs = Vector{SampleJumps}(undef, N)
μs = Vector{Vector{Float64}}(undef, N)
Σs = Vector{Matrix{Float64}}(undef, N)
log_ws = Vector{Float64}(undef, N)
@showprogress for t in 1:T
y = ys[t]
u = controls[t]
if t == 1
for i in 1:N
log_ws[i] = -log(N)
# Initialise outer state from prior
x = initialise(rng, cond_model.model.outer_latent_dynamics, u)
# The control vector for the inner dynamics is the concatenation of the
# the outer state of the control vector
inner_u = [x; u]
xs[i] = x
# Compute prior mean and covariance for inner state, conditioned on outer
μ0, Σ0 = calc_initial(cond_model.model.inner_latent_dynamics, inner_u)
# Extract model matrices and vectors
H = calc_H(obs, inner_u, t)
R = calc_R(obs, inner_u, t)
# Filter initial states
K = Σ0 * H' * inv(H * Σ0 * H' + R)
μs[i] = μ0 + K * (y - H * μ0)
Σs[i] = (I - K * H) * Σ0
# Update weight given observation
μ_y = H * μ0
Σ_y = H * Σ0 * H' + R
log_ws[i] += logpdf(MvNormal(μ_y, Σ_y), y)
end
store!(particle_container, t, μs, log_ws)
else
# Resampling
weights = softmax(log_ws)
parent_idxs = sample(1:N, Weights(weights), N)
xs = xs[parent_idxs]
μs = μs[parent_idxs]
Σs = Σs[parent_idxs]
for i in 1:N
# Resample
log_ws[i] = -log(N)
# Transition outer state
x = transition(rng, outer, SampleJumps(Float64[], Float64[]), u, t)
# See t = 1 case
inner_u = [x; u]
xs[i] = x
# Extract model matrices and vectors
A, b, Q = calc_transition(inner, inner_u, t)
H = calc_H(obs, inner_u, t)
R = calc_R(obs, inner_u, t)
# Transition inner state
μ_pred = A * μs[i] + b
Σ_pred = A * Σs[i] * A' + Q
# Filter state
K = Σ_pred * H' * inv(H * Σ_pred * H' + R)
μs[i] = μ_pred + K * (y - H * μ_pred)
Σs[i] = (I - K * H) * Σ_pred
# Update weight given observation
μ_y = H * μ_pred
Σ_y = H * Σ_pred * H' + R
Σ_y = (Σ_y + Σ_y') / 2 # HACK: force symmetry
log_ws[i] += logpdf(MvNormal(μ_y, Σ_y), y)
end
store!(particle_container, t, μs, log_ws)
end
end
return particle_container
end
function StatsBase.sample!(
model::ConditionedStateSpaceModel,
controls::AbstractVector,
algorithm::RaoBlackwellisedParticleFilter,
particle_container::AbstractParticleContainer
)
return sample!(Random.default_rng(), model, controls, algorithm, particle_container)
end
######################
#### FULL EXAMPLE ####
######################
C = 1.0
β = 0.4
ϵ = 1e-10
μw = 0.0
σw = 1.0
θ = -0.5
σe = 2.0
"""
Linear Gaussian dynamics that are represented as general dynamics.
"""
struct SubordinatorDynamics <: LatentDynamics
p::LevyProcess
end
dim(model::SubordinatorDynamics) = 1
function initialise(
rng::AbstractRNG, dyn::SubordinatorDynamics, u::AbstractVector
)
return SampleJumps(Float64[], Float64[])
end
function transition(
rng::AbstractRNG,
dynamics::SubordinatorDynamics,
x::SampleJumps,
u::AbstractVector,
t::Integer,
)
dt = only(u) # store time deltas as control variables
return sample(rng, dynamics.p, dt)
end
struct LangevinLatentDynamics{T} <: LinearGaussianDynamics
dyn::LevyDrivenLinearSDE{T}
end
dim(model::LangevinLatentDynamics) = 2
function calc_transition(model::LangevinLatentDynamics, u::AbstractVector, t::Integer)
# TODO: this requires u to contain `Any` elements, which is not great for performance
dt = u[2]
jumps = u[1]
dist = conditional_marginal(jumps, model.dyn, [0.0, 0.0], dt)
A = exp(model.dyn.linear_dynamics, dt)
b = dist.μ
Q = dist.Σ
return A, b, Q
end
calc_μ0(model::LangevinLatentDynamics, u::AbstractVector) = zeros(2)
calc_Σ0(model::LangevinLatentDynamics, u::AbstractVector) = Diagonal(ones(2))
function calc_initial(model::LangevinLatentDynamics, u::AbstractVector)
return calc_μ0(model, u), calc_Σ0(model, u)
end
dyn = LangevinDynamics(θ)
struct PositionSelector <: LinearGaussianObservation
σ2::Float64
end
dim(model::PositionSelector) = 1
calc_H(model::PositionSelector, u::AbstractVector, t::Integer) = [1.0 0.0]
calc_c(model::PositionSelector, u::AbstractVector, t::Integer) = zeros(1)
calc_R(model::PositionSelector, u::AbstractVector, t::Integer) = Diagonal([model.σ2])
subordinator = GammaProcess(C, β)
truncated_subordinator = truncate(subordinator, ϵ)
nvm = NormalVarianceMeanProcess(truncated_subordinator, μw, σw)
sde = LevyDrivenLinearSDE(nvm, dyn, [0.0, 1.0])
example_model = HierarchicalStateSpaceModel(
SubordinatorDynamics(truncated_subordinator),
LangevinLatentDynamics(sde),
PositionSelector(σe^2),
)
# Simulate from the model
SEED = 1234
T = 200
finish = 100.0
control = [finish / T * ones(1) for _ in 1:T]
rng = Random.MersenneTwister(SEED)
xs, ys = sample(rng, example_model, control)
# Plot position state and observation
p = plot(; title="Forward Simulation of Hierarchical Model")
plot!(1:T, getindex.(xs, 2); label="Position")
scatter!(1:T, getindex.(ys, 1); label="Observations")
display(p)
# Plot velocity state
p = plot(; title="Forward Simulation of Hierarchical Model")
plot!(1:T, getindex.(xs, 3); label="Velocity")
display(p)
# Run Rao-Blackwellised particle filter
N = 10^4
rbpf = RaoBlackwellisedParticleFilter(N)
# TODO: is this not inefficient since Julia doesn't know length of inner vector?
particle_container = MeanParticleContainer(Vector{Float64}, T)
sample!(ConditionedStateSpaceModel(example_model, ys), control, rbpf, particle_container)
# Compare filtered trajectory to true trajectory
filtered_mean = particle_container.μ
p = plot(;)
plot!(1:T, getindex.(xs, 2); label="True Position")
plot!(1:T, getindex.(filtered_mean, 1); label="Filtered Position")
scatter!(1:T, getindex.(ys, 1); label="Observations", alpha=0.5)
display(p)
p2 = plot(;)
plot!(1:T, getindex.(xs, 3); label="True Velocity")
plot!(1:T, getindex.(filtered_mean, 2); label="Filtered Velocity")
plot(p, p2, layout=(2, 1), size=(800, 600))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment