Skip to content

Instantly share code, notes, and snippets.

@iancze
Last active February 16, 2023 23:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iancze/94bdb59f83df51102f8a238b248611c9 to your computer and use it in GitHub Desktop.
Save iancze/94bdb59f83df51102f8a238b248611c9 to your computer and use it in GitHub Desktop.
Pyro: using latent variables with a plate
import torch
from torch import nn
import numpy as np
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.infer import Predictive, MCMC, NUTS
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO, Predictive
def gaussian(x, A_i, mu_i, sigma_i=0.5):
r"""
Evaluate a Gaussian ring of the form
.. math::
f(r) = A_i \exp \left(- \frac{(r - r_i)^2}{2 \sigma_i^2} \right)
"""
return A_i * torch.exp(-((x - mu_i) ** 2) / (2 * sigma_i**2))
class Spectrum(PyroModule):
def __init__(self, mus):
super().__init__()
self.mus = torch.as_tensor(mus)
self.nlines = len(self.mus)
# works fine
# self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2).expand([self.nlines]).to_event(1))
with pyro.plate("plate", self.nlines):
# doesn't work, errors with tensor shape
# self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2))
# code executes but can't find "log_amplitudes" in samples
self.log_amplitudes = pyro.sample("log_amplitudes", dist.Normal(1.0, 0.2))
self.baseline = PyroSample(dist.Normal(0.0, 1.0))
def intensities(self, x):
I = torch.zeros_like(x)
for i in range(self.nlines):
A_i = torch.pow(10.0, self.log_amplitudes[i])
mu_i = self.mus[i]
I += gaussian(x, A_i, mu_i)
return I
def forward(self, x, y, yerr):
I = self.intensities(x) + self.baseline
with pyro.plate("data", len(y)):
pyro.sample("obs", dist.Normal(I, yerr), obs=y)
return I
if __name__=="__main__":
import numpy as np
np.random.seed(123)
# create a fake dataset
N = 80
xs = torch.as_tensor(np.sort(np.random.uniform(0, 10, size=N)))
true_mus = np.array([2.0, 4.5, 7.4])
true_amplitudes = np.array([0.4, 1.0, 0.6])
true_log_amplitudes = np.log10(true_amplitudes)
true_baseline = 0.5
yerr = 0.05
ys = torch.zeros_like(xs)
for i in range(len(true_amplitudes)):
ys += gaussian(xs, true_amplitudes[i], true_mus[i])
ys += true_baseline
# add random noise
ys += torch.as_tensor(np.random.normal(loc=0, scale=yerr, size=N))
import matplotlib.pyplot as plt
fig, ax = plt.subplots(nrows=1)
ax.plot(xs.numpy(), ys.numpy(), "o")
fig.savefig("data.png")
# now create a model
model = Spectrum(mus=true_mus)
# define SVI guide
guide = AutoDiagonalNormal(model)
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())
num_iterations = 10
pyro.clear_param_store()
loss_tracker = np.empty(num_iterations)
for j in range(num_iterations):
# calculate the loss and take a gradient step
loss_tracker[j] = svi.step(xs, ys, yerr)
print(j)
predictive = Predictive(model, guide=guide, num_samples=1)(xs, ys, yerr)
for k, v in predictive.items():
print(f"{k}: {v.shape}")
import torch
from torch import nn
import numpy as np
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.infer import Predictive, MCMC, NUTS
from pyro.infer.autoguide import AutoNormal
from pyro.infer import SVI, Trace_ELBO, Predictive
import numpy as np
def gaussian(x, A_i, mu_i, sigma_i=0.5):
r"""
Evaluate a Gaussian ring of the form
.. math::
f(r) = A_i \exp \left(- \frac{(r - r_i)^2}{2 \sigma_i^2} \right)
"""
return A_i * torch.exp(-((x - mu_i) ** 2) / (2 * sigma_i**2))
np.random.seed(123)
# create a fake dataset
N = 80
xs = torch.as_tensor(np.sort(np.random.uniform(0, 10, size=N)))
true_mus = np.array([2.0, 4.5, 7.4])
true_amplitudes = np.array([0.4, 1.0, 0.6])
true_log_amplitudes = np.log10(true_amplitudes)
true_baseline = 0.5
yerr = 0.05
ys = torch.zeros_like(xs)
for i in range(len(true_amplitudes)):
ys += gaussian(xs, true_amplitudes[i], true_mus[i])
ys += true_baseline
# add random noise
ys += torch.as_tensor(np.random.normal(loc=0, scale=yerr, size=N))
import matplotlib.pyplot as plt
fig, ax = plt.subplots(nrows=1)
ax.plot(xs.numpy(), ys.numpy(), "o")
fig.savefig("data.png")
# define the model
def model_func(x, y, yerr):
baseline = pyro.sample("baseline", dist.Normal(0.0, 1.0))
with pyro.plate("plate", 3):
log_amplitudes = pyro.sample("log_amplitudes", dist.Normal(1.0, 0.2))
I = torch.zeros_like(x)
for i in range(3):
A_i = torch.pow(10.0, log_amplitudes[i])
mu_i = true_mus[i]
I += gaussian(x, A_i, mu_i)
I += baseline
with pyro.plate("data", len(y)):
pyro.sample("obs", dist.Normal(I, yerr), obs=y)
# define SVI guide
guide = AutoNormal(model_func)
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model_func, guide, adam, loss=Trace_ELBO())
num_iterations = 1000
pyro.clear_param_store()
loss_tracker = np.empty(num_iterations)
for j in range(num_iterations):
# calculate the loss and take a gradient step
loss_tracker[j] = svi.step(xs, ys, yerr)
print(j)
predictive = Predictive(model_func, guide=guide, num_samples=1)(xs, ys, yerr)
for k, v in predictive.items():
print(f"{k}: {v.shape}")
# https://forum.pyro.ai/t/how-to-access-guide-parameters/3995
print(list(guide.parameters()))
with pyro.poutine.trace(param_only=True) as tr:
guide(xs, ys, yerr)
constrained_params = [site["value"] for site in tr.trace.nodes.values()]
PARAMS = [p.unconstrained() for p in constrained_params]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment