Skip to content

Instantly share code, notes, and snippets.

@HGangloff
Last active August 20, 2021 10:16
Show Gist options
  • Save HGangloff/cae63fe2851ef92fdd80121cf590bee6 to your computer and use it in GitHub Desktop.
Save HGangloff/cae63fe2851ef92fdd80121cf590bee6 to your computer and use it in GitHub Desktop.
Efficient Sequential Importance Resampling Particle Filter with Jax using jit and lax.scan
'''
Sequential Importance Resampling Particle Filter with Jax using jit and
lax.scan
The model equation for this simple application are:
X_{n} = 0.5 * X_{n-1} + 25 * X_{n-1} / (1 + X_{n-1}^2) + 8 * cos(1.2 * n) + U
Y_{n} = X_{n}^2 / 20 + V
where U~N(0, 10) and V~N(0, 1)
'''
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import jax
from jax import (jit, random, partial)
def generate_hidden_states(T):
X = np.empty(T)
X[0] = 0
for n in range(1, T):
X[n] = (0.5 * X[n-1] + 25 * X[n-1] / (1 + np.square(X[n-1])) +
8 * np.cos(1.2 * n) + np.random.randn() * np.sqrt(10))
return X
def generate_observations(X):
Y = np.square(X) / 20 + np.random.randn(*X.shape) * 1
return Y
def jax_gauss_pdf(x, mu, sigma):
'''
norm pdf
'''
res = (1 / jnp.sqrt(2 * jnp.pi * sigma ** 2) * jnp.exp(-0.5 * (x - mu) ** 2
/ sigma ** 2))
return res
@partial(jit, static_argnums=(1, 3))
def jax_SIR(Y, T, key, nb_particles):
def resampling(w, particles_t, key):
idx = jax.random.categorical(key, jnp.log(w), shape=(nb_particles,))
w = jnp.ones(nb_particles) / nb_particles
particles_t = particles_t[idx]
return w, particles_t
key, sub_key = random.split(key)
particles_t0 = random.normal(sub_key, shape=(nb_particles,)) * jnp.sqrt(10)
w = jax_gauss_pdf(jnp.full(particles_t0.shape, Y[0]),
jnp.square(particles_t0) / 20, 1)
w /= jnp.sum(w)
key, sub_key = random.split(key)
w, particles_t0 = resampling(w, particles_t0, sub_key)
X0_est = jnp.sum(w * particles_t0)
def SIR_one_step(w_tm1, particles_tm1, key, t):
key, sub_key = random.split(key)
particles_t = (0.5 * particles_tm1 + 25 * particles_tm1 /
(1 + jnp.square(particles_tm1)) + 8 * jnp.cos(1.2 * t) +
random.normal(sub_key, shape=(nb_particles,)) *
jnp.sqrt(10))
w = w_tm1 * jax_gauss_pdf(jnp.full(particles_t.shape, Y[t]),
jnp.square(particles_t / 20), 1)
w /= jnp.sum(w)
key, sub_key = random.split(key)
w, particles_t = resampling(w, particles_t, sub_key)
Xt_est = jnp.sum(w * particles_t)
return (w, particles_t, key), Xt_est
def scan_SIR_wrapper(carry, t):
w_tm1, particles_tm1, key = carry
carry, sample = SIR_one_step(w_tm1, particles_tm1, key, t)
return carry, sample
_, X_est_from_1 = jax.lax.scan(scan_SIR_wrapper, (w, particles_t0, key),
jnp.arange(1, T, 1))
X_est = jnp.concatenate([X0_est[None, ...], X_est_from_1], axis=0)
return X_est
key, sub_key = random.split(key)
post_marginals_probas_final = jax_SIR(T, all_llkh, all_lmargq,
all_pmp, sub_key, nb_particles)
if __name__ == "__main__":
key = random.PRNGKey(0)
key, sub_key = random.split(key)
nb_particles = 500
T = 50
X = generate_hidden_states(T)
Y = generate_observations(X)
X_est = jax_SIR(Y, T, key, nb_particles)
print("MSE", jnp.mean(jnp.square(X - X_est)))
fig, axes = plt.subplots(1, 1)
axes.plot(X)
axes.plot(Y)
axes.plot(X_est)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment