Skip to content

Instantly share code, notes, and snippets.

@dasayan05

dasayan05/mog.py

Last active May 6, 2020
Embed
What would you like to do?
Example usage of Pyro for MoG
import pyro, torch, numpy as np
import pyro.distributions as dist
import pyro.optim as optim
import pyro.infer as infer
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from scipy.stats import norm
plt.ioff()
def getdata(N, mean1=2.0, mean2=-1.0, std1=0.5, std2=0.5):
D1 = np.random.randn(N//2,) * std1 + mean1
D2 = np.random.randn(N//2,) * std2 + mean2
D = np.concatenate([D1, D2], 0)
np.random.shuffle(D)
return torch.from_numpy(D.astype(np.float32))
@infer.config_enumerate(default='parallel')
def model(data):
f = pyro.param("f", torch.tensor([0.5]), constraint=dist.constraints.unit_interval)
means = pyro.param("M", torch.tensor([1.5, 3.]))
stds = pyro.param("S", torch.tensor([0.5, 0.5]), constraint=dist.constraints.positive)
with pyro.plate("data", len(data)):
F = dist.Bernoulli(f)
c = pyro.sample("c", F)
c = c.type(torch.LongTensor)
X = dist.Normal(means[c], stds[c])
x = pyro.sample("x", X, obs=data)
@infer.config_enumerate(default='parallel')
def guide(data):
pc = pyro.param("pc", torch.rand(len(data)), constraint=dist.constraints.unit_interval)
with pyro.plate("data", len(data)):
C = dist.Bernoulli(pc)
c = pyro.sample("c", C)
data = getdata(200)
# breakpoint()
pyro.clear_param_store()
optim = pyro.optim.Adam({})
svi = pyro.infer.SVI(model, guide, optim, infer.TraceEnum_ELBO())
fig, ax = plt.subplots(1, 3, figsize=(15, 4))
losses = []
T = 10000
for t in range(T):
losses.append(svi.step(data))
if t % 50 == 0:
ax[0].plot(losses, color='m')
ax[0].scatter(len(losses), losses[-1], color='m')
ax[0].annotate(f'{losses[-1]:.2f}', (len(losses)+50, losses[-1]+50))
ax[0].set_xlabel("epochs")
ax[0].set_ylabel("ELBO")
ax[0].set_xlim([0, T])
ax[0].set_ylim([0, 2500])
pc = pyro.param("pc")
han = ax[1].scatter(data.detach().numpy(), pc.detach().numpy(), c=pc.detach().numpy())
ax[1].set_ylim([-0.03, 1.03])
ax[1].set_xlabel("Data axis")
ax[1].set_ylabel(r"Posterior ($\lambda_i$)")
means, stds = pyro.param("M"), pyro.param("S")
coinbias = pyro.param("f").detach().item()
mean1, mean2 = means.detach().numpy()
std1, std2 = stds.detach().numpy()
xmin, xmax = data.min(), data.max()
xs = np.linspace(xmin-2, xmax+2, 150)
y1 = norm.pdf(xs, mean1, std1)
y2 = norm.pdf(xs, mean2, std2)
p1, = ax[2].plot(xs, y1, color='r')
p2, = ax[2].plot(xs, y2, color='b')
ax[2].axvline(mean1, linestyle='--', color='r')
ax[2].axvline(mean2, linestyle='--', color='b')
cb = ax[2].axhline(coinbias, linestyle='--', color='black')
ax[2].set_xlim([xmin-2, xmax+2])
ax[2].set_ylim([-0.02, 0.8])
ax[2].legend([p1, p2, cb], ['Gaussian 1', 'Gaussian 2', 'Coin Bias'], loc=2)
ax[2].scatter(data.numpy(), np.zeros_like(data.numpy()), marker='x', c=pc.detach().numpy())
ax[2].set_xlabel("Data axis")
ax[2].set_ylabel("Model densities")
# plt.draw()
# plt.savefig(f"tmp/{t}.png", bbox_inches='tight', inches=0)
plt.pause(0.01)
ax[0].cla(); ax[1].cla(); ax[2].cla()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.