Skip to content

Instantly share code, notes, and snippets.

@a-y-khan
Last active December 16, 2019 02:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save a-y-khan/8693d2b186227561a4baf4d03ce75c34 to your computer and use it in GitHub Desktop.
Save a-y-khan/8693d2b186227561a4baf4d03ce75c34 to your computer and use it in GitHub Desktop.
Simple, contrived example to trigger MLFlowLogger pickle bug
from argparse import Namespace
import itertools
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import pytorch_lightning as pl
from pytorch_lightning.logging.mlflow_logger import MLFlowLogger
class BasicDataset(data.Dataset):
def __init__(self):
super(BasicDataset).__init__()
self.tensors = list(
itertools.repeat((torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float),
torch.tensor([[0], [1], [1], [0]], dtype=torch.float)), 1000))
def __getitem__(self, index):
return self.tensors[index]
def __len__(self):
return len(self.tensors)
class XORGateModel(pl.LightningModule):
LEARNING_RATE = 0.1
# default checkpoint expects hparams!
def __init__(self, hparams):
super(XORGateModel, self).__init__()
self.hparams = hparams
self.hidden = nn.Linear(2, 3, bias=True)
self.output = nn.Linear(3, 1, bias=True)
self.sigmoid = nn.Sigmoid()
self.loss_function = nn.MSELoss(reduction='sum')
def forward(self, input):
x = self.hidden(input)
x = self.sigmoid(x)
x = self.output(x)
return x
def training_step(self, batch, batch_index):
x, y = batch
y_hat = self.forward(x)
return {'loss': self.loss_function(y_hat, y)}
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=XORGateModel.LEARNING_RATE)
@pl.data_loader
def train_dataloader(self):
return data.DataLoader(BasicDataset())
@pl.data_loader
def val_dataloader(self):
return data.DataLoader(BasicDataset())
if '__main__' == __name__:
test_hparams = Namespace()
model = XORGateModel(test_hparams)
logger = MLFlowLogger(experiment_name='test_lightning_logger', tracking_uri=os.environ['MLFLOW_TRACKING_URI'])
trainer = pl.Trainer(logger=logger, distributed_backend='ddp', gpus='-1')
trainer.fit(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment