Last active August 18, 2021 09:56
Simple implementation of rescaled Forward Backward algorithm for Hidden Markov Chains with Independent Gaussian Noise in Jax using the lax.scan function
import numpy as np
import jax.numpy as jnp
import jax
from jax.scipy.stats import norm
def generate_observations(H, means, stds):
X = ((H == 0) * (means[0] + np.random.randn(*H.shape) * stds[0]) +
(H == 1) * (means[1] + np.random.randn(*H.shape) * stds[1]))
return X
def generate_hidden_states(T, A, p0):
H = []
H.append(np.random.choice(2, 1, p=p0)[0])
for t in range(1, T):
u = np.random.rand()
p = A[H[t - 1]]
H.append(np.nonzero(u < np.cumsum(p))[0][0])
H = np.array(H)
return H
def jax_forward_one_step(alpha_tm1, t, X_pdf, A):
alpha_t = (jnp.sum(alpha_tm1[..., None] * A, axis=0) * X_pdf[:, t])
alpha_t /= jnp.sum(alpha_t)
return alpha_t
def jax_backward_one_step(beta_tp1, t, X_pdf, A):
beta_t = jnp.sum(A * X_pdf[:, t + 1] * beta_tp1, axis=1)
beta_t /= jnp.sum(beta_t)
return beta_t
@jax.partial(jax.jit, static_argnums=(0,))
def jax_forward_backward(T, X_pdf, A, pO):
alpha_init = p0 * X_pdf[:, 0]
beta_init = jnp.array([1., 1.])
def scan_fn_a(alpha_tm1, t):
alpha_t = jax_forward_one_step(alpha_tm1, t, X_pdf, A)
# the carry that we want for the next iteration and the sample we want
# to store for this iteration are the same
return alpha_t, alpha_t
def scan_fn_b(beta_tp1, t):
beta_t = jax_backward_one_step(beta_tp1, t, X_pdf, A)
return beta_t, beta_t
_, alpha = jax.lax.scan(scan_fn_a, alpha_init, jnp.arange(1, T, 1))
alpha = jnp.concatenate([alpha_init[None, ...], alpha], axis=0)
_, beta = jax.lax.scan(scan_fn_b, beta_init, jnp.arange(0, T - 1, 1),
beta = jnp.concatenate([beta, beta_init[None, ...]], axis=0)
return alpha, beta
def jax_get_post_marginals_probas(alpha, beta):
post_marginals_probas = alpha * beta
post_marginals_probas /= jnp.sum(post_marginals_probas,
axis=1, keepdims=True)
return pmp
if __name__ == "__main__":
p0 = jnp.array([0.5, 0.5])
A = jnp.array([[0.9, 0.1], [0.1, 0.9]])
T = 500
means = jnp.array([0., 1.])
stds = jnp.array([0.5, 0.5])
H = generate_hidden_states(T, A, p0)
X = generate_observations(H, means, stds)
X_pdf = jnp.stack([norm.pdf(X, means[0], stds[0]),
norm.pdf(X, means[1], stds[1])], axis=0)
alpha, beta = jax_forward_backward(T, X_pdf, A, p0)
post_marginals_probas = jax_get_post_marginals_probas(alpha, beta)
# Marginal MAP criterion
mpm_seg = jnp.argmax(post_marginals_probas, axis=1)
error_rate = jnp.count_nonzero(mpm_seg != H) / len(H)
print("Error rate", error_rate)
