Skip to content

Instantly share code, notes, and snippets.

@benoitdescamps
Created September 12, 2020 18:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benoitdescamps/83fa6802f14b353604548298cc651d8b to your computer and use it in GitHub Desktop.
Save benoitdescamps/83fa6802f14b353604548298cc651d8b to your computer and use it in GitHub Desktop.
Simple learning step
class MetaLearner:
"""
This is nothing more than a regular learning flow. However, we create this
class, as we plan on using separate (meta-)learners for each task.
"""
def __init__(self,
model:torch.nn.Module,
loss_fn:Callable,
optimizer):
self.model = model
self.loss_fn = loss_fn
self.optimizer = optimizer
def training_step(self,x,y):
y_pred = self.model(x)
loss = self.loss_fn(y_pred,y).sum()
self.model.zero_grad()
loss.backward()
self.optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment