Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save noamsgl/9c553ba7f594665ad089641543710fc3 to your computer and use it in GitHub Desktop.
Save noamsgl/9c553ba7f594665ad089641543710fc3 to your computer and use it in GitHub Desktop.
#%% IMPORTS
import torch
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer
from torch.nn import functional as F
import pyro
import pyro.distributions as dist
# %%
class CoolSystem(pl.LightningModule):
def __init__(self):
super(CoolSystem, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(1, 1)
def forward(self, x):
return self.l1(x)
def training_step(self, batch, batch_idx):
x,y = batch
yhat = self.forward(x)
loss = (yhat-y).abs().mean()
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx):
x,y = batch
yhat = self.forward(x)
loss = (yhat-y).abs().mean()
return {'val_loss': loss}
def validation_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'loss': avg_loss}
return {'val_loss': avg_loss, 'log': tensorboard_logs}
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
return torch.optim.Adam(self.parameters(), lr=0.02)
@pl.data_loader
def train_dataloader(self):
x = torch.arange(100).float().view(-1,1)
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2
ds = torch.utils.data.TensorDataset(x,y)
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2)
return dataloader
@pl.data_loader
def val_dataloader(self):
x = torch.arange(10).float().view(-1,1)
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2
ds = torch.utils.data.TensorDataset(x,y)
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2)
return dataloader
# %%
system = CoolSystem()
# most basic trainer, uses good defaults
trainer = Trainer(min_epochs=1)
trainer.fit(system)
# RESULTS
list(system.parameters())
# %% PYRO LIGHTNING!!
#%%
import torch
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer
from torch.nn import functional as F
import pyro
import pyro.distributions as dist
class PyroOptWrap(pyro.infer.SVI):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def state_dict(self,):
return {}
class PyroCoolSystem(pl.LightningModule):
def __init__(self, num_data = 100, lr = 1e-3):
super(PyroCoolSystem, self).__init__()
self.lr = lr
self.num_data =num_data
def model(self, batch):
x, y = batch
yhat = self.forward(x)
obsdistr = dist.Normal(yhat, 0.2)#.to_event(1)
pyro.sample("obs", obsdistr, obs = y)
return yhat
def guide(self, batch):
b_m = pyro.param("b-mean", torch.tensor(0.1))
a_m = pyro.param("a-mean", torch.tensor(0.1))
b = pyro.sample("beta", dist.Normal(b_m , 0.1))
a = pyro.sample("alpha", dist.Normal(a_m,0.1))
def forward(self, x):
b = pyro.sample("beta", dist.Normal(0,1))
a = pyro.sample("alpha", dist.Normal(0,1))
yhat = a + x*b
return yhat
def training_step(self, batch, batch_idx):
#x,y = batch
#yhat = self.forward(x)
loss = self.svi.step(batch)
loss = torch.tensor(loss).requires_grad_(True)
tensorboard_logs = {'running/loss': loss, 'param/a-mean': pyro.param("a-mean"), 'param/b-mean': pyro.param("b-mean") }
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx):
loss = self.svi.evaluate_loss(batch)
loss = torch.tensor(loss).requires_grad_(True)
return {'val_loss': loss}
def validation_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
#print(pyro.param("a-mean"), pyro.param('b-mean'))
return {'val_loss': avg_loss, 'log': tensorboard_logs}
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
self.svi = PyroOptWrap(model=self.model,
guide=self.guide,
optim=pyro.optim.SGD({"lr": self.lr, "momentum":0.0}),
loss=pyro.infer.Trace_ELBO())
return [self.svi]
@pl.data_loader
def train_dataloader(self):
x = torch.rand((self.num_data,)).float().view(-1,1)
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2
ds = torch.utils.data.TensorDataset(x,y)
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2)
return dataloader
@pl.data_loader
def val_dataloader(self):
x = torch.rand((100,)).float().view(-1,1)
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2
ds = torch.utils.data.TensorDataset(x,y)
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 10)
return dataloader
def optimizer_step(self, *args, **kwargs):
pass
def backward(self, *args, **kwargs):
pass
# %%
pyro.clear_param_store()
system = PyroCoolSystem(num_data=2)
# most basic trainer, uses good defaults
trainer = Trainer(min_epochs=1, max_epochs=100)
trainer.fit(system)
# %%
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment