Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active June 29, 2023 10:06
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomwolf/b0902276d173961a7a0208542847b140 to your computer and use it in GitHub Desktop.
Save thomwolf/b0902276d173961a7a0208542847b140 to your computer and use it in GitHub Desktop.
A simple bare MetaLearner class in PyTorch
class MetaLearner(nn.Module):
""" Bare Meta-learner class
Should be added: intialization, hidden states, more control over everything
"""
def __init__(self, model):
super(MetaLearner, self).__init__()
self.weights = Parameter(torch.Tensor(1, 2))
def forward(self, forward_model, backward_model):
""" Forward optimizer with a simple linear neural net
Inputs:
forward_model: PyTorch module with parameters gradient populated
backward_model: PyTorch module identical to forward_model (but without gradients)
updated at the Parameter level to keep track of the computation graph for meta-backward pass
"""
f_model_iter = get_params(forward_model)
b_model_iter = get_params(backward_model)
for f_param_tuple, b_param_tuple in zip(f_model_iter, b_model_iter): # loop over parameters
# Prepare the inputs, we detach the inputs to avoid computing 2nd derivatives (re-pack in new Variable)
(module_f, name_f, param_f) = f_param_tuple
(module_b, name_b, param_b) = b_param_tuple
inputs = Variable(torch.stack([param_f.grad.data, param_f.data], dim=-1))
# Optimization step: compute new model parameters, here we apply a simple linear function
dW = F.linear(inputs, self.weights).squeeze()
param_b = param_b + dW
# Update backward_model (meta-gradients can flow) and forward_model (no need for meta-gradients).
module_b._parameters[name_b] = param_b
param_f.data = param_b.data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment