Skip to content

Instantly share code, notes, and snippets.

@pietrolesci
Created October 29, 2020 10:36
Show Gist options
  • Save pietrolesci/503c271b8464ce1c579b1856fe58d87c to your computer and use it in GitHub Desktop.
Save pietrolesci/503c271b8464ce1c579b1856fe58d87c to your computer and use it in GitHub Desktop.
First attempt, not run
import logging
import higher
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.metrics.functional.classification import accuracy
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
from config import parser
logger = logging.getLogger(__name__)
class ConvolutionalNeuralNetwork(nn.Module):
def __init__(self, in_channels, out_features, hidden_size=64):
super(ConvolutionalNeuralNetwork, self).__init__()
self.in_channels = in_channels
self.out_features = out_features
self.hidden_size = hidden_size
self.features = nn.Sequential(
self.conv3x3(in_channels, hidden_size),
self.conv3x3(hidden_size, hidden_size),
self.conv3x3(hidden_size, hidden_size),
self.conv3x3(hidden_size, hidden_size),
)
self.classifier = nn.Linear(hidden_size, out_features)
def forward(self, inputs, params=None):
features = self.features(inputs)
features = features.view((features.size(0), -1))
logits = self.classifier(features)
return logits
def conv3x3(self, in_channels, out_channels, **kwargs):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
nn.BatchNorm2d(out_channels, momentum=1.0, track_running_stats=False),
nn.ReLU(),
nn.MaxPool2d(2),
)
class MAML(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.accuracy = accuracy
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx, optimizer_idx):
# batch of tasks
meta_optimizer, inner_optimiser = self.optimizers()
train_inputs, train_targets = batch["train"]
test_inputs, test_targets = batch["test"]
outer_loss = torch.tensor(0.0, device=self.device)
for task_idx, (train_input, train_target, test_input, test_target) in enumerate(
zip(train_inputs, train_targets, test_inputs, test_targets)
):
with higher.innerloop_ctx(
self.model, inner_optimiser, copy_initial_weights=False
) as (fmodel, diffopt):
train_logit = fmodel(train_input)
inner_loss = F.cross_entropy(train_logit, train_target)
diffopt.step(inner_loss)
test_logit = fmodel(test_input)
outer_loss += F.cross_entropy(test_logit, test_target)
self.log_dict(
{
"outer_loss": outer_loss,
"accuracy": self.accuracy(test_logit, test_target),
}
)
outer_loss.div_(args.batch_size)
self.manual_backward(outer_loss, meta_optimizer)
return outer_loss
def configure_optimizers(self):
meta_optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
inner_optimiser = torch.optim.SGD(self.parameters(), lr=args.step_size)
return [meta_optimizer, inner_optimiser]
class OmniglotDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str,
shots: int,
ways: int,
shuffle_ds: bool,
test_shots: int,
meta_train: bool,
download: bool,
batch_size: str,
shuffle: bool,
num_workers: int,
):
super().__init__()
self.data_dir = data_dir
self.shots = shots
self.ways = ways
self.shuffle_ds = shuffle_ds
self.test_shots = test_shots
self.meta_train = meta_train
self.download = download
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
def setup(self, stage=None):
self.task_dataset = omniglot(
self.data_dir,
shots=self.shots,
ways=self.ways,
shuffle=self.shuffle_ds,
test_shots=self.test_shots,
meta_train=self.meta_train,
download=self.download,
)
def train_dataloader(self):
return BatchMetaDataLoader(
self.task_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
)
if __name__ == "__main__":
logger.warning(
"This script is an example to showcase the data-loading "
"features of Torchmeta in conjunction with using higher to "
'make models "unrollable" and optimizers differentiable, '
"and as such has been very lightly tested."
)
args = parser.parse_args()
dm = OmniglotDataModule(
"data",
shots=args.num_shots,
ways=args.num_ways,
shuffle_ds=True,
test_shots=15,
meta_train=True,
download=args.download,
batch_size=args.batch_size,
shuffle=True,
num_workers=8,
)
model = MAML(
model=ConvolutionalNeuralNetwork(1, args.num_ways, hidden_size=args.hidden_size)
)
trainer = Trainer(
automatic_optimization=False,
profiler=True,
max_epochs=args.n_epochs,
fast_dev_run=False,
num_sanity_val_steps=2,
)
trainer.fit(model, datamodule=dm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment