Skip to content

Instantly share code, notes, and snippets.

@NumesSanguis
Created July 1, 2020 07:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NumesSanguis/6d929ce7816e4a2514beb0d00127db2c to your computer and use it in GitHub Desktop.
Save NumesSanguis/6d929ce7816e4a2514beb0d00127db2c to your computer and use it in GitHub Desktop.
from functools import partial
import torch
import torch.optim as optim
from torch.utils.data import Dataset
import pytorch_lightning as pl
from pytorch_lightning import Trainer
# partial to give all params, except the data
hparams = {
"criterion": torch.nn.BCELoss(), # F.cross_entropy(), # loss function
"optimizer": partial(optim.Adam, lr=0.001), # (lr=0.001),
# "learning_rate": 0.001,
"filters": 64,
"layers": 2
}
class EmptyDataset(Dataset):
def __init__(self, transform=None):
pass
def __len__(self):
return 32
def __getitem__(self, idx):
return {"input": np.array([1]), "output": "nothing"}
class LitLake(pl.LightningModule):
def __init__(self, hparams: dict, transforms: dict = None):
super().__init__()
self.hparams = hparams
print("self.hparams\n", self.hparams)
def forward(self, x):
pass
def training_step(self, batch, batch_idx):
"""
Lightning calls this inside the training loop with the data from the training dataloader
passed in as `batch`.
"""
# forward pass
x, y = batch
y_hat = self(x)
loss = self.hparams["criterion"](y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def configure_optimizers(self):
print("self.hparams\n", self.hparams)
optimizer = self.hparams["optimizer"](self.parameters())
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
return [optimizer], [scheduler]
def train_dataloader(self):
return DataLoader(EmptyDataset(), batch_size=4, num_workers=1)
model = LitLake(hparams=hparams)
# most basic trainer, uses good defaults
trainer = Trainer() # gpus=1, num_nodes=1
trainer.fit(model) # KeyError: 'optimizer'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment