Last active
June 20, 2023 14:01
-
-
Save magnusross/06e29ccb42f054c765a74366104fcb24 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
import jax.random as jrnd | |
import jax.scipy.stats.norm as normal | |
from jax.scipy.special import logsumexp | |
from jax.example_libraries import optimizers | |
from jax import vmap, grad, jit | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
def log_regression_likelihood(yn, xn, k, W, S): | |
return normal.logpdf(yn, loc=jnp.dot(W[k], xn), scale=S[k]) | |
def regression_likelihood(yn, xn, k, W, S): | |
return normal.pdf(yn, loc=jnp.dot(W[k].T, xn), scale=S[k]) | |
def logistic_likelihood(k, xn, V): | |
return jnp.exp(jnp.dot(V[k].T, xn))/jnp.sum(jnp.exp(V@xn)) | |
def log_logistic_likelihood(k, xn, V): | |
return jnp.dot(V[k].T, xn) - logsumexp(V@xn) | |
def prob_z_given_y(k, yn, xn, W, V, S, K): | |
reg_map = vmap(regression_likelihood, in_axes=( | |
None, None, 0, None, None))(yn, xn, jnp.arange(K), W, S) | |
log_map = vmap(logistic_likelihood, in_axes=( | |
0, None, None))(jnp.arange(K), xn, V) | |
return (logistic_likelihood(k, xn, V) * regression_likelihood(yn, xn, k, W, S))/jnp.sum(reg_map*log_map) | |
def compute_rnk(y, x, W, V, S, K): | |
return vmap(lambda ks: vmap(lambda yn, xn: prob_z_given_y(ks, yn, xn, W, V, S, K))(y, x))(jnp.arange(K)).T | |
def m_step_obj(y, x, Rnk, W, V, S, K): | |
def map_func(n, k): return Rnk[n, k]*log_regression_likelihood( | |
y[n], x[n], k, W, S) + Rnk[n, k]*log_logistic_likelihood(k, x[n], V) | |
return -vmap(lambda n: vmap(lambda k: map_func(n, k))(jnp.arange(K)))(jnp.arange(x.shape[0])).sum() | |
def m_step(y, x, Rnk, W, V, S, K, lr=0.001, num_iterations=1000): | |
@jit | |
def loss_fn(params): | |
W, V, S = params | |
return m_step_obj(y, x, Rnk, W, V, S, K) | |
grad_fn = jit(grad(loss_fn)) | |
init_params = (W, V, S) | |
opt_init, opt_update, get_params = optimizers.adam(step_size=lr) | |
opt_state = opt_init(init_params) | |
for i in range(num_iterations): | |
loss_value = loss_fn(get_params(opt_state)) | |
grads = grad_fn(get_params(opt_state)) | |
opt_state = opt_update(i, grads, opt_state) | |
print(f"M-step objective: {loss_value}") | |
optimized_params = get_params(opt_state) | |
optimized_W, optimized_V, optimized_S = optimized_params | |
return optimized_W, optimized_V, optimized_S | |
def run_EM(y, x, W_init, V_init, S_init, K, num_iterations=50): | |
W, V, S = W_init, V_init, S_init | |
for i in range(num_iterations): | |
print(f"EM iteration {i}") | |
# E step | |
Rnk = compute_rnk(y, x, W, V, S, K) | |
# M step | |
W, V, S = m_step(y, x, Rnk, W, V, S, K) | |
return W, V, S | |
def plot_soln(x, y, W, V, K, f_true=None): | |
fig, axs = plt.subplots(1, 2, figsize=(10, 4)) | |
x_train = x[:, 0] | |
# axs[0].plot(x_train, f_true) | |
expert_prob = vmap(lambda k: vmap(logistic_likelihood, in_axes=( | |
None, 0, None))(k, x, V))(jnp.arange(K)) | |
argmax_prob = expert_prob.argmax(axis=0) | |
axs[0].scatter(x_train, y, c="black", alpha=0.1, label="Training data") | |
axs[0].plot(x_train, vmap(lambda n: W[argmax_prob[n], 0]*x_train[n] + | |
W[argmax_prob[n], 1])(jnp.arange(x.shape[0])), c="red", label="Modal Pred.") | |
if f_true is not None: | |
axs[0].plot(x_train, f_true, c="black", ls=":", label="$f(x)$") | |
for k in range(K): | |
axs[0].plot(x_train, W[k, 0]*x_train + | |
W[k, 1], c=cm.Set2(k), alpha=0.6) | |
axs[0].set_xlabel("$x$") | |
axs[0].set_ylabel("$y$") | |
axs[0].legend() | |
for k in range(K): | |
axs[1].plot(x_train, vmap(logistic_likelihood, in_axes=(None, 0, None))( | |
k, x, V), c=cm.Set2(k), label=f"k={k}") | |
axs[1].set_ylabel("$p(z=k|x, \\theta)$") | |
axs[1].set_xlabel("$x$") | |
axs[1].legend() | |
plt.tight_layout() | |
plt.savefig("em-mole-sln.png", dpi=600) | |
plt.show() | |
def main(): | |
# parameters | |
N = 2 # data dim, including bias | |
K = 3 # number of mixtures | |
W = jrnd.normal(jrnd.PRNGKey(2), shape=(K, N)) # regression wieghts | |
V = jrnd.normal(jrnd.PRNGKey(3), shape=(K, N)) # logisitic wieghts | |
S = jnp.array([1.0] * K) # noise of regressions | |
# construct training data | |
x_train = jnp.linspace(-3, 3, 100) | |
xp_train = jnp.vstack((x_train, jnp.ones(100))).T | |
w_true = jnp.array([[0.4, 1.0], [-2.0, 0.1], [2.0, 0.0]]) | |
f_true = w_true[0]@xp_train.T * (x_train < -1) + w_true[1]@xp_train.T * ( | |
(-1 <= x_train) & (x_train < 1)) + w_true[2]@xp_train.T * (x_train >= 1) | |
y_train = f_true + 0.5*jrnd.normal(jrnd.PRNGKey(1), shape=(100,)) | |
plt.figure(figsize=(4, 3)) | |
plt.scatter(x_train, y_train, c="black", alpha=0.1, label="Training data") | |
plt.xlabel("$x$") | |
plt.ylabel("$y$") | |
plt.legend() | |
plt.tight_layout() | |
plt.savefig("em-mole-data.png", dpi=600) | |
plt.show() | |
W, V, S = run_EM(y_train, xp_train, W, V, S, K, num_iterations=20) | |
plot_soln(xp_train, y_train, W, V, K, f_true=f_true) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Implements the EM algorithm for inference in a mixture of linear regression experts model, using
jax
.