Skip to content

Instantly share code, notes, and snippets.

@charlieoneill11
Created January 11, 2022 00:19
Show Gist options
  • Save charlieoneill11/45af6e084b9c5c8e0cd45ecc2afd2951 to your computer and use it in GitHub Desktop.
Save charlieoneill11/45af6e084b9c5c8e0cd45ecc2afd2951 to your computer and use it in GitHub Desktop.
def training_loop(n_epochs, optimiser, model, loss_fn, X_train, X_val, y_train, y_val):
for epoch in range(1, n_epochs + 1):
output_train = model(X_train) # forwards pass
loss_train = loss_fn(output_train, y_train) # calculate loss
output_val = model(X_val)
loss_val = loss_fn(output_val, y_val)
optimiser.zero_grad() # set gradients to zero
loss_train.backward() # backwards pass
optimiser.step() # update model parameters
if epoch == 1 or epoch % 10000 == 0:
print(f"Epoch {epoch}, Training loss {loss_train.item():.4f},"
f" Validation loss {loss_val.item():.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment