Skip to content

Instantly share code, notes, and snippets.

@Oktai15
Last active January 21, 2021 10:01
Show Gist options
  • Save Oktai15/1f5bd2fe5f6733b249085095b2c304fd to your computer and use it in GitHub Desktop.
Save Oktai15/1f5bd2fe5f6733b249085095b2c304fd to your computer and use it in GitHub Desktop.
Differentiable EM-algorithm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
num_gaussian = 6
gaussian_dim = 1
device = torch.device("cuda")
embedding_mean = 10 + torch.randn(
num_gaussian,
gaussian_dim,
requires_grad=False,
device=device
)
embedding_mean.requires_grad = True
embedding_log_variance = 10 * torch.ones(
num_gaussian,
gaussian_dim,
requires_grad=False,
device=device
)
embedding_log_variance.requires_grad = True
gaussians_logits_prior = torch.zeros(
num_gaussian,
dtype=torch.float,
requires_grad=True,
device=device
)
def em_loss(embeddings):
gaussian_list = []
for j in range(num_gaussian):
mn = Normal(
embedding_mean[j],
embedding_log_variance[j].exp())
gaussian_list.append(mn)
# Expectation step
log_p_embedding_and_gaussian = torch.t(torch.stack(
[gaussian_list[j].log_prob(embeddings) +
nn.functional.log_softmax(gaussians_logits_prior, dim=0)[j]
for j in range(num_gaussian)
]))
p_gaussians_condition_on_embeddings = \
nn.functional.softmax(log_p_embedding_and_gaussian, dim=-1)
# Gradient-based maximization step
p_log_embeddings = log_p_embedding_and_gaussian * \
p_gaussians_condition_on_embeddings.detach()
return p_log_embeddings.sum(dim=-1)
if __name__ == "__main__":
opt_em = optim.SGD(
[embedding_mean,
embedding_log_variance,
gaussians_logits_prior],
lr=1,
)
mn1 = Normal(torch.tensor(-100.0), torch.tensor(1.0))
mn2 = Normal(torch.tensor(100.0), torch.tensor(1.0))
for i in range(3000):
opt_em.zero_grad()
embeddings1 = mn1.sample((1000,)).to(device=device)
embeddings2 = mn2.sample((2000,)).to(device=device)
embeddings12 = torch.cat([embeddings1, embeddings2])
embedding_loss = -em_loss(embeddings12.detach()).mean(dim=0)
embedding_loss.backward(retain_graph=False)
if i % 200 == 0:
print(f"embedding_loss: {embedding_loss}")
print(f"mean: {embedding_mean}")
print(f"variance: {embedding_log_variance.exp()}")
print(f'gaussian_probs: {torch.softmax(gaussians_logits_prior, dim=-1)}')
opt_em.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment