-
-
Save thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3 to your computer and use it in GitHub Desktop.
model.zero_grad() # Reset gradients tensors | |
for i, (inputs, labels) in enumerate(training_set): | |
predictions = model(inputs) # Forward pass | |
loss = loss_function(predictions, labels) # Compute loss function | |
loss = loss / accumulation_steps # Normalize our loss (if averaged) | |
loss.backward() # Backward pass | |
if (i+1) % accumulation_steps == 0: # Wait for several backward steps | |
optimizer.step() # Now we can do an optimizer step | |
model.zero_grad() # Reset gradients tensors | |
if (i+1) % evaluation_steps == 0: # Evaluate the model when we... | |
evaluate_model() # ...have no gradients accumulated |
@jaideep11061982 , @vaneshieh , this code is incorrect and has multiple mistakes.
Explanation (a bit rewritten for readability):
total_loss = 0
for i, (inputs, labels) in enumerate(training_set, 1):
# Given absense of torch.no_grad(),
# `loss` will contain all immediate states of forward pass, needed for backward computation
predictions = model(inputs)
loss = loss_function(predictions, labels)
# Summation here aggregates loss from N accumulation steps
# Thus total_loss has immediate states of NUM_ACCUMULATES forwards
# So no memory saving here, but leak
total_loss += loss
if i % NUM_ACCUMULATES == 0:
total_loss = total_loss / NUM_ACCUMULATES
# Here we do backward on loss from last step, so grads in model are influenced only by each Nth batch
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss = 0
if i % (NUM_ACCUMULATES * NUM_UPDATES)== 0:
evaluate_model()
So this code updates weights using only 1/N of training set, eating N times more memory than training with single accumulation step.
Thus it's equal to those (to make mistakes more obvious):
total_loss = 0
batches = []
for i, (inputs, labels) in enumerate(training_set, 1):
batches.append((inputs, labels))
if i % NUM_ACCUMULATES == 0:
# Store immediate buffers from NUM_ACCUMULATES forward passes
losses = [loss_function(model(inputs), labels) for inputs in batches]
total_loss = sum(losses) / NUM_ACCUMULATES
# Do backward only on last loss
losses[-1].backward()
optimizer.step()
optimizer.zero_grad()
batches = []
if i % (NUM_ACCUMULATES * NUM_UPDATES) == 0:
evaluate_model()
Original code of @thomwolf is also not without mistakes, though they are less striking.
I have rewriten it to be more error-prone:
# Zero saved gradients
optimizer.zero_grad()
total_loss = 0
for i, (inputs, labels) in enumerate(training_set, 1):
# Do forward pass and store immediate buffers
predictions = model(inputs)
loss = loss_function(predictions, labels)
# Add gradients from this batch to saved ones, divide loss by NUM_ACCUMULATES if it's averaged over samples
(loss / NUM_ACCUMULATES).backward()
# Drop immediate buffers, mandatory
# This guaranties that total loss doesn't bring some immediate buffers with it
predictions.detach_()
loss.detach_()
total_loss += loss
if i % NUM_ACCUMULATES == 0:
# Update parameters using saved gradients
optimizer.step()
# Zero saved gradients
optimizer.zero_grad()
if i % (NUM_ACCUMULATES * NUM_UPDATES) == 0:
print(f'train loss: {total_loss / NUM_ACCUMULATES / NUM_UPDATES}')
total_loss = 0
evaluate_model()
But seriously, use some more high level wrapper for Pytorch, like catalyst, or pytorch-lightning, or even ignite,
as they will prevent you from those mistakes.
@Alex-Mathai-98 W.r.t. @thomwolf's code my changes are optimization to not keep gradient-linked data when it's not needed anymore and they don't affect the value of the gradient.
The main difference is detach()
calls to not keep gradient used by last backward
after end of the loop or in-between of iterations.
@arquolo - okay thankyou for the clarification.
@thomwolf Thanks for the code. Just a little fix in the condition at line 7:
if (i+1) % accumulation_steps == 0:
This assumes that the number of batches is perfectly divisble by the accumulation steps. However, if there are, say, 10, batches, and the accumulation steps are 4, the last two batches would not make to the optimizer.step().
And even if you add an extra condition, you will still need to adjust the normalization denominator because there will be only 2 not 4 accumulation steps.
So the updated code would be:
n = len(training_set)
remainder_batches = n % accumulation_steps # calculate number of remainder batches
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs)
loss = loss_function(predictions, labels)
remaining = n - i
# update the denominator if the remaining batches are leq number of remainder batches
denominator = remainder_batches if remaining <= remainder_batches else accumulation_steps
loss = loss / denominator
loss.backward()
if (i+1) % accumulation_steps == 0 or i == n - 1: # add condition for last iteration
optimizer.step()
model.zero_grad()
You can emulate the logic in a standalone script as follows:
def get_value():
return 5
n = 10
steps = 4
values = []
val = 0
remainder = n % steps
for i in range(n):
a = get_value()
remaining = n - i
if remaining <= remainder:
denom = n % steps
else:
denom = steps
val += a / denom
print(i, denom)
if (i + 1) % steps == 0 or i == n - 1:
values.append(val)
val = 0
print("update")
Hi I'm using the similar idea. Do you have any further suggestions? Thanks