Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Pytorch implementation of Reptile as Ravi, et.al. (https://openreview.net/pdf?id=rJY0-Kcll)
class Reptile:
"""
Repile-optimization as described by Ravi,et.al. (https://openreview.net/pdf?id=rJY0-Kcl)
"""
def __init__(self,
model:torch.nn.Module,
metalearners:List[MetaLearner]):
self.n_tasks = len(metalearners)
self.model = model
self.metalearners = metalearners
def metatraining_step(self,x,y,idxs=None,steps:int=1):
for i,(idx,xp,yp) in enumerate(zip(idxs,x,y)):
metalearner = self.metalearners[idx]
for _ in range(steps):
metalearner.training_step(xp,yp)
def training_step(self,x,y,idxs=None,metalearning_steps:int=100,learning_rate:float=0.01):
for metalearner in self.metalearners:
metalearner.model = transfer_model(self.model,metalearner.model)
self.metatraining_step(x,y,idxs,metalearning_steps)
with torch.no_grad():
for args in zip(self.model.parameters(),*[meta.model.parameters() for meta in self.metalearners]):
param = args[0]
W = torch.stack([args[i] for i in range(1,len(args))],dim=0).mean(dim=0)
param += learning_rate * (W-param)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.