Skip to content

Instantly share code, notes, and snippets.

@redwrasse
Created November 21, 2020 00:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save redwrasse/e46976d3fc2df7528742b6f55a79b315 to your computer and use it in GitHub Desktop.
Save redwrasse/e46976d3fc2df7528742b6f55a79b315 to your computer and use it in GitHub Desktop.
Attempted direct gradient descent on 2-state gaussian mixture model
# gmm_gd.py
"""
Direct gradient descent on 2-state gaussian mixture model.
Not the best way to do this, typically use the EM algorithm instead.
Training is highly unstable.
model:
p(x) = pi * phi_1 + (1-pi) * phi_2
phi_1, phi_2 ~ normal
pi = p(z = 0)
1 - pi = p(z = 1)
So -grad_theta log p = - grad_theta log(pi * phi_1 + (1-pi) * phi_2)
"""
import torch
pi_value = 3.14159265
dtype = torch.float
device = torch.device('cpu')
N = 1000
x1 = (torch.randn(N, device=device, dtype=dtype) * 2.) + 3.
x2 = (torch.randn(N, device=device, dtype=dtype) * 2.5) - 6.
x = torch.cat([x1, x2], dim=0)
# need sufficiently large sigma1, sigma2 values at initialization for numeric stability in loss function
mu1 = torch.randn(1, device=device, dtype=dtype, requires_grad=True)
sigma1 = torch.randn(1, device=device, dtype=dtype).clamp_min(1.5).requires_grad_()
mu2 = torch.randn(1, device=device, dtype=dtype, requires_grad=True)
sigma2 = torch.randn(1, device=device, dtype=dtype).clamp_min(1.5).requires_grad_()
pi = (torch.rand(1, device=device, dtype=dtype,) + 0.1).clamp_max(0.9).clamp_min(0.1).requires_grad_()
learning_rate = 1e-4
for i in range(10**5):
loss = - torch.mean(torch.log(pi * 1./(sigma1 * (2*pi_value)**0.5)*torch.exp(-(x - mu1)**2/(2*sigma1**2)) + (1 - pi)
* 1./(sigma2 * (2*pi_value)**0.5) * torch.exp(-(x - mu2)**2/(2*sigma2**2))))
if i % 10**3 == 0:
print(f'(i={i}) loss: {loss.item()} mu1: {mu1.item()}, mu2: {mu2.item()}, sigma1: {sigma1.item()}, sigma2: {sigma2.item()}, pi: {pi.item()}')
loss.backward()
with torch.no_grad():
mu1 -= learning_rate * mu1.grad
sigma1 -= learning_rate * sigma1.grad
mu2 -= learning_rate * mu2.grad
sigma2 -= learning_rate * sigma2.grad
pi -= learning_rate * pi.grad
mu1.grad.zero_()
sigma1.grad.zero_()
mu2.grad.zero_()
sigma2.grad.zero_()
pi.grad.zero_()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment