Pytorch implementation of Reptile as Ravi, et.al. (https://openreview.net/pdf?id=rJY0-Kcll)
class Reptile: | |
""" | |
Repile-optimization as described by Ravi,et.al. (https://openreview.net/pdf?id=rJY0-Kcl) | |
""" | |
def __init__(self, | |
model:torch.nn.Module, | |
metalearners:List[MetaLearner]): | |
self.n_tasks = len(metalearners) | |
self.model = model | |
self.metalearners = metalearners | |
def metatraining_step(self,x,y,idxs=None,steps:int=1): | |
for i,(idx,xp,yp) in enumerate(zip(idxs,x,y)): | |
metalearner = self.metalearners[idx] | |
for _ in range(steps): | |
metalearner.training_step(xp,yp) | |
def training_step(self,x,y,idxs=None,metalearning_steps:int=100,learning_rate:float=0.01): | |
for metalearner in self.metalearners: | |
metalearner.model = transfer_model(self.model,metalearner.model) | |
self.metatraining_step(x,y,idxs,metalearning_steps) | |
with torch.no_grad(): | |
for args in zip(self.model.parameters(),*[meta.model.parameters() for meta in self.metalearners]): | |
param = args[0] | |
W = torch.stack([args[i] for i in range(1,len(args))],dim=0).mean(dim=0) | |
param += learning_rate * (W-param) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment