Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Created January 15, 2020 05:53
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DuaneNielsen/8c5bde8d35a46d60640d0579d913dcff to your computer and use it in GitHub Desktop.
Save DuaneNielsen/8c5bde8d35a46d60640d0579d913dcff to your computer and use it in GitHub Desktop.
EM algorithm - 1D Uses logprob bayes update for numerical stability
import torch
from torch.distributions.normal import Normal
import matplotlib.pyplot as plt
"""
EM algo demo, in pytorch
"""
n = 40 # must be even number
k = 2
eps = torch.finfo(torch.float32).eps
def plot(x, posterior):
fig, ax = plt.subplots(nrows=2, ncols=1)
ax[0].title.set_text('p (H | x)')
ax[0].bar(x.squeeze(), posterior[:, 0].squeeze())
ax[0].bar(x.squeeze(), posterior[:, 1].squeeze(), bottom=posterior[:, 0])
x_axis = torch.linspace(ax[0].get_xlim()[0], ax[0].get_xlim()[1], 50)
ax[1].title.set_text('H')
ax[1].plot(x_axis, torch.exp(h.log_prob(x_axis.expand(k, 50).T)), label=['h1', 'h2'])
fig.tight_layout()
plt.show()
if __name__ == '__main__':
d1 = Normal(-2.0, 0.5)
d2 = Normal(2.0, 0.5)
x1 = d1.sample((n//2,))
x2 = d2.sample((n//2,))
x = torch.cat((x1, x2)).view(-1, 1)
mu = torch.tensor([-3.0, -2.5])
stdev = torch.tensor([0.2, 0.2])
prior = torch.tensor([0.5, 0.5])
converged = False
i = 0
while not converged:
prev_mu = mu.clone()
prev_stdev = stdev.clone()
h = Normal(mu, stdev)
llhood = h.log_prob(x)
weighted_llhood = llhood + prior.log()
log_sum_lhood = torch.logsumexp(weighted_llhood, dim=1, keepdim=True)
log_posterior = weighted_llhood - log_sum_lhood
posterior = torch.exp(log_posterior)
if i % 3 == 0:
plot(x, posterior)
mu = torch.sum(posterior * x, dim=0) / (torch.sum(posterior, dim=0) + eps)
variance = torch.sum(posterior * (x - mu) ** 2, dim=0) / (torch.sum(posterior, dim=0) + eps)
stdev = variance.sqrt()
prior = posterior.mean(0)
converged = torch.allclose(mu, prev_mu) and torch.allclose(stdev, prev_stdev)
i += 1
plot(x, posterior)
print(i , mu, stdev, posterior.mean(0))
@AugustKarlstedt
Copy link

Really great example to learn from. Had this idea to "just use PyTorch" and found your post. Thanks for sharing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment