Skip to content

Instantly share code, notes, and snippets.

@michaelchughes
Created February 24, 2023 19:05
Show Gist options
  • Save michaelchughes/cacef0692d86ef3fb27b60e84a588890 to your computer and use it in GitHub Desktop.
Save michaelchughes/cacef0692d86ef3fb27b60e84a588890 to your computer and use it in GitHub Desktop.
Demo of JAX applied to a stack of a list to avoid item assignment
import numpy as np
import jax
import jax.numpy as jnp
import jax.nn
def calc_trans_mat(x_TD, r_KD, p_KK):
''' Compute transition matrix for each timestep
Args
----
x_TD : 2D array, (T,D)
Observed features in time series
r_KD : 2D array, (K,D)
Weight matrix
p_KK : 2D array, (K,K)
Transition probability matrix
Rows sum to one
Returns
-------
trans_TKK : 3D array, (T-1, K, K)
transition matrix for each timestep
'''
rx_TK1 = jnp.einsum("kd,td->tk", r_KD, x_TD)[:,np.newaxis]
trans_t_1KK_list = []
for t in range(1,T):
eta_t_KK = rx_TK1[t-1, :] + jnp.log(p_KK)
trans_t_KK = jax.nn.softmax(eta_t_KK, axis=1)
# rows of trans_t_KK sum to one
trans_t_1KK_list.append(trans_t_KK[np.newaxis,:,:])
# stack each element of shape (1,K,K) into array of (T, K, K)
return jnp.vstack(trans_t_1KK_list)
if __name__ == '__main__':
T = 6
D = 2
K = 3
prng = np.random.RandomState(101)
x_TD = prng.randn(T,D)
r_KD = prng.randn(K,D)
p_KK = prng.rand(K, K)
p_KK /= p_KK.sum(axis=1, keepdims=1)
trans_TKK = calc_trans_mat(x_TD, r_KD, p_KK)
for t in range(T-1):
print(trans_TKK[t])
def calc_loss(r_KD):
trans_TKK = calc_trans_mat(x_TD, r_KD, p_KK)
return jnp.sum(jnp.log(trans_TKK))
print("calc_loss")
print(calc_loss(r_KD))
calc_grad = jax.grad(calc_loss)
print("calc_grad wrt r_KD")
print(calc_grad(r_KD))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment