Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Pytorch implementation of Reptile as Ravi, (
class Reptile:
Repile-optimization as described by Ravi, (
def __init__(self,
self.n_tasks = len(metalearners)
self.model = model
self.metalearners = metalearners
def metatraining_step(self,x,y,idxs=None,steps:int=1):
for i,(idx,xp,yp) in enumerate(zip(idxs,x,y)):
metalearner = self.metalearners[idx]
for _ in range(steps):
def training_step(self,x,y,idxs=None,metalearning_steps:int=100,learning_rate:float=0.01):
for metalearner in self.metalearners:
metalearner.model = transfer_model(self.model,metalearner.model)
with torch.no_grad():
for args in zip(self.model.parameters(),*[meta.model.parameters() for meta in self.metalearners]):
param = args[0]
W = torch.stack([args[i] for i in range(1,len(args))],dim=0).mean(dim=0)
param += learning_rate * (W-param)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.