Skip to content

Instantly share code, notes, and snippets.

Last active May 6, 2020 16:54
What would you like to do?
Example usage of Pyro for MoG
import pyro, torch, numpy as np
import pyro.distributions as dist
import pyro.optim as optim
import pyro.infer as infer
import matplotlib.pyplot as plt'ggplot')
from scipy.stats import norm
def getdata(N, mean1=2.0, mean2=-1.0, std1=0.5, std2=0.5):
D1 = np.random.randn(N//2,) * std1 + mean1
D2 = np.random.randn(N//2,) * std2 + mean2
D = np.concatenate([D1, D2], 0)
return torch.from_numpy(D.astype(np.float32))
def model(data):
f = pyro.param("f", torch.tensor([0.5]), constraint=dist.constraints.unit_interval)
means = pyro.param("M", torch.tensor([1.5, 3.]))
stds = pyro.param("S", torch.tensor([0.5, 0.5]), constraint=dist.constraints.positive)
with pyro.plate("data", len(data)):
F = dist.Bernoulli(f)
c = pyro.sample("c", F)
c = c.type(torch.LongTensor)
X = dist.Normal(means[c], stds[c])
x = pyro.sample("x", X, obs=data)
def guide(data):
pc = pyro.param("pc", torch.rand(len(data)), constraint=dist.constraints.unit_interval)
with pyro.plate("data", len(data)):
C = dist.Bernoulli(pc)
c = pyro.sample("c", C)
data = getdata(200)
# breakpoint()
optim = pyro.optim.Adam({})
svi = pyro.infer.SVI(model, guide, optim, infer.TraceEnum_ELBO())
fig, ax = plt.subplots(1, 3, figsize=(15, 4))
losses = []
T = 10000
for t in range(T):
if t % 50 == 0:
ax[0].plot(losses, color='m')
ax[0].scatter(len(losses), losses[-1], color='m')
ax[0].annotate(f'{losses[-1]:.2f}', (len(losses)+50, losses[-1]+50))
ax[0].set_xlim([0, T])
ax[0].set_ylim([0, 2500])
pc = pyro.param("pc")
han = ax[1].scatter(data.detach().numpy(), pc.detach().numpy(), c=pc.detach().numpy())
ax[1].set_ylim([-0.03, 1.03])
ax[1].set_xlabel("Data axis")
ax[1].set_ylabel(r"Posterior ($\lambda_i$)")
means, stds = pyro.param("M"), pyro.param("S")
coinbias = pyro.param("f").detach().item()
mean1, mean2 = means.detach().numpy()
std1, std2 = stds.detach().numpy()
xmin, xmax = data.min(), data.max()
xs = np.linspace(xmin-2, xmax+2, 150)
y1 = norm.pdf(xs, mean1, std1)
y2 = norm.pdf(xs, mean2, std2)
p1, = ax[2].plot(xs, y1, color='r')
p2, = ax[2].plot(xs, y2, color='b')
ax[2].axvline(mean1, linestyle='--', color='r')
ax[2].axvline(mean2, linestyle='--', color='b')
cb = ax[2].axhline(coinbias, linestyle='--', color='black')
ax[2].set_xlim([xmin-2, xmax+2])
ax[2].set_ylim([-0.02, 0.8])
ax[2].legend([p1, p2, cb], ['Gaussian 1', 'Gaussian 2', 'Coin Bias'], loc=2)
ax[2].scatter(data.numpy(), np.zeros_like(data.numpy()), marker='x', c=pc.detach().numpy())
ax[2].set_xlabel("Data axis")
ax[2].set_ylabel("Model densities")
# plt.draw()
# plt.savefig(f"tmp/{t}.png", bbox_inches='tight', inches=0)
ax[0].cla(); ax[1].cla(); ax[2].cla()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment