Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active December 13, 2022 19:15
Show Gist options
  • Save thomwolf/387ea8c8f24290fc8f55050af089ac47 to your computer and use it in GitHub Desktop.
Save thomwolf/387ea8c8f24290fc8f55050af089ac47 to your computer and use it in GitHub Desktop.
Using DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
# Each process runs on 1 GPU device specified by the local_rank argument.
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
# Encapsulate the model on the GPU assigned to the current process
device = torch.device('cuda', arg.local_rank)
model = model.to(device)
distrib_model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank)
# Restricts data loading to a subset of the dataset exclusive to the current process
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler)
for inputs, labels in dataloader:
predictions = distrib_model(inputs.to(device)) # Forward pass
loss = loss_function(predictions, labels.to(device)) # Compute loss function
loss.backward() # Backward pass
optimizer.step() # Optimizer step
@monajalal
Copy link

Hi Tom, should we use DistributedSampler if we have 1 node with 8 GPUs and 24 CPUs when using DDP or is it only for when we have more than 1 nodes? If we should use it, what would happen if we use normal DataLoader without a DistributedSampler? Thanks a lot for the code snippet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment