Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active January 16, 2024 02:38
  • Star 87 You must be signed in to star a gist
  • Fork 12 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3 to your computer and use it in GitHub Desktop.
PyTorch gradient accumulation training loop
model.zero_grad() # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs) # Forward pass
loss = loss_function(predictions, labels) # Compute loss function
loss = loss / accumulation_steps # Normalize our loss (if averaged)
loss.backward() # Backward pass
if (i+1) % accumulation_steps == 0: # Wait for several backward steps
optimizer.step() # Now we can do an optimizer step
model.zero_grad() # Reset gradients tensors
if (i+1) % evaluation_steps == 0: # Evaluate the model when we...
evaluate_model() # ...have no gradients accumulated
@muaz-git
Copy link

Thanks for the code.
Can you please explain what does it mean by "Normalize our loss (if averaged)"?

@thomwolf
Copy link
Author

thomwolf commented May 4, 2019

If you are using a loss which is averaged over the training samples (which is the case most of the time), you have to divide by the number of gradient accumulation steps

@Auth0rM0rgan
Copy link

Auth0rM0rgan commented May 19, 2019

Hey @thomwolf,

Thanks for the tips and tricks.
Just one question: Shouldn't use optimizer.zero_grad() before loss.backward()?

@thomwolf
Copy link
Author

@Auth0rM0rgan
No, otherwise you erase all the gradient accumulated in the leaves of the computation graphs.

@jaideep11061982
Copy link

hi thom..
if we have to simulate the case where we to process more number of images in a batch say due to GPU limitation i can keep batch size to be only 64 . But suppose i want to process 128 images . In that case shouldnt we do this way

total_loss=0
for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)                     # Forward pass
    loss = loss_function(predictions, labels)       # Compute loss function
  
    if (i+1) % accumulation_steps == 0:             # Wait for several backward steps
        total_loss = (total_loss+loss )/ accumulation_steps                # Normalize our loss (if averaged)
        loss.backward()                                 # Backward pass
        optimizer.step()                            # Now we can do an optimizer step
        model.zero_grad()                           # Reset gradients tensors
        total_loss=0
        if (i+1) % evaluation_steps == 0:           # Evaluate the model when we...
            evaluate_model()    
   else :
        total_loss=loss+total_loss

@jshin49
Copy link

jshin49 commented Nov 27, 2019

Would gradient accumulation work for MAML training?

@meet-minimalist
Copy link

Hey,
Thanks for the code.
I want to know how you handled batch normalization in gradient accumulation?
E.g. If we use 8 sub-batch size and 4 iterations of forward passes and then accumulate gradient and backprop gradients. This will result in effective batch size of 32.

But the problem with this setting is that in batch normalization layer, batch mean and batch variance for training are computed on batch of 8 at each forward pass for 4 times which are not the same as computing batch mean and batch variance on batch of 32 size.

This stops simulating the same effect of training with 32 batch size.

Please elaborate how you handled this. OR How to handle this? :(

To verify this train a model with 32 batch size and train another one with 8-batch x 4-iteration accumulation strategy. For both the times, the weights and the data feeding should be identical. This way you will definitely encounter difference in loss.

@aGIToz
Copy link

aGIToz commented Sep 4, 2020

This still seems to be an open issue, I have not found any definitive answer to it. @meet-minimalist did you find a solution of handling BN with gradient accumulation /

@meet-minimalist
Copy link

@aGIToz Sorry, I havent found exact solution as of now. But I know Tensorflow uses something like SyncBatchNorm between 8 individual TPUs when training on TPU Cluster. I dont know how they do it but they must have developed some workaround for this problem.

@ziqihuang233
Copy link

I rewrite the code in my project. But I found the loss increase! But if I trained in the normal way, the loss decreased. I have no idea.

@VpkPrasanna
Copy link

Ho to find the gradient-accumulation_value , based on what we need to fix the value ?

@wenxie18
Copy link

hi thom..
if we have to simulate the case where we to process more number of images in a batch say due to GPU limitation i can keep batch size to be only 64 . But suppose i want to process 128 images . In that case shouldnt we do this way

total_loss=0
for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)                     # Forward pass
    loss = loss_function(predictions, labels)       # Compute loss function
  
    if (i+1) % accumulation_steps == 0:             # Wait for several backward steps
        total_loss = (total_loss+loss )/ accumulation_steps                # Normalize our loss (if averaged)
        loss.backward()                                 # Backward pass
        optimizer.step()                            # Now we can do an optimizer step
        model.zero_grad()                           # Reset gradients tensors
        total_loss=0
        if (i+1) % evaluation_steps == 0:           # Evaluate the model when we...
            evaluate_model()    
   else :
        total_loss=loss+total_loss

Hi I'm using the similar idea. Do you have any further suggestions? Thanks

@arquolo
Copy link

arquolo commented May 20, 2022

@jaideep11061982 , @vaneshieh , this code is incorrect and has multiple mistakes.
Explanation (a bit rewritten for readability):

total_loss = 0
for i, (inputs, labels) in enumerate(training_set, 1):
    # Given absense of torch.no_grad(),
    # `loss` will contain all immediate states of forward pass, needed for backward computation
    predictions = model(inputs)
    loss = loss_function(predictions, labels)
  
    # Summation here aggregates loss from N accumulation steps
    # Thus total_loss has immediate states of NUM_ACCUMULATES forwards
    # So no memory saving here, but leak
    total_loss += loss
    if i % NUM_ACCUMULATES == 0:
        total_loss = total_loss / NUM_ACCUMULATES 
        
        # Here we do backward on loss from last step, so grads in model are influenced only by each Nth batch
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss = 0
        
        if i % (NUM_ACCUMULATES * NUM_UPDATES)== 0:
            evaluate_model()    

So this code updates weights using only 1/N of training set, eating N times more memory than training with single accumulation step.
Thus it's equal to those (to make mistakes more obvious):

total_loss = 0
batches = []
for i, (inputs, labels) in enumerate(training_set, 1):
    batches.append((inputs, labels))
    
    if i % NUM_ACCUMULATES == 0:
        # Store immediate buffers from NUM_ACCUMULATES forward passes
        losses = [loss_function(model(inputs), labels) for inputs in batches]
        total_loss = sum(losses) / NUM_ACCUMULATES
        
        # Do backward only on last loss
        losses[-1].backward()
    
        optimizer.step()
        optimizer.zero_grad()
        batches = []
        
        if i % (NUM_ACCUMULATES * NUM_UPDATES) == 0:
            evaluate_model()

Original code of @thomwolf is also not without mistakes, though they are less striking.
I have rewriten it to be more error-prone:

# Zero saved gradients
optimizer.zero_grad()

total_loss = 0
for i, (inputs, labels) in enumerate(training_set, 1):
    # Do forward pass and store immediate buffers
    predictions = model(inputs)
    loss = loss_function(predictions, labels)

    # Add gradients from this batch to saved ones, divide loss by NUM_ACCUMULATES if it's averaged over samples
    (loss / NUM_ACCUMULATES).backward()
    
    # Drop immediate buffers, mandatory
    # This guaranties that total loss doesn't bring some immediate buffers with it
    predictions.detach_()
    loss.detach_()
    total_loss += loss

    if i % NUM_ACCUMULATES == 0:
        # Update parameters using saved gradients
        optimizer.step()

        # Zero saved gradients
        optimizer.zero_grad()

        if i % (NUM_ACCUMULATES * NUM_UPDATES) == 0:
            print(f'train loss: {total_loss / NUM_ACCUMULATES / NUM_UPDATES}')
            total_loss = 0
            evaluate_model()

But seriously, use some more high level wrapper for Pytorch, like catalyst, or pytorch-lightning, or even ignite,
as they will prevent you from those mistakes.

@Alex-Mathai-98
Copy link

@arquolo - I was hoping to understand the changes you have suggested to @thomwolf's code. Are these changes optimization changes (drop immediate buffers) or do these changes actually affect the value of the gradient ? I apologise in advance if this a rudimentary question.

@arquolo
Copy link

arquolo commented Jan 16, 2023

@Alex-Mathai-98 W.r.t. @thomwolf's code my changes are optimization to not keep gradient-linked data when it's not needed anymore and they don't affect the value of the gradient.
The main difference is detach() calls to not keep gradient used by last backward after end of the loop or in-between of iterations.

@Alex-Mathai-98
Copy link

@arquolo - okay thankyou for the clarification.

@AmmarRashed
Copy link

AmmarRashed commented Jan 16, 2024

@thomwolf Thanks for the code. Just a little fix in the condition at line 7:

if (i+1) % accumulation_steps == 0:

This assumes that the number of batches is perfectly divisble by the accumulation steps. However, if there are, say, 10, batches, and the accumulation steps are 4, the last two batches would not make to the optimizer.step().

And even if you add an extra condition, you will still need to adjust the normalization denominator because there will be only 2 not 4 accumulation steps.

So the updated code would be:

n = len(training_set)
remainder_batches = n % accumulation_steps # calculate number of remainder batches

for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)
    loss = loss_function(predictions, labels)
    remaining = n - i

    # update the denominator if the remaining batches are leq number of remainder batches 
    denominator = remainder_batches if remaining <= remainder_batches else accumulation_steps

    loss = loss / denominator
    loss.backward()
    if (i+1) % accumulation_steps == 0 or i == n - 1: # add condition for last iteration
        optimizer.step()
        model.zero_grad()

You can emulate the logic in a standalone script as follows:

def get_value():
    return 5


n = 10
steps = 4
values = []
val = 0
remainder = n % steps
for i in range(n):
    a = get_value()
    remaining = n - i
    if remaining <= remainder:
        denom = n % steps
    else:
        denom = steps
    val += a / denom
    print(i, denom)
    if (i + 1) % steps == 0 or i == n - 1:
        values.append(val)
        val = 0
        print("update")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment