Skip to content

Instantly share code, notes, and snippets.

@mcarilli
Last active June 30, 2023 12:21
Show Gist options
  • Save mcarilli/bf013d2d2f4b4dd21ade30c9b52d5e2e to your computer and use it in GitHub Desktop.
Save mcarilli/bf013d2d2f4b4dd21ade30c9b52d5e2e to your computer and use it in GitHub Desktop.
Minimal example of gradient accumulation, allreducing only on step() iterations and interacting properly with torch.cuda.amp
# For single-node, run this script via
# python -m torch.distributed.launch --nproc_per_node=<ngpus this node> example.py
#
# For multinode, see https://pytorch.org/docs/stable/distributed.html#launch-utility
#
# Example showing native mixed precision tools
# (torch.cuda.amp.GradScaler and torch.cuda.amp.autocast)
# used along with native DistributedDataParallel to perform
# gradient accumulation with allreduces only when stepping.
#
# The key takeway is, each of those tools is used orthogonally
# (just as it would be in the absence of the others).
# There are no gotchas combining them.
import torch
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
# Fake data (different in each process)
torch.manual_seed(args.local_rank)
N, D_in, D_out = 64, 1024, 16
x = torch.randn(N, D_in, device='cuda')
y = torch.randn(N, D_out, device='cuda')
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
loss_fn = torch.nn.MSELoss()
scaler = torch.cuda.amp.GradScaler()
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank)
iters_to_accumulate = 4
# help print gradient values for debugging
# torch.set_printoptions(precision=10)
# def debug_grad_info(t, should_match_across_ranks):
# string = ""
# for name, param in model.named_parameters():
# string += "iter = {}, rank = {}, should match across ranks = {}, {}.grad sum = {}\n".format(
# t, args.local_rank, should_match_across_ranks, name, param.grad.double().sum().item())
# print(string, flush=True)
def run_fwd_bwd():
# Runs forward pass under autocast.
with torch.cuda.amp.autocast():
y_pred = model(x)
# You may wish to divide loss by iters_to_accumulate to average
# across the effective (accumulated) global batch.
loss = loss_fn(y_pred, y)/iters_to_accumulate
scaler.scale(loss).backward()
for t in range(20):
if (t + 1) % iters_to_accumulate == 0:
# We will step() this iteration, so don't run forward and backward under no_sync.
# Allow allreduces to happen.
run_fwd_bwd()
# Grads DO match across ranks at this point, ready to step
# debug_grad_info(t, True)
scaler.step(optimizer)
optimizer.zero_grad(set_to_none=True)
# Only call scaler.update() for iterations where we actually step()ed, as in
# https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation
scaler.update()
else:
# We're not stepping this iteration, so use no_sync to prevent DDP allreduces.
# It appears we need to run forward and backward under no_sync()
# to get the right no-allreduce behavior.
with model.no_sync():
run_fwd_bwd()
# Grads don't match across ranks at this point.
# debug_grad_info(t, False)
# double-check that param values are identical across ranks
string = ""
for name, param in model.named_parameters():
string += "rank = {}, {} sum = {}\n".format(
args.local_rank, name, param.double().sum().item())
print(string, flush=True)
@Shamdan17
Copy link

Thanks for this gist!

Is the goal of using model.no_sync() to avoid overhead from synchronizing gradients if we're not performing an update in this step? That is what I understood from the documentation here.

If so, this should not affect the syncronization of batchnorm in the forward pass in SyncBatchNorm, is that correct?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment