Skip to content

Instantly share code, notes, and snippets.

@HGangloff
Created September 25, 2021 07:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HGangloff/6b6a177eed2ae3387cbb5aa92fea87c4 to your computer and use it in GitHub Desktop.
Save HGangloff/6b6a177eed2ae3387cbb5aa92fea87c4 to your computer and use it in GitHub Desktop.
Equivalence between a gradient ascent over the Expectation Maximization quantity Q and a gradient ascent over the model likelihood in the case of training an Hidden Markov Chain with Gaussian Independent Noise
'''
Equivalence between a gradient ascent over the Expectation Maximization
quantity Q and a gradient ascent over the model likelihood in the case of
training an Hidden Markov Chain with Gaussian Independent Noise
'''
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import jax
from jax.scipy.stats import norm
from jax.scipy.special import logsumexp
import optax
def jax_lse_custom(a, axis=None):
max_ = jax.lax.stop_gradient(jnp.amax(a, axis=axis))
if axis == 1:
max__ = jnp.stack([max_, max_], axis=1)
elif axis == 0:
max__ = jnp.stack([max_, max_], axis=0)
else:
max__ = max_
return max_ + jnp.log(jnp.sum(jnp.exp(a - max__), axis=axis))
def jax_loggauss(x, mu, sigma):
'''
log norm pdf
'''
return (-jnp.log(jnp.sqrt(2 * jnp.pi * sigma ** 2)) -0.5 *
(x - mu) ** 2 / sigma ** 2)
def generate_observations(H, means, stds):
X = ((H == 0) * (means[0] + np.random.randn(*H.shape) * stds[0]) +
(H == 1) * (means[1] + np.random.randn(*H.shape) * stds[1]))
return X
def generate_hidden_states(T, A, p0):
H = []
H.append(np.random.choice(2, 1, p=p0)[0])
for t in range(1, T):
u = np.random.rand()
p = A[H[t - 1]]
H.append(np.nonzero(u < np.cumsum(p))[0][0])
H = np.array(H)
return H
def jax_forward_one_step(llkh_t_1, t, lX_pdf, lA, rescaled=True):
llkh_t = (jax_lse_custom(
llkh_t_1[..., None] + lA, axis=0)
+ lX_pdf[:, t])
if rescaled:
llkh_t -= jax_lse_custom(llkh_t)
return llkh_t
@jax.jit
def jax_forward_one_step_no_rescaled(llkh_t_1, t, lX_pdf, lA):
return jax_forward_one_step(llkh_t_1, t, lX_pdf, lA, rescaled=False)
@jax.jit
def jax_forward_one_step_rescaled(llkh_t_1, t, lX_pdf, lA):
return jax_forward_one_step(llkh_t_1, t, lX_pdf, lA, rescaled=True)
@jax.jit
def jax_backward_one_step_rescaled(beta_tp1, t, lX_pdf, lA):
return jax_backward_one_step(beta_tp1, t, lX_pdf, lA, rescaled=True)
@jax.jit
def jax_backward_one_step_no_rescaled(beta_tp1, t, lX_pdf, lA):
return jax_backward_one_step(beta_tp1, t, lX_pdf, lA, rescaled=False)
def jax_backward_one_step(beta_tp1, t, lX_pdf, lA, rescaled=True):
"""
"""
beta_t = jax_lse_custom(lA + lX_pdf[:, t + 1] + beta_tp1, axis=1)
if rescaled:
beta_t -= jax_lse_custom(beta_t)
return beta_t
@jax.partial(jax.jit, static_argnums=(0,))
def jax_log_forward_backward(T, X_pdf, A):
llkh_init = jnp.array([0.5, 0.5]) * X_pdf[:, 0]
beta_init = jnp.array([1., 1.])
def scan_fn_a(llkh_tm1, t):
llkh_t = jax_forward_one_step_rescaled(llkh_tm1, t, X_pdf, A)
# the carry that we want for the next iteration and the sample we want
# to store for this iteration are the same
return llkh_t, llkh_t
def scan_fn_b(beta_tp1, t):
beta_t = jax_backward_one_step_rescaled(beta_tp1, t, X_pdf, A)
return beta_t, beta_t
_, llkh = jax.lax.scan(scan_fn_a, llkh_init, jnp.arange(1, T, 1))
llkh = jnp.concatenate([llkh_init[None, ...], llkh], axis=0)
_, beta = jax.lax.scan(scan_fn_b, beta_init, jnp.arange(0, T - 1, 1),
reverse=True)
beta = jnp.concatenate([beta, beta_init[None, ...]], axis=0)
return llkh, beta
@jax.jit
def jax_get_post_marginals_probas(lllkh, lbeta):
post_marginals_probas = lllkh + lbeta
post_marginals_probas -= logsumexp(post_marginals_probas,
axis=1, keepdims=True)
pmp = jnp.exp(post_marginals_probas)
#pmp = jnp.where(pmp < 1e-5, 1e-5, pmp)
#pmp = jnp.where(pmp > 0.99999, 0.99999, pmp)
return pmp
@jax.partial(jax.jit, static_argnums=(0,))
def jax_get_post_pair_marginals_probas(T, lllkh, lbeta, lA, lX_pdf,
nb_classes=2):
post_pair_marginals_probas = jnp.empty((T - 1, nb_classes, nb_classes))
for h_t_1 in range(nb_classes):
for h_t in range(nb_classes):
post_pair_marginals_probas = jax.ops.index_update(
post_pair_marginals_probas,
jax.ops.index[:, h_t_1, h_t],
lllkh[:T - 1, h_t_1] +
lA[h_t_1, h_t] +
lX_pdf[h_t, 1:] + lbeta[1:, h_t])
post_pair_marginals_probas -= logsumexp(post_pair_marginals_probas,
axis=(1, 2), keepdims=True)
ppmp = jnp.exp(post_pair_marginals_probas)
#ppmp = jnp.where(ppmp < 1e-5, 1e-5, ppmp)
#ppmp = jnp.where(ppmp > 0.99999, 0.99999, ppmp)
return ppmp
def reconstruct_A(T, A_sig_params, X):
A = jnp.empty((2, 2))
for h_t_1 in range(2):
tmp = jnp.exp(-(A_sig_params[h_t_1]))
tmp = tmp / (1 + tmp)
A = jax.ops.index_update(A, jax.ops.index[h_t_1, 1], tmp)
A = jax.ops.index_update(A, jax.ops.index[h_t_1, 0], 1 - tmp)
#A = jnp.where(A < 1e-5, 1e-5, A)
#A = jnp.where(A > 0.99999, 0.99999, A)
return A
@jax.jit
def Q_A_pointwise(A_sig_params, input_, post_pair_marginals_probas, X):
s = 0
for h_t_1 in range(2):
tmp = jnp.exp(-(A_sig_params[h_t_1]))
tmp = tmp / (1 + tmp)
#tmp = jnp.where(tmp < 1e-5, 1e-5, tmp)
#tmp = jnp.where(tmp > 0.99999, 0.99999, tmp)
for h_t in range(2):
s += (post_pair_marginals_probas[input_ - 1, h_t_1, h_t] *
(h_t * jnp.log(tmp) + (1 - h_t) * jnp.log(1 - tmp))
)
# the multiplications with h_t act like indicator functions
return s
@jax.jit
def Q_means_stds_pointwise(means, stds, input_, post_marginals_probas, X):
s = 0
for h_t in range(2):
s += (post_marginals_probas[input_, h_t] * jax_loggauss(X[input_],
means[h_t], stds[h_t]))
return s
vmap_Q_A_pointwise = jax.vmap(Q_A_pointwise,
in_axes=(None, 0, None, None))
vmap_Q_means_stds_pointwise = jax.vmap(Q_means_stds_pointwise,
in_axes=(None, None, 0, None, None))
@jax.jit
def Q(A_sig_params, means, stds, post_marginals_probas,
post_pair_marginals_probas, X):
'''
The Q quantity from EM algorithm
'''
# NOTE jnp mean because of vmap which returns a vector for each t
# NOTE the minus sign because we want to maximize and optax minimizes
first_term_A = jnp.log(0.5) * (post_marginals_probas[0, 0] +
post_marginals_probas[0, 1])
return -(first_term_A +
jnp.mean(vmap_Q_A_pointwise(A_sig_params, jnp.arange(1, T),
post_pair_marginals_probas, X))
+ jnp.mean(vmap_Q_means_stds_pointwise(means, stds,
jnp.arange(0, T), post_marginals_probas, X))
)
Q_grad_A_sig_params = jax.grad(Q, argnums=0)
Q_grad_means = jax.grad(Q, argnums=1)
Q_grad_stds = jax.grad(Q, argnums=2)
@jax.partial(jax.jit, static_argnums=(0,))
def jax_compute_llkh(T, A_sig_params, means, stds, X):
"""
Compute the loglikelihood of an observed sequence given model parameters
Note that it needs an forward algorithm without rescaling
"""
lX_pdf = jnp.stack([jax_loggauss(X, means[0], stds[0]),
jax_loggauss(X, means[1], stds[1])], axis=0)
A = reconstruct_A(T, A_sig_params, X)
lA = jnp.log(A)
llkh_init = jnp.log(0.5) + lX_pdf[:, 0]
def scan_fn_a(llkh_t_1, t):
# llkh_t_1 is the former carry
llkh_t = jax_forward_one_step_no_rescaled(llkh_t_1, t, lX_pdf, lA)
# the next carry is also the sample we want to stakc in memory
return llkh_t, llkh_t
# We just want the final carry value (llkh value) (at T)
llkh_T, _ = jax.lax.scan(scan_fn_a, llkh_init, jnp.arange(1, T, 1))
llkh = jax_lse_custom(llkh_T)
return -llkh
Llkh_grad_A_sig_params = jax.grad(jax_compute_llkh, argnums=1)
Llkh_grad_means = jax.grad(jax_compute_llkh, argnums=2)
Llkh_grad_stds = jax.grad(jax_compute_llkh, argnums=3)
def gradEM_iteration(T, X, A_sig_params, means, stds, opt_state_A_sig_params,
opt_state_means, opt_state_stds):
'''
One gradient iteration with gradient EM
'''
A = reconstruct_A(T, A_sig_params, X)
lX_pdf = jnp.stack([jax_loggauss(X, means[0], stds[0]),
jax_loggauss(X, means[1], stds[1])], axis=0)
lA = jnp.log(A)
lllkh, lbeta = jax_log_forward_backward(T, lX_pdf, lA)
post_marginals_probas = jax_get_post_marginals_probas(lllkh, lbeta)
post_pair_marginals_probas = jax_get_post_pair_marginals_probas(T,
lllkh, lbeta, lA, lX_pdf)
q_grad_A_sig_params = Q_grad_A_sig_params(A_sig_params, means, stds,
post_marginals_probas, post_pair_marginals_probas, X)
q_grad_A_sig_params, opt_state_A_sig_params = \
opt.update(q_grad_A_sig_params, opt_state_A_sig_params)
q_grad_means = Q_grad_means(A_sig_params, means, stds,
post_marginals_probas, post_pair_marginals_probas, X)
q_grad_means, opt_state_means= \
opt.update(q_grad_means, opt_state_means)
q_grad_stds = Q_grad_stds(A_sig_params, means, stds,
post_marginals_probas, post_pair_marginals_probas, X)
q_grad_stds, opt_state_stds = \
opt.update(q_grad_stds, opt_state_stds)
A_sig_params = optax.apply_updates(A_sig_params, q_grad_A_sig_params)
means = optax.apply_updates(means, q_grad_means)
stds = optax.apply_updates(stds, q_grad_stds)
return (A_sig_params, means, stds, opt_state_A_sig_params, opt_state_means,
opt_state_stds)
def gradLlkh_iteration(T, X, A_sig_params, means, stds, opt_state_A_sig_params,
opt_state_means, opt_state_stds):
'''
One gradient iteration over the likelihood (Llkh function)
'''
llkh_grad_A_sig_params = Llkh_grad_A_sig_params(T, A_sig_params, means,
stds, X)
llkh_grad_A_sig_params, opt_state_A_sig_params = \
opt.update(llkh_grad_A_sig_params, opt_state_A_sig_params)
llkh_grad_means = Llkh_grad_means(T, A_sig_params, means,
stds, X)
llkh_grad_means, opt_state_means = \
opt.update(llkh_grad_means, opt_state_means)
llkh_grad_stds = Llkh_grad_stds(T, A_sig_params, means,
stds, X)
llkh_grad_stds, opt_state_stds = \
opt.update(llkh_grad_stds, opt_state_stds)
A_sig_params = optax.apply_updates(A_sig_params, llkh_grad_A_sig_params)
means = optax.apply_updates(means, llkh_grad_means)
stds = optax.apply_updates(stds, llkh_grad_stds)
return (A_sig_params, means, stds, opt_state_A_sig_params, opt_state_means,
opt_state_stds)
if __name__ == "__main__":
p0 = jnp.array([0.5, 0.5])
A = jnp.array([[0.9, 0.1], [0.1, 0.9]])
T = 500
means = jnp.array([0., 1.])
stds = jnp.array([0.5, 0.5])
H = generate_hidden_states(T, A, p0)
X = generate_observations(H, means, stds)
# so far we have created some data with parameters that we now forget
X_train = X
T = len(X_train)
# some random initial parameters
A_sig_params_gradEM = jnp.array([3., -3.])
means_gradEM = jnp.array([0.2, 0.5])
stds_gradEM = jnp.array([0.2, 0.3])
A_sig_params_gradLlkh = A_sig_params_gradEM.copy()
means_gradLlkh = means_gradEM.copy()
stds_gradLlkh = stds_gradEM.copy()
llkh = 0.01
opt = optax.adam(llkh)
opt_state_A_sig_params_gradEM = opt.init(A_sig_params_gradEM)
opt_state_means_gradEM = opt.init(means_gradEM)
opt_state_stds_gradEM = opt.init(stds_gradEM)
opt_state_A_sig_params_gradLlkh = opt.init(A_sig_params_gradLlkh)
opt_state_means_gradLlkh = opt.init(means_gradLlkh)
opt_state_stds_gradLlkh = opt.init(stds_gradLlkh)
llkh_list_gradEM = []
A_sig_params_list_gradEM = []
means_list_gradEM = []
stds_list_gradEM = []
llkh_list_gradLlkh = []
A_sig_params_list_gradLlkh = []
means_list_gradLlkh = []
stds_list_gradLlkh = []
nb_iter = 100
# At initialization
llkh_gradEM = -jax_compute_llkh(T, A_sig_params_gradEM,
means_gradEM, stds_gradEM, X) # NOTE the -1
print("\nGradient EM INIT", "loglikelihood=", llkh_gradEM)
llkh_list_gradEM.append(llkh_gradEM)
A_sig_params_list_gradEM.append(A_sig_params_gradEM)
means_list_gradEM.append(means_gradEM)
stds_list_gradEM.append(stds_gradEM)
llkh_gradLlkh = -jax_compute_llkh(T, A_sig_params_gradLlkh,
means_gradLlkh, stds_gradLlkh, X) # NOTE the -1
print("\nGradient on likelihood INIT", "loglikelihood=",
llkh_gradLlkh)
llkh_list_gradLlkh.append(llkh_gradLlkh)
A_sig_params_list_gradLlkh.append(A_sig_params_gradLlkh)
means_list_gradLlkh.append(means_gradLlkh)
stds_list_gradLlkh.append(stds_gradLlkh)
for k in range(nb_iter):
print("Iteration", k)
(A_sig_params_gradEM, means_gradEM, stds_gradEM,
opt_state_A_sig_params_gradEM, opt_state_means_gradEM,
opt_state_stds_gradEM) = gradEM_iteration(T, X, A_sig_params_gradEM,
means_gradEM, stds_gradEM, opt_state_A_sig_params_gradEM,
opt_state_means_gradEM, opt_state_stds_gradEM)
llkh_gradEM = -jax_compute_llkh(T, A_sig_params_gradEM,
means_gradEM, stds_gradEM, X)
print("\nGradient EM, loglikelihood=", llkh_gradEM)
llkh_list_gradEM.append(llkh_gradEM)
A_sig_params_list_gradEM.append(A_sig_params_gradEM)
means_list_gradEM.append(means_gradEM)
stds_list_gradEM.append(stds_gradEM)
print(A_sig_params_gradEM, means_gradEM, stds_gradEM)
(A_sig_params_gradLlkh, means_gradLlkh, stds_gradLlkh,
opt_state_A_sig_params_gradLlkh, opt_state_means_gradLlkh,
opt_state_stds_gradLlkh) = gradLlkh_iteration(T, X,
A_sig_params_gradLlkh,
means_gradLlkh, stds_gradLlkh, opt_state_A_sig_params_gradLlkh,
opt_state_means_gradLlkh, opt_state_stds_gradLlkh)
llkh_gradLlkh = -jax_compute_llkh(T, A_sig_params_gradLlkh,
means_gradLlkh, stds_gradLlkh, X)
print("\nGradient on likelihood, loglikelihood=",
llkh_gradLlkh)
llkh_list_gradLlkh.append(llkh_gradLlkh)
A_sig_params_list_gradLlkh.append(A_sig_params_gradLlkh)
means_list_gradLlkh.append(means_gradLlkh)
stds_list_gradLlkh.append(stds_gradLlkh)
print(A_sig_params_gradLlkh, means_gradLlkh, stds_gradLlkh)
A_sig_params_list_gradEM = np.array(A_sig_params_list_gradEM)
means_list_gradEM = np.array(means_list_gradEM)
stds_list_gradEM = np.array(stds_list_gradEM)
A_sig_params_list_gradLlkh = np.array(A_sig_params_list_gradLlkh)
means_list_gradLlkh = np.array(means_list_gradLlkh)
stds_list_gradLlkh = np.array(stds_list_gradLlkh)
fig, axes = plt.subplots(1, 4)
axes[0].plot(llkh_list_gradEM, label="gradEM")
axes[0].plot(llkh_list_gradLlkh, label="gradLlkh")
axes[0].legend()
axes[0].set_title('log-likelihoods')
axes[1].plot(A_sig_params_list_gradEM[:, 0], label="A_sig_params[0] gradEM")
axes[1].plot(A_sig_params_list_gradLlkh[:, 0], label="A_sig_params[0] gradLlkh")
axes[1].plot(A_sig_params_list_gradEM[:, 1], label="A_sig_params[1] gradEM")
axes[1].plot(A_sig_params_list_gradLlkh[:, 1], label="A_sig_params[1] gradLlkh")
axes[1].legend()
axes[1].set_title('A_sig_params')
axes[2].plot(means_list_gradEM[:, 0], label="means[0] gradEM")
axes[2].plot(means_list_gradLlkh[:, 0], label="means[0] gradLlkh")
axes[2].plot(means_list_gradEM[:, 1], label="means[1] gradEM")
axes[2].plot(means_list_gradLlkh[:, 1], label="means[1] gradLlkh")
axes[2].legend()
axes[2].set_title('means')
axes[3].plot(stds_list_gradEM[:, 0], label="stds[0] gradEM")
axes[3].plot(stds_list_gradLlkh[:, 0], label="stds[0] gradLlkh")
axes[3].plot(stds_list_gradEM[:, 1], label="stds[1] gradEM")
axes[3].plot(stds_list_gradLlkh[:, 1], label="stds[1] gradLlkh")
axes[3].legend()
axes[3].set_title('stds')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment