Last active
July 5, 2019 15:43
-
-
Save thomwolf/788b55f427d67a4021bee500c3229a80 to your computer and use it in GitHub Desktop.
Simple gist on how to train a meta-learner in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(forward_model, backward_model, optimizer, meta_optimizer, train_data, meta_epochs): | |
""" Train a meta-learner | |
Inputs: | |
forward_model, backward_model: Two identical PyTorch modules (can have shared Tensors) | |
optimizer: a neural net to be used as optimizer (an instance of the MetaLearner class) | |
meta_optimizer: an optimizer for the optimizer neural net, e.g. ADAM | |
train_data: an iterator over an epoch of training data | |
meta_epochs: meta-training steps | |
To be added: intialization, early stopping, checkpointing, more control over everything | |
""" | |
for meta_epoch in range(meta_epochs): # Meta-training loop (train the optimizer) | |
optimizer.zero_grad() | |
losses = [] | |
for inputs, labels in train_data: # Meta-forward pass (train the model) | |
forward_model.zero_grad() # Forward pass | |
inputs = Variable(inputs) | |
labels = Variable(labels) | |
output = forward_model(inputs) | |
loss = loss_func(output, labels) # Compute loss | |
losses.append(loss) | |
loss.backward() # Backward pass to add gradients to the forward_model | |
optimizer(forward_model, # Optimizer step (update the models) | |
backward_model) | |
meta_loss = sum(losses) # Compute a simple meta-loss | |
meta_loss.backward() # Meta-backward pass | |
meta_optimizer.step() # Meta-optimizer step |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Excuse me, but the code in the 22nd line is optimizer(forward_model, backward_model) or optimizer.forward(forward_model, backward_model)?