Skip to content

Instantly share code, notes, and snippets.

Last active May 6, 2020 16:54
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
What would you like to do?
Example usage of Pyro for MoG
import pyro, torch, numpy as np
import pyro.distributions as dist
import pyro.optim as optim
import pyro.infer as infer
import matplotlib.pyplot as plt'ggplot')
from scipy.stats import norm
def getdata(N, mean1=2.0, mean2=-1.0, std1=0.5, std2=0.5):
D1 = np.random.randn(N//2,) * std1 + mean1
D2 = np.random.randn(N//2,) * std2 + mean2
D = np.concatenate([D1, D2], 0)
return torch.from_numpy(D.astype(np.float32))
def model(data):
f = pyro.param("f", torch.tensor([0.5]), constraint=dist.constraints.unit_interval)
means = pyro.param("M", torch.tensor([1.5, 3.]))
stds = pyro.param("S", torch.tensor([0.5, 0.5]), constraint=dist.constraints.positive)
with pyro.plate("data", len(data)):
F = dist.Bernoulli(f)
c = pyro.sample("c", F)
c = c.type(torch.LongTensor)
X = dist.Normal(means[c], stds[c])
x = pyro.sample("x", X, obs=data)
def guide(data):
pc = pyro.param("pc", torch.rand(len(data)), constraint=dist.constraints.unit_interval)
with pyro.plate("data", len(data)):
C = dist.Bernoulli(pc)
c = pyro.sample("c", C)
data = getdata(200)
# breakpoint()
optim = pyro.optim.Adam({})
svi = pyro.infer.SVI(model, guide, optim, infer.TraceEnum_ELBO())
fig, ax = plt.subplots(1, 3, figsize=(15, 4))
losses = []
T = 10000
for t in range(T):
if t % 50 == 0:
ax[0].plot(losses, color='m')
ax[0].scatter(len(losses), losses[-1], color='m')
ax[0].annotate(f'{losses[-1]:.2f}', (len(losses)+50, losses[-1]+50))
ax[0].set_xlim([0, T])
ax[0].set_ylim([0, 2500])
pc = pyro.param("pc")
han = ax[1].scatter(data.detach().numpy(), pc.detach().numpy(), c=pc.detach().numpy())
ax[1].set_ylim([-0.03, 1.03])
ax[1].set_xlabel("Data axis")
ax[1].set_ylabel(r"Posterior ($\lambda_i$)")
means, stds = pyro.param("M"), pyro.param("S")
coinbias = pyro.param("f").detach().item()
mean1, mean2 = means.detach().numpy()
std1, std2 = stds.detach().numpy()
xmin, xmax = data.min(), data.max()
xs = np.linspace(xmin-2, xmax+2, 150)
y1 = norm.pdf(xs, mean1, std1)
y2 = norm.pdf(xs, mean2, std2)
p1, = ax[2].plot(xs, y1, color='r')
p2, = ax[2].plot(xs, y2, color='b')
ax[2].axvline(mean1, linestyle='--', color='r')
ax[2].axvline(mean2, linestyle='--', color='b')
cb = ax[2].axhline(coinbias, linestyle='--', color='black')
ax[2].set_xlim([xmin-2, xmax+2])
ax[2].set_ylim([-0.02, 0.8])
ax[2].legend([p1, p2, cb], ['Gaussian 1', 'Gaussian 2', 'Coin Bias'], loc=2)
ax[2].scatter(data.numpy(), np.zeros_like(data.numpy()), marker='x', c=pc.detach().numpy())
ax[2].set_xlabel("Data axis")
ax[2].set_ylabel("Model densities")
# plt.draw()
# plt.savefig(f"tmp/{t}.png", bbox_inches='tight', inches=0)
ax[0].cla(); ax[1].cla(); ax[2].cla()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment