Skip to content

Instantly share code, notes, and snippets.

@shawnthu
Forked from thomwolf/gradient_accumulation.py
Created June 25, 2021 03:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shawnthu/89bcbe42cccfebba18c31ea2c514f741 to your computer and use it in GitHub Desktop.
Save shawnthu/89bcbe42cccfebba18c31ea2c514f741 to your computer and use it in GitHub Desktop.
PyTorch gradient accumulation training loop
model.zero_grad() # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs) # Forward pass
loss = loss_function(predictions, labels) # Compute loss function
loss = loss / accumulation_steps # Normalize our loss (if averaged)
loss.backward() # Backward pass
if (i+1) % accumulation_steps == 0: # Wait for several backward steps
optimizer.step() # Now we can do an optimizer step
model.zero_grad() # Reset gradients tensors
if (i+1) % evaluation_steps == 0: # Evaluate the model when we...
evaluate_model() # ...have no gradients accumulated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment