Last active
August 20, 2021 10:16
-
-
Save HGangloff/cae63fe2851ef92fdd80121cf590bee6 to your computer and use it in GitHub Desktop.
Efficient Sequential Importance Resampling Particle Filter with Jax using jit and lax.scan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Sequential Importance Resampling Particle Filter with Jax using jit and | |
lax.scan | |
The model equation for this simple application are: | |
X_{n} = 0.5 * X_{n-1} + 25 * X_{n-1} / (1 + X_{n-1}^2) + 8 * cos(1.2 * n) + U | |
Y_{n} = X_{n}^2 / 20 + V | |
where U~N(0, 10) and V~N(0, 1) | |
''' | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import jax.numpy as jnp | |
import jax | |
from jax import (jit, random, partial) | |
def generate_hidden_states(T): | |
X = np.empty(T) | |
X[0] = 0 | |
for n in range(1, T): | |
X[n] = (0.5 * X[n-1] + 25 * X[n-1] / (1 + np.square(X[n-1])) + | |
8 * np.cos(1.2 * n) + np.random.randn() * np.sqrt(10)) | |
return X | |
def generate_observations(X): | |
Y = np.square(X) / 20 + np.random.randn(*X.shape) * 1 | |
return Y | |
def jax_gauss_pdf(x, mu, sigma): | |
''' | |
norm pdf | |
''' | |
res = (1 / jnp.sqrt(2 * jnp.pi * sigma ** 2) * jnp.exp(-0.5 * (x - mu) ** 2 | |
/ sigma ** 2)) | |
return res | |
@partial(jit, static_argnums=(1, 3)) | |
def jax_SIR(Y, T, key, nb_particles): | |
def resampling(w, particles_t, key): | |
idx = jax.random.categorical(key, jnp.log(w), shape=(nb_particles,)) | |
w = jnp.ones(nb_particles) / nb_particles | |
particles_t = particles_t[idx] | |
return w, particles_t | |
key, sub_key = random.split(key) | |
particles_t0 = random.normal(sub_key, shape=(nb_particles,)) * jnp.sqrt(10) | |
w = jax_gauss_pdf(jnp.full(particles_t0.shape, Y[0]), | |
jnp.square(particles_t0) / 20, 1) | |
w /= jnp.sum(w) | |
key, sub_key = random.split(key) | |
w, particles_t0 = resampling(w, particles_t0, sub_key) | |
X0_est = jnp.sum(w * particles_t0) | |
def SIR_one_step(w_tm1, particles_tm1, key, t): | |
key, sub_key = random.split(key) | |
particles_t = (0.5 * particles_tm1 + 25 * particles_tm1 / | |
(1 + jnp.square(particles_tm1)) + 8 * jnp.cos(1.2 * t) + | |
random.normal(sub_key, shape=(nb_particles,)) * | |
jnp.sqrt(10)) | |
w = w_tm1 * jax_gauss_pdf(jnp.full(particles_t.shape, Y[t]), | |
jnp.square(particles_t / 20), 1) | |
w /= jnp.sum(w) | |
key, sub_key = random.split(key) | |
w, particles_t = resampling(w, particles_t, sub_key) | |
Xt_est = jnp.sum(w * particles_t) | |
return (w, particles_t, key), Xt_est | |
def scan_SIR_wrapper(carry, t): | |
w_tm1, particles_tm1, key = carry | |
carry, sample = SIR_one_step(w_tm1, particles_tm1, key, t) | |
return carry, sample | |
_, X_est_from_1 = jax.lax.scan(scan_SIR_wrapper, (w, particles_t0, key), | |
jnp.arange(1, T, 1)) | |
X_est = jnp.concatenate([X0_est[None, ...], X_est_from_1], axis=0) | |
return X_est | |
key, sub_key = random.split(key) | |
post_marginals_probas_final = jax_SIR(T, all_llkh, all_lmargq, | |
all_pmp, sub_key, nb_particles) | |
if __name__ == "__main__": | |
key = random.PRNGKey(0) | |
key, sub_key = random.split(key) | |
nb_particles = 500 | |
T = 50 | |
X = generate_hidden_states(T) | |
Y = generate_observations(X) | |
X_est = jax_SIR(Y, T, key, nb_particles) | |
print("MSE", jnp.mean(jnp.square(X - X_est))) | |
fig, axes = plt.subplots(1, 1) | |
axes.plot(X) | |
axes.plot(Y) | |
axes.plot(X_est) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment