Skip to content

Instantly share code, notes, and snippets.

Last active August 9, 2019 19:11
Show Gist options
  • Save AbeHandler/5e0a6b3cf898bc9638e8b6c46e10de2d to your computer and use it in GitHub Desktop.
Save AbeHandler/5e0a6b3cf898bc9638e8b6c46e10de2d to your computer and use it in GitHub Desktop.
mixture_of_multinomials w/ EM
Display the source blob
Display the rendered blob
import numpy as np
from scipy.special import logsumexp
from tqdm import tqdm_notebook as tqdm
def init_pi(K):
t = np.random.uniform(0,1,size=K)
return t/np.sum(t)
def init_phi(K,V):
alphas = np.ones(V)
return np.random.dirichlet(alphas, size=K)
def init_lambda_d(N, K):
lambda_d = np.random.rand(N, K)
lambda_d /= np.sum(lambda_d,axis=1).reshape(N, 1)
return lambda_d
def generate_data(N, K, V, real_pi, real_phi, C):
D = np.zeros((N, V))
ks = []
for d in tqdm(range(N)):
k = np.random.choice(K, 1, p=real_pi)[0]
D[d] = np.random.multinomial(n=C, pvals=real_phi[k])
return D, ks
def normalize_log_probs(lps):
N_, K_ = lps.shape
p = np.exp(lps - np.max(lps,axis=1).reshape(N_,1))
p = p/np.sum(p, axis=1).reshape(N_,1)
assert np.sum(p, axis=1).all() == 1.0
return p
def log_prob_D_given_k(D, pi_hat_, phi_hat_, k):
log prob of the data, given k. i.e.
log( p(x, z | \theta)) = log( p(x | z, theta) * p(z | theta) )
= log(p(x | z, theta)) + log(p(z | theta))+
= \sum_V(p(x_v | z, theta) + log(p(z | theta))
# sum across rows to get the log probability of all words in the instance under phi_hat[k]
log_prob_of_words = np.sum(D * np.log(phi_hat_[k]), axis=1)
log_prob_of_class = np.log(pi_hat_[k])
out = log_prob_of_words + log_prob_of_class
assert out.shape == (N,)
return out.reshape(N,1)
def expected_complete_log_likelihood(lambda_d, pi_hat, phi_hat, D):
The first term in the ELBO \sum_z q(z) log p(x,z|\theta)
sum_ = 0
K = pi_hat.shape[0]
for k in range(K):
sum_ += np.sum(lambda_d[:,k].reshape(N,1) * log_prob_D_given_k(D, pi_hat, phi_hat, k=k))
assert sum_ < 0
return sum_
def entropy(q):
return -1 * np.sum(q * np.log(q))
def elbo(lambda_d, pi_hat, phi_hat, D):
return expected_complete_log_likelihood(lambda_d, pi_hat, phi_hat, D) + entropy(lambda_d)
def e_step(pi_hat_, phi_hat_, D, N, K):
lambda_d_new = np.zeros((N, K), dtype=np.float64)
for k in range(K):
lambda_d_new[:,k] = np.sum((D * np.log(phi_hat_[k])), axis=1).reshape(N,)
lambda_d_new += np.log(pi_hat_)
return normalize_log_probs(lambda_d_new)
def m_step_phi(lambda_d, K_, phi_hat, D):
phi_hat_new = np.zeros_like(phi_hat)
for k in range(K_):
nk = np.sum(lambda_d[:,k].reshape(N, 1) * D, axis=0)
nd = np.sum(nk)
phi_hat_new[k] = nk/nd
return phi_hat_new
def m_step_pi(lambda_d):
assert round(np.sum(lambda_d), 8) == N # i.e. denom D
n = np.sum(lambda_d,axis=0)
return n/(np.sum(n))
def observed_data_LL(pi_hat, phi_hat, K, D):
## note this is the log of the prob of the data! not the sum of the logs 'cuz Jensens
# \ell(theta)
observedD = np.zeros((N,1), dtype='float64')
for k in range(K):
lp_xz = log_prob_D_given_k(D, pi_hat, phi_hat, k) # log prob of all of the data, given z=k
p_xz = np.exp(lp_xz) # exponentiate to get out of log space
observedD += p_xz
log_of_all_N_points = np.log(observedD) # now take the log of the sum over instances
sum_of_n = np.sum(log_of_all_N_points)
assert sum_of_n <= 0
return sum_of_n
seed = random.randint(2,10)
seed = 3
N = 10000
K = 2
V = 3
C = 4 # context size
real_pi = init_pi(K)
real_phi = init_phi(K,V)
D, ks = generate_data(N, K, V, real_pi, real_phi, C)
# add pseudo counts
#D += 1
phi_hat = init_phi(K, V)
pi_hat = init_pi(K)
q = init_pi(K)
lambda_d = init_lambda_d(N, K)
for i in range(4):
#print(observed_data_LL(pi_hat, phi_hat, K))
this_expected_complete = expected_complete_log_likelihood(lambda_d, pi_hat, phi_hat, D)
#### e step
b4 = expected_complete_log_likelihood(lambda_d, pi_hat, phi_hat, D)
lambda_d = e_step(pi_hat, phi_hat, D, N, K)
aft = expected_complete_log_likelihood(lambda_d, pi_hat, phi_hat, D)
if not np.allclose(b4, aft, rtol=1e7):
assert(b4 <= aft)
#### m step
b4 = expected_complete_log_likelihood(lambda_d, pi_hat, phi_hat, D)
pi_hat = m_step_pi(lambda_d)
aft = expected_complete_log_likelihood(lambda_d, pi_hat, phi_hat, D)
if not np.allclose(b4, aft, rtol=1e10):
assert(b4 <= aft)
b4 = expected_complete_log_likelihood(lambda_d, pi_hat, phi_hat, D)
phi_hat = m_step_phi(lambda_d, K, phi_hat, D)
aft = expected_complete_log_likelihood(lambda_d, pi_hat, phi_hat, D)
if not np.allclose(b4, aft, rtol=1e10):
assert(b4 <= aft)
#print(elbo(lambda_d, pi_hat, phi_hat, D))
# BP: observed data LL should be less than 0
assert observed_data_LL(pi_hat, phi_hat, K, D) < 0
# BP: elbo should be below observed data LL
if not np.allclose(elbo(lambda_d, pi_hat, phi_hat, D), observed_data_LL(pi_hat, phi_hat, K, D), rtol=1e10):
assert elbo(lambda_d, pi_hat, phi_hat, D) <= observed_data_LL(pi_hat, phi_hat, K, D)
# BP: elbo should be climbing, observed data LL should be climbing
# BP: do a graph
print(elbo(lambda_d, pi_hat, phi_hat, D), observed_data_LL(pi_hat, phi_hat, K, D))
kl_phi = np.sum(real_phi * np.log(real_phi/phi_hat))
kl_pi = np.sum(real_pi * np.log(real_pi/pi_hat))
#print('kl', kl_pi + kl_phi) # assuming you have enuf points, this should go down too.
# # if you dont have enuf points, there will be variance in the draw from the
# # true parameters
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment