Last active
August 18, 2021 09:56
-
-
Save HGangloff/0698566a64d9f3db3e009d4ccc4bda75 to your computer and use it in GitHub Desktop.
Simple implementation of rescaled Forward Backward algorithm for Hidden Markov Chains with Independent Gaussian Noise in Jax using the lax.scan function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import jax.numpy as jnp | |
import jax | |
from jax.scipy.stats import norm | |
def generate_observations(H, means, stds): | |
X = ((H == 0) * (means[0] + np.random.randn(*H.shape) * stds[0]) + | |
(H == 1) * (means[1] + np.random.randn(*H.shape) * stds[1])) | |
return X | |
def generate_hidden_states(T, A, p0): | |
H = [] | |
H.append(np.random.choice(2, 1, p=p0)[0]) | |
for t in range(1, T): | |
u = np.random.rand() | |
p = A[H[t - 1]] | |
H.append(np.nonzero(u < np.cumsum(p))[0][0]) | |
H = np.array(H) | |
return H | |
def jax_forward_one_step(alpha_tm1, t, X_pdf, A): | |
alpha_t = (jnp.sum(alpha_tm1[..., None] * A, axis=0) * X_pdf[:, t]) | |
alpha_t /= jnp.sum(alpha_t) | |
return alpha_t | |
def jax_backward_one_step(beta_tp1, t, X_pdf, A): | |
beta_t = jnp.sum(A * X_pdf[:, t + 1] * beta_tp1, axis=1) | |
beta_t /= jnp.sum(beta_t) | |
return beta_t | |
@jax.partial(jax.jit, static_argnums=(0,)) | |
def jax_forward_backward(T, X_pdf, A, pO): | |
alpha_init = p0 * X_pdf[:, 0] | |
beta_init = jnp.array([1., 1.]) | |
def scan_fn_a(alpha_tm1, t): | |
alpha_t = jax_forward_one_step(alpha_tm1, t, X_pdf, A) | |
# the carry that we want for the next iteration and the sample we want | |
# to store for this iteration are the same | |
return alpha_t, alpha_t | |
def scan_fn_b(beta_tp1, t): | |
beta_t = jax_backward_one_step(beta_tp1, t, X_pdf, A) | |
return beta_t, beta_t | |
_, alpha = jax.lax.scan(scan_fn_a, alpha_init, jnp.arange(1, T, 1)) | |
alpha = jnp.concatenate([alpha_init[None, ...], alpha], axis=0) | |
_, beta = jax.lax.scan(scan_fn_b, beta_init, jnp.arange(0, T - 1, 1), | |
reverse=True) | |
beta = jnp.concatenate([beta, beta_init[None, ...]], axis=0) | |
return alpha, beta | |
def jax_get_post_marginals_probas(alpha, beta): | |
post_marginals_probas = alpha * beta | |
post_marginals_probas /= jnp.sum(post_marginals_probas, | |
axis=1, keepdims=True) | |
return pmp | |
if __name__ == "__main__": | |
p0 = jnp.array([0.5, 0.5]) | |
A = jnp.array([[0.9, 0.1], [0.1, 0.9]]) | |
T = 500 | |
means = jnp.array([0., 1.]) | |
stds = jnp.array([0.5, 0.5]) | |
H = generate_hidden_states(T, A, p0) | |
X = generate_observations(H, means, stds) | |
X_pdf = jnp.stack([norm.pdf(X, means[0], stds[0]), | |
norm.pdf(X, means[1], stds[1])], axis=0) | |
alpha, beta = jax_forward_backward(T, X_pdf, A, p0) | |
post_marginals_probas = jax_get_post_marginals_probas(alpha, beta) | |
# Marginal MAP criterion | |
mpm_seg = jnp.argmax(post_marginals_probas, axis=1) | |
error_rate = jnp.count_nonzero(mpm_seg != H) / len(H) | |
print("Error rate", error_rate) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment