Skip to content

Instantly share code, notes, and snippets.

@egrefen
Last active May 12, 2021 08:15
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save egrefen/19ad0b65cf4d997a4b5bebf6e98b0562 to your computer and use it in GitHub Desktop.
Save egrefen/19ad0b65cf4d997a4b5bebf6e98b0562 to your computer and use it in GitHub Desktop.
Train maml model with torchmeta and higher v0.2.
# Based on the code in https://github.com/tristandeleu/pytorch-meta/tree/master/examples/maml
# Basically, we only use the dataset loaders/helpers from TorchMeta and replace usage of MetaModules
# with normal pytorch nn.Modules, letting higher deal with making the inner loop unrollable and the
# optimizers differentiable. This makes it easier to use another optimizer than SGD, or any arbitrary
# third-party model, when doing MAML using this codebase.
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import logging
from collections import OrderedDict
import higher # tested with higher v0.2
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
logger = logging.getLogger(__name__)
def conv3x3(in_channels, out_channels, **kwargs):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
nn.BatchNorm2d(out_channels, momentum=1., track_running_stats=False),
nn.ReLU(),
nn.MaxPool2d(2)
)
class ConvolutionalNeuralNetwork(nn.Module):
def __init__(self, in_channels, out_features, hidden_size=64):
super(ConvolutionalNeuralNetwork, self).__init__()
self.in_channels = in_channels
self.out_features = out_features
self.hidden_size = hidden_size
self.features = nn.Sequential(
conv3x3(in_channels, hidden_size),
conv3x3(hidden_size, hidden_size),
conv3x3(hidden_size, hidden_size),
conv3x3(hidden_size, hidden_size)
)
self.classifier = nn.Linear(hidden_size, out_features)
def forward(self, inputs, params=None):
features = self.features(inputs)
features = features.view((features.size(0), -1))
logits = self.classifier(features)
return logits
def get_accuracy(logits, targets):
"""Compute the accuracy (after adaptation) of MAML on the test/query points
Parameters
----------
logits : `torch.FloatTensor` instance
Outputs/logits of the model on the query points. This tensor has shape
`(num_examples, num_classes)`.
targets : `torch.LongTensor` instance
A tensor containing the targets of the query points. This tensor has
shape `(num_examples,)`.
Returns
-------
accuracy : `torch.FloatTensor` instance
Mean accuracy on the query points
"""
_, predictions = torch.max(logits, dim=-1)
return torch.mean(predictions.eq(targets).float())
def train(args):
logger.warning('This script is an example to showcase the data-loading '
'features of Torchmeta in conjunction with using higher to '
'make models "unrollable" and optimizers differentiable, '
'and as such has been very lightly tested.')
dataset = omniglot(args.folder,
shots=args.num_shots,
ways=args.num_ways,
shuffle=True,
test_shots=15,
meta_train=True,
download=args.download)
dataloader = BatchMetaDataLoader(dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers)
model = ConvolutionalNeuralNetwork(1,
args.num_ways,
hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()
inner_optimiser = torch.optim.SGD(model.parameters(), lr=args.step_size)
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
with tqdm(dataloader, total=args.num_batches) as pbar:
for batch_idx, batch in enumerate(pbar):
model.zero_grad()
train_inputs, train_targets = batch['train']
train_inputs = train_inputs.to(device=args.device)
train_targets = train_targets.to(device=args.device)
test_inputs, test_targets = batch['test']
test_inputs = test_inputs.to(device=args.device)
test_targets = test_targets.to(device=args.device)
outer_loss = torch.tensor(0., device=args.device)
accuracy = torch.tensor(0., device=args.device)
for task_idx, (train_input, train_target, test_input,
test_target) in enumerate(zip(train_inputs, train_targets,
test_inputs, test_targets)):
with higher.innerloop_ctx(model, inner_optimiser, copy_initial_weights=False) as (fmodel, diffopt):
train_logit = fmodel(train_input)
inner_loss = F.cross_entropy(train_logit, train_target)
diffopt.step(inner_loss)
test_logit = fmodel(test_input)
outer_loss += F.cross_entropy(test_logit, test_target)
with torch.no_grad():
accuracy += get_accuracy(test_logit, test_target)
outer_loss.div_(args.batch_size)
accuracy.div_(args.batch_size)
outer_loss.backward()
meta_optimizer.step()
pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
if batch_idx >= args.num_batches:
break
# Save model
if args.output_folder is not None:
filename = os.path.join(args.output_folder, 'maml_omniglot_'
'{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
with open(filename, 'wb') as f:
state_dict = model.state_dict()
torch.save(state_dict, f)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser('Model-Agnostic Meta-Learning (MAML)')
parser.add_argument('folder', type=str,
help='Path to the folder the data is downloaded to.')
parser.add_argument('--num-shots', type=int, default=5,
help='Number of examples per class (k in "k-shot", default: 5).')
parser.add_argument('--num-ways', type=int, default=5,
help='Number of classes per task (N in "N-way", default: 5).')
parser.add_argument('--step-size', type=float, default=0.4,
help='Step-size for the gradient step for adaptation (default: 0.4).')
parser.add_argument('--hidden-size', type=int, default=64,
help='Number of channels for each convolutional layer (default: 64).')
parser.add_argument('--output-folder', type=str, default=None,
help='Path to the output folder for saving the model (optional).')
parser.add_argument('--batch-size', type=int, default=16,
help='Number of tasks in a mini-batch of tasks (default: 16).')
parser.add_argument('--num-batches', type=int, default=100,
help='Number of batches the model is trained over (default: 100).')
parser.add_argument('--num-workers', type=int, default=1,
help='Number of workers for data loading (default: 1).')
parser.add_argument('--download', action='store_true',
help='Download the Omniglot dataset in the data folder.')
parser.add_argument('--use-cuda', action='store_true',
help='Use CUDA if available.')
args = parser.parse_args()
args.device = torch.device('cuda' if args.use_cuda
and torch.cuda.is_available() else 'cpu')
train(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment