Skip to content

Instantly share code, notes, and snippets.

@benoitdescamps
Last active September 12, 2020 18:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benoitdescamps/6bc4bbe36b3b2b8b7385d31bdee12b49 to your computer and use it in GitHub Desktop.
Save benoitdescamps/6bc4bbe36b3b2b8b7385d31bdee12b49 to your computer and use it in GitHub Desktop.
Pytorch implementation of Reptile as Ravi, et.al. (https://openreview.net/pdf?id=rJY0-Kcll)
class Reptile:
"""
Repile-optimization as described by Ravi,et.al. (https://openreview.net/pdf?id=rJY0-Kcl)
"""
def __init__(self,
model:torch.nn.Module,
metalearners:List[MetaLearner]):
self.n_tasks = len(metalearners)
self.model = model
self.metalearners = metalearners
def metatraining_step(self,x,y,idxs=None,steps:int=1):
for i,(idx,xp,yp) in enumerate(zip(idxs,x,y)):
metalearner = self.metalearners[idx]
for _ in range(steps):
metalearner.training_step(xp,yp)
def training_step(self,x,y,idxs=None,metalearning_steps:int=100,learning_rate:float=0.01):
for metalearner in self.metalearners:
metalearner.model = transfer_model(self.model,metalearner.model)
self.metatraining_step(x,y,idxs,metalearning_steps)
with torch.no_grad():
for args in zip(self.model.parameters(),*[meta.model.parameters() for meta in self.metalearners]):
param = args[0]
W = torch.stack([args[i] for i in range(1,len(args))],dim=0).mean(dim=0)
param += learning_rate * (W-param)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment