Skip to content

Instantly share code, notes, and snippets.

@HGangloff
Last active August 18, 2021 09:56
Show Gist options
  • Save HGangloff/0698566a64d9f3db3e009d4ccc4bda75 to your computer and use it in GitHub Desktop.
Save HGangloff/0698566a64d9f3db3e009d4ccc4bda75 to your computer and use it in GitHub Desktop.
Simple implementation of rescaled Forward Backward algorithm for Hidden Markov Chains with Independent Gaussian Noise in Jax using the lax.scan function
import numpy as np
import jax.numpy as jnp
import jax
from jax.scipy.stats import norm
def generate_observations(H, means, stds):
X = ((H == 0) * (means[0] + np.random.randn(*H.shape) * stds[0]) +
(H == 1) * (means[1] + np.random.randn(*H.shape) * stds[1]))
return X
def generate_hidden_states(T, A, p0):
H = []
H.append(np.random.choice(2, 1, p=p0)[0])
for t in range(1, T):
u = np.random.rand()
p = A[H[t - 1]]
H.append(np.nonzero(u < np.cumsum(p))[0][0])
H = np.array(H)
return H
def jax_forward_one_step(alpha_tm1, t, X_pdf, A):
alpha_t = (jnp.sum(alpha_tm1[..., None] * A, axis=0) * X_pdf[:, t])
alpha_t /= jnp.sum(alpha_t)
return alpha_t
def jax_backward_one_step(beta_tp1, t, X_pdf, A):
beta_t = jnp.sum(A * X_pdf[:, t + 1] * beta_tp1, axis=1)
beta_t /= jnp.sum(beta_t)
return beta_t
@jax.partial(jax.jit, static_argnums=(0,))
def jax_forward_backward(T, X_pdf, A, pO):
alpha_init = p0 * X_pdf[:, 0]
beta_init = jnp.array([1., 1.])
def scan_fn_a(alpha_tm1, t):
alpha_t = jax_forward_one_step(alpha_tm1, t, X_pdf, A)
# the carry that we want for the next iteration and the sample we want
# to store for this iteration are the same
return alpha_t, alpha_t
def scan_fn_b(beta_tp1, t):
beta_t = jax_backward_one_step(beta_tp1, t, X_pdf, A)
return beta_t, beta_t
_, alpha = jax.lax.scan(scan_fn_a, alpha_init, jnp.arange(1, T, 1))
alpha = jnp.concatenate([alpha_init[None, ...], alpha], axis=0)
_, beta = jax.lax.scan(scan_fn_b, beta_init, jnp.arange(0, T - 1, 1),
reverse=True)
beta = jnp.concatenate([beta, beta_init[None, ...]], axis=0)
return alpha, beta
def jax_get_post_marginals_probas(alpha, beta):
post_marginals_probas = alpha * beta
post_marginals_probas /= jnp.sum(post_marginals_probas,
axis=1, keepdims=True)
return pmp
if __name__ == "__main__":
p0 = jnp.array([0.5, 0.5])
A = jnp.array([[0.9, 0.1], [0.1, 0.9]])
T = 500
means = jnp.array([0., 1.])
stds = jnp.array([0.5, 0.5])
H = generate_hidden_states(T, A, p0)
X = generate_observations(H, means, stds)
X_pdf = jnp.stack([norm.pdf(X, means[0], stds[0]),
norm.pdf(X, means[1], stds[1])], axis=0)
alpha, beta = jax_forward_backward(T, X_pdf, A, p0)
post_marginals_probas = jax_get_post_marginals_probas(alpha, beta)
# Marginal MAP criterion
mpm_seg = jnp.argmax(post_marginals_probas, axis=1)
error_rate = jnp.count_nonzero(mpm_seg != H) / len(H)
print("Error rate", error_rate)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment