Simple learning step
class MetaLearner: | |
""" | |
This is nothing more than a regular learning flow. However, we create this | |
class, as we plan on using separate (meta-)learners for each task. | |
""" | |
def __init__(self, | |
model:torch.nn.Module, | |
loss_fn:Callable, | |
optimizer): | |
self.model = model | |
self.loss_fn = loss_fn | |
self.optimizer = optimizer | |
def training_step(self,x,y): | |
y_pred = self.model(x) | |
loss = self.loss_fn(y_pred,y).sum() | |
self.model.zero_grad() | |
loss.backward() | |
self.optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment