Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active January 16, 2024 02:38
Show Gist options
  • Star 87 You must be signed in to star a gist
  • Fork 12 You must be signed in to fork a gist
  • Save thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3 to your computer and use it in GitHub Desktop.
Save thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3 to your computer and use it in GitHub Desktop.
PyTorch gradient accumulation training loop
model.zero_grad() # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs) # Forward pass
loss = loss_function(predictions, labels) # Compute loss function
loss = loss / accumulation_steps # Normalize our loss (if averaged)
loss.backward() # Backward pass
if (i+1) % accumulation_steps == 0: # Wait for several backward steps
optimizer.step() # Now we can do an optimizer step
model.zero_grad() # Reset gradients tensors
if (i+1) % evaluation_steps == 0: # Evaluate the model when we...
evaluate_model() # ...have no gradients accumulated
@arquolo
Copy link

arquolo commented Jan 16, 2023

@Alex-Mathai-98 W.r.t. @thomwolf's code my changes are optimization to not keep gradient-linked data when it's not needed anymore and they don't affect the value of the gradient.
The main difference is detach() calls to not keep gradient used by last backward after end of the loop or in-between of iterations.

@Alex-Mathai-98
Copy link

@arquolo - okay thankyou for the clarification.

@AmmarRashed
Copy link

AmmarRashed commented Jan 16, 2024

@thomwolf Thanks for the code. Just a little fix in the condition at line 7:

if (i+1) % accumulation_steps == 0:

This assumes that the number of batches is perfectly divisble by the accumulation steps. However, if there are, say, 10, batches, and the accumulation steps are 4, the last two batches would not make to the optimizer.step().

And even if you add an extra condition, you will still need to adjust the normalization denominator because there will be only 2 not 4 accumulation steps.

So the updated code would be:

n = len(training_set)
remainder_batches = n % accumulation_steps # calculate number of remainder batches

for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)
    loss = loss_function(predictions, labels)
    remaining = n - i

    # update the denominator if the remaining batches are leq number of remainder batches 
    denominator = remainder_batches if remaining <= remainder_batches else accumulation_steps

    loss = loss / denominator
    loss.backward()
    if (i+1) % accumulation_steps == 0 or i == n - 1: # add condition for last iteration
        optimizer.step()
        model.zero_grad()

You can emulate the logic in a standalone script as follows:

def get_value():
    return 5


n = 10
steps = 4
values = []
val = 0
remainder = n % steps
for i in range(n):
    a = get_value()
    remaining = n - i
    if remaining <= remainder:
        denom = n % steps
    else:
        denom = steps
    val += a / denom
    print(i, denom)
    if (i + 1) % steps == 0 or i == n - 1:
        values.append(val)
        val = 0
        print("update")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment