Skip to content

Instantly share code, notes, and snippets.

@DanReia
Last active March 15, 2020 16:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DanReia/e0456f9bd4cb35998d6029ffa31be60c to your computer and use it in GitHub Desktop.
Save DanReia/e0456f9bd4cb35998d6029ffa31be60c to your computer and use it in GitHub Desktop.
Variation Bayesian Gaussian Mixture Model with Full Covariance Matrix in Pyro
import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
import matplotlib.pyplot as plt
pyro.enable_validation(True)
def FiveGaussians():
'''
Data comes from Corduneanu and Bishop:
https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/bishop-aistats01.pdf
'''
means=torch.tensor([[0, 0], [3, -3], [3, 3], [-3, 3], [-3, -3]]).float()
covariances=torch.tensor([[[1, 0],[0, 1]],[[1, 0.5], [0.5, 1]],[[1, -0.5], [-0.5, 1]],
[[1, 0.5],[0.5, 1]], [[1, -0.5],[-0.5, 1]]]).float()
return dist.MultivariateNormal(means,covariances).rsample([120]).view(-1,2)
D=FiveGaussians()
@config_enumerate(default='parallel')
def model(data):
d = data.shape[1]
pi = pyro.param('weights', dist.Dirichlet(0.5 * torch.ones(K)).sample(), constraint=constraints.unit_interval)
with pyro.plate('components', K):
theta=dist.HalfCauchy(torch.ones(d)).rsample([K])
eta = torch.ones(1)
L_omega=dist.LKJCorrCholesky(d, eta).sample((K,))
T=pyro.param('T',torch.bmm(theta.diag_embed(),L_omega))
means=data.mean(dim=0)
scales=(0.5*torch.eye(data.size(1)))
loc=pyro.param('loc',dist.MultivariateNormal(means,scales).rsample([K]))
with pyro.plate('data', len(data)):
assignment = pyro.sample('assignment', dist.Categorical(pi))
pyro.sample('obs', dist.MultivariateNormal(loc[assignment],scale_tril=T[assignment]), obs=data)
@config_enumerate(default="parallel")
def full_guide(data):
with pyro.plate('data', len(data)):
assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K,
constraint=constraints.unit_interval)
pyro.sample('assignment', dist.Categorical(assignment_probs))
K=torch.tensor([5])
pyro.clear_param_store()
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, full_guide, optim, loss=elbo)
pyro.set_rng_seed(42)
loss=[]
for i in range(10000):
step_loss=svi.step(D)
loss.append(step_loss)
plt.semilogx(loss)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss")
plt.show()
print([i for i in pyro.get_param_store().items()])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment