Skip to content

Instantly share code, notes, and snippets.

@magnusross
Last active June 20, 2023 14:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save magnusross/06e29ccb42f054c765a74366104fcb24 to your computer and use it in GitHub Desktop.
Save magnusross/06e29ccb42f054c765a74366104fcb24 to your computer and use it in GitHub Desktop.
import jax.numpy as jnp
import jax.random as jrnd
import jax.scipy.stats.norm as normal
from jax.scipy.special import logsumexp
from jax.example_libraries import optimizers
from jax import vmap, grad, jit
import matplotlib.pyplot as plt
from matplotlib import cm
def log_regression_likelihood(yn, xn, k, W, S):
return normal.logpdf(yn, loc=jnp.dot(W[k], xn), scale=S[k])
def regression_likelihood(yn, xn, k, W, S):
return normal.pdf(yn, loc=jnp.dot(W[k].T, xn), scale=S[k])
def logistic_likelihood(k, xn, V):
return jnp.exp(jnp.dot(V[k].T, xn))/jnp.sum(jnp.exp(V@xn))
def log_logistic_likelihood(k, xn, V):
return jnp.dot(V[k].T, xn) - logsumexp(V@xn)
def prob_z_given_y(k, yn, xn, W, V, S, K):
reg_map = vmap(regression_likelihood, in_axes=(
None, None, 0, None, None))(yn, xn, jnp.arange(K), W, S)
log_map = vmap(logistic_likelihood, in_axes=(
0, None, None))(jnp.arange(K), xn, V)
return (logistic_likelihood(k, xn, V) * regression_likelihood(yn, xn, k, W, S))/jnp.sum(reg_map*log_map)
def compute_rnk(y, x, W, V, S, K):
return vmap(lambda ks: vmap(lambda yn, xn: prob_z_given_y(ks, yn, xn, W, V, S, K))(y, x))(jnp.arange(K)).T
def m_step_obj(y, x, Rnk, W, V, S, K):
def map_func(n, k): return Rnk[n, k]*log_regression_likelihood(
y[n], x[n], k, W, S) + Rnk[n, k]*log_logistic_likelihood(k, x[n], V)
return -vmap(lambda n: vmap(lambda k: map_func(n, k))(jnp.arange(K)))(jnp.arange(x.shape[0])).sum()
def m_step(y, x, Rnk, W, V, S, K, lr=0.001, num_iterations=1000):
@jit
def loss_fn(params):
W, V, S = params
return m_step_obj(y, x, Rnk, W, V, S, K)
grad_fn = jit(grad(loss_fn))
init_params = (W, V, S)
opt_init, opt_update, get_params = optimizers.adam(step_size=lr)
opt_state = opt_init(init_params)
for i in range(num_iterations):
loss_value = loss_fn(get_params(opt_state))
grads = grad_fn(get_params(opt_state))
opt_state = opt_update(i, grads, opt_state)
print(f"M-step objective: {loss_value}")
optimized_params = get_params(opt_state)
optimized_W, optimized_V, optimized_S = optimized_params
return optimized_W, optimized_V, optimized_S
def run_EM(y, x, W_init, V_init, S_init, K, num_iterations=50):
W, V, S = W_init, V_init, S_init
for i in range(num_iterations):
print(f"EM iteration {i}")
# E step
Rnk = compute_rnk(y, x, W, V, S, K)
# M step
W, V, S = m_step(y, x, Rnk, W, V, S, K)
return W, V, S
def plot_soln(x, y, W, V, K, f_true=None):
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
x_train = x[:, 0]
# axs[0].plot(x_train, f_true)
expert_prob = vmap(lambda k: vmap(logistic_likelihood, in_axes=(
None, 0, None))(k, x, V))(jnp.arange(K))
argmax_prob = expert_prob.argmax(axis=0)
axs[0].scatter(x_train, y, c="black", alpha=0.1, label="Training data")
axs[0].plot(x_train, vmap(lambda n: W[argmax_prob[n], 0]*x_train[n] +
W[argmax_prob[n], 1])(jnp.arange(x.shape[0])), c="red", label="Modal Pred.")
if f_true is not None:
axs[0].plot(x_train, f_true, c="black", ls=":", label="$f(x)$")
for k in range(K):
axs[0].plot(x_train, W[k, 0]*x_train +
W[k, 1], c=cm.Set2(k), alpha=0.6)
axs[0].set_xlabel("$x$")
axs[0].set_ylabel("$y$")
axs[0].legend()
for k in range(K):
axs[1].plot(x_train, vmap(logistic_likelihood, in_axes=(None, 0, None))(
k, x, V), c=cm.Set2(k), label=f"k={k}")
axs[1].set_ylabel("$p(z=k|x, \\theta)$")
axs[1].set_xlabel("$x$")
axs[1].legend()
plt.tight_layout()
plt.savefig("em-mole-sln.png", dpi=600)
plt.show()
def main():
# parameters
N = 2 # data dim, including bias
K = 3 # number of mixtures
W = jrnd.normal(jrnd.PRNGKey(2), shape=(K, N)) # regression wieghts
V = jrnd.normal(jrnd.PRNGKey(3), shape=(K, N)) # logisitic wieghts
S = jnp.array([1.0] * K) # noise of regressions
# construct training data
x_train = jnp.linspace(-3, 3, 100)
xp_train = jnp.vstack((x_train, jnp.ones(100))).T
w_true = jnp.array([[0.4, 1.0], [-2.0, 0.1], [2.0, 0.0]])
f_true = w_true[0]@xp_train.T * (x_train < -1) + w_true[1]@xp_train.T * (
(-1 <= x_train) & (x_train < 1)) + w_true[2]@xp_train.T * (x_train >= 1)
y_train = f_true + 0.5*jrnd.normal(jrnd.PRNGKey(1), shape=(100,))
plt.figure(figsize=(4, 3))
plt.scatter(x_train, y_train, c="black", alpha=0.1, label="Training data")
plt.xlabel("$x$")
plt.ylabel("$y$")
plt.legend()
plt.tight_layout()
plt.savefig("em-mole-data.png", dpi=600)
plt.show()
W, V, S = run_EM(y_train, xp_train, W, V, S, K, num_iterations=20)
plot_soln(xp_train, y_train, W, V, K, f_true=f_true)
if __name__ == "__main__":
main()
@magnusross
Copy link
Author

Implements the EM algorithm for inference in a mixture of linear regression experts model, using jax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment