Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active July 5, 2019 15:43
Show Gist options
  • Save thomwolf/788b55f427d67a4021bee500c3229a80 to your computer and use it in GitHub Desktop.
Save thomwolf/788b55f427d67a4021bee500c3229a80 to your computer and use it in GitHub Desktop.
Simple gist on how to train a meta-learner in PyTorch
def train(forward_model, backward_model, optimizer, meta_optimizer, train_data, meta_epochs):
""" Train a meta-learner
Inputs:
forward_model, backward_model: Two identical PyTorch modules (can have shared Tensors)
optimizer: a neural net to be used as optimizer (an instance of the MetaLearner class)
meta_optimizer: an optimizer for the optimizer neural net, e.g. ADAM
train_data: an iterator over an epoch of training data
meta_epochs: meta-training steps
To be added: intialization, early stopping, checkpointing, more control over everything
"""
for meta_epoch in range(meta_epochs): # Meta-training loop (train the optimizer)
optimizer.zero_grad()
losses = []
for inputs, labels in train_data: # Meta-forward pass (train the model)
forward_model.zero_grad() # Forward pass
inputs = Variable(inputs)
labels = Variable(labels)
output = forward_model(inputs)
loss = loss_func(output, labels) # Compute loss
losses.append(loss)
loss.backward() # Backward pass to add gradients to the forward_model
optimizer(forward_model, # Optimizer step (update the models)
backward_model)
meta_loss = sum(losses) # Compute a simple meta-loss
meta_loss.backward() # Meta-backward pass
meta_optimizer.step() # Meta-optimizer step
@tk1363704
Copy link

Excuse me, but the code in the 22nd line is optimizer(forward_model, backward_model) or optimizer.forward(forward_model, backward_model)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment