-
-
Save noamsgl/9c553ba7f594665ad089641543710fc3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% IMPORTS | |
import torch | |
import pytorch_lightning as pl | |
import matplotlib.pyplot as plt | |
from pytorch_lightning import Trainer | |
from torch.nn import functional as F | |
import pyro | |
import pyro.distributions as dist | |
# %% | |
class CoolSystem(pl.LightningModule): | |
def __init__(self): | |
super(CoolSystem, self).__init__() | |
# not the best model... | |
self.l1 = torch.nn.Linear(1, 1) | |
def forward(self, x): | |
return self.l1(x) | |
def training_step(self, batch, batch_idx): | |
x,y = batch | |
yhat = self.forward(x) | |
loss = (yhat-y).abs().mean() | |
tensorboard_logs = {'train_loss': loss} | |
return {'loss': loss, 'log': tensorboard_logs} | |
def validation_step(self, batch, batch_idx): | |
x,y = batch | |
yhat = self.forward(x) | |
loss = (yhat-y).abs().mean() | |
return {'val_loss': loss} | |
def validation_end(self, outputs): | |
# OPTIONAL | |
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() | |
tensorboard_logs = {'loss': avg_loss} | |
return {'val_loss': avg_loss, 'log': tensorboard_logs} | |
def configure_optimizers(self): | |
# REQUIRED | |
# can return multiple optimizers and learning_rate schedulers | |
# (LBFGS it is automatically supported, no need for closure function) | |
return torch.optim.Adam(self.parameters(), lr=0.02) | |
@pl.data_loader | |
def train_dataloader(self): | |
x = torch.arange(100).float().view(-1,1) | |
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 | |
ds = torch.utils.data.TensorDataset(x,y) | |
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2) | |
return dataloader | |
@pl.data_loader | |
def val_dataloader(self): | |
x = torch.arange(10).float().view(-1,1) | |
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 | |
ds = torch.utils.data.TensorDataset(x,y) | |
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2) | |
return dataloader | |
# %% | |
system = CoolSystem() | |
# most basic trainer, uses good defaults | |
trainer = Trainer(min_epochs=1) | |
trainer.fit(system) | |
# RESULTS | |
list(system.parameters()) | |
# %% PYRO LIGHTNING!! | |
#%% | |
import torch | |
import pytorch_lightning as pl | |
import matplotlib.pyplot as plt | |
from pytorch_lightning import Trainer | |
from torch.nn import functional as F | |
import pyro | |
import pyro.distributions as dist | |
class PyroOptWrap(pyro.infer.SVI): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def state_dict(self,): | |
return {} | |
class PyroCoolSystem(pl.LightningModule): | |
def __init__(self, num_data = 100, lr = 1e-3): | |
super(PyroCoolSystem, self).__init__() | |
self.lr = lr | |
self.num_data =num_data | |
def model(self, batch): | |
x, y = batch | |
yhat = self.forward(x) | |
obsdistr = dist.Normal(yhat, 0.2)#.to_event(1) | |
pyro.sample("obs", obsdistr, obs = y) | |
return yhat | |
def guide(self, batch): | |
b_m = pyro.param("b-mean", torch.tensor(0.1)) | |
a_m = pyro.param("a-mean", torch.tensor(0.1)) | |
b = pyro.sample("beta", dist.Normal(b_m , 0.1)) | |
a = pyro.sample("alpha", dist.Normal(a_m,0.1)) | |
def forward(self, x): | |
b = pyro.sample("beta", dist.Normal(0,1)) | |
a = pyro.sample("alpha", dist.Normal(0,1)) | |
yhat = a + x*b | |
return yhat | |
def training_step(self, batch, batch_idx): | |
#x,y = batch | |
#yhat = self.forward(x) | |
loss = self.svi.step(batch) | |
loss = torch.tensor(loss).requires_grad_(True) | |
tensorboard_logs = {'running/loss': loss, 'param/a-mean': pyro.param("a-mean"), 'param/b-mean': pyro.param("b-mean") } | |
return {'loss': loss, 'log': tensorboard_logs} | |
def validation_step(self, batch, batch_idx): | |
loss = self.svi.evaluate_loss(batch) | |
loss = torch.tensor(loss).requires_grad_(True) | |
return {'val_loss': loss} | |
def validation_end(self, outputs): | |
# OPTIONAL | |
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() | |
tensorboard_logs = {'val_loss': avg_loss} | |
#print(pyro.param("a-mean"), pyro.param('b-mean')) | |
return {'val_loss': avg_loss, 'log': tensorboard_logs} | |
def configure_optimizers(self): | |
# REQUIRED | |
# can return multiple optimizers and learning_rate schedulers | |
# (LBFGS it is automatically supported, no need for closure function) | |
self.svi = PyroOptWrap(model=self.model, | |
guide=self.guide, | |
optim=pyro.optim.SGD({"lr": self.lr, "momentum":0.0}), | |
loss=pyro.infer.Trace_ELBO()) | |
return [self.svi] | |
@pl.data_loader | |
def train_dataloader(self): | |
x = torch.rand((self.num_data,)).float().view(-1,1) | |
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 | |
ds = torch.utils.data.TensorDataset(x,y) | |
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 2) | |
return dataloader | |
@pl.data_loader | |
def val_dataloader(self): | |
x = torch.rand((100,)).float().view(-1,1) | |
y = 2 + x + torch.distributions.Normal(0,1).sample((len(x),)).view(-1,1)*0.2 | |
ds = torch.utils.data.TensorDataset(x,y) | |
dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size = 10) | |
return dataloader | |
def optimizer_step(self, *args, **kwargs): | |
pass | |
def backward(self, *args, **kwargs): | |
pass | |
# %% | |
pyro.clear_param_store() | |
system = PyroCoolSystem(num_data=2) | |
# most basic trainer, uses good defaults | |
trainer = Trainer(min_epochs=1, max_epochs=100) | |
trainer.fit(system) | |
# %% | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment