Last active
June 29, 2023 10:06
-
-
Save thomwolf/b0902276d173961a7a0208542847b140 to your computer and use it in GitHub Desktop.
A simple bare MetaLearner class in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MetaLearner(nn.Module): | |
""" Bare Meta-learner class | |
Should be added: intialization, hidden states, more control over everything | |
""" | |
def __init__(self, model): | |
super(MetaLearner, self).__init__() | |
self.weights = Parameter(torch.Tensor(1, 2)) | |
def forward(self, forward_model, backward_model): | |
""" Forward optimizer with a simple linear neural net | |
Inputs: | |
forward_model: PyTorch module with parameters gradient populated | |
backward_model: PyTorch module identical to forward_model (but without gradients) | |
updated at the Parameter level to keep track of the computation graph for meta-backward pass | |
""" | |
f_model_iter = get_params(forward_model) | |
b_model_iter = get_params(backward_model) | |
for f_param_tuple, b_param_tuple in zip(f_model_iter, b_model_iter): # loop over parameters | |
# Prepare the inputs, we detach the inputs to avoid computing 2nd derivatives (re-pack in new Variable) | |
(module_f, name_f, param_f) = f_param_tuple | |
(module_b, name_b, param_b) = b_param_tuple | |
inputs = Variable(torch.stack([param_f.grad.data, param_f.data], dim=-1)) | |
# Optimization step: compute new model parameters, here we apply a simple linear function | |
dW = F.linear(inputs, self.weights).squeeze() | |
param_b = param_b + dW | |
# Update backward_model (meta-gradients can flow) and forward_model (no need for meta-gradients). | |
module_b._parameters[name_b] = param_b | |
param_f.data = param_b.data |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment