Last active August 20, 2021 10:16
Efficient Sequential Importance Resampling Particle Filter with Jax using jit and lax.scan
The model equation for this simple application are:
X_{n} = 0.5 * X_{n-1} + 25 * X_{n-1} / (1 + X_{n-1}^2) + 8 * cos(1.2 * n) + U
Y_{n} = X_{n}^2 / 20 + V
where U~N(0, 10) and V~N(0, 1)
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import jax
from jax import (jit, random, partial)
def generate_hidden_states(T):
X = np.empty(T)
X[0] = 0
for n in range(1, T):
X[n] = (0.5 * X[n-1] + 25 * X[n-1] / (1 + np.square(X[n-1])) +
8 * np.cos(1.2 * n) + np.random.randn() * np.sqrt(10))
return X
def generate_observations(X):
Y = np.square(X) / 20 + np.random.randn(*X.shape) * 1
return Y
def jax_gauss_pdf(x, mu, sigma):
norm pdf
res = (1 / jnp.sqrt(2 * jnp.pi * sigma ** 2) * jnp.exp(-0.5 * (x - mu) ** 2
/ sigma ** 2))
return res
@partial(jit, static_argnums=(1, 3))
def jax_SIR(Y, T, key, nb_particles):
def resampling(w, particles_t, key):
idx = jax.random.categorical(key, jnp.log(w), shape=(nb_particles,))
w = jnp.ones(nb_particles) / nb_particles
particles_t = particles_t[idx]
return w, particles_t
key, sub_key = random.split(key)
particles_t0 = random.normal(sub_key, shape=(nb_particles,)) * jnp.sqrt(10)
w = jax_gauss_pdf(jnp.full(particles_t0.shape, Y[0]),
jnp.square(particles_t0) / 20, 1)
w /= jnp.sum(w)
key, sub_key = random.split(key)
w, particles_t0 = resampling(w, particles_t0, sub_key)
X0_est = jnp.sum(w * particles_t0)
def SIR_one_step(w_tm1, particles_tm1, key, t):
key, sub_key = random.split(key)
particles_t = (0.5 * particles_tm1 + 25 * particles_tm1 /
(1 + jnp.square(particles_tm1)) + 8 * jnp.cos(1.2 * t) +
random.normal(sub_key, shape=(nb_particles,)) *
w = w_tm1 * jax_gauss_pdf(jnp.full(particles_t.shape, Y[t]),
jnp.square(particles_t / 20), 1)
w /= jnp.sum(w)
key, sub_key = random.split(key)
w, particles_t = resampling(w, particles_t, sub_key)
Xt_est = jnp.sum(w * particles_t)
return (w, particles_t, key), Xt_est
def scan_SIR_wrapper(carry, t):
w_tm1, particles_tm1, key = carry
carry, sample = SIR_one_step(w_tm1, particles_tm1, key, t)
return carry, sample
_, X_est_from_1 = jax.lax.scan(scan_SIR_wrapper, (w, particles_t0, key),
jnp.arange(1, T, 1))
X_est = jnp.concatenate([X0_est[None, ...], X_est_from_1], axis=0)
return X_est
if __name__ == "__main__":
key = random.PRNGKey(0)
key, sub_key = random.split(key)
nb_particles = 500
T = 50
X = generate_hidden_states(T)
Y = generate_observations(X)
X_est = jax_SIR(Y, T, key, nb_particles)
print("MSE", jnp.mean(jnp.square(X - X_est)))
fig, axes = plt.subplots(1, 1)
