Skip to content

Instantly share code, notes, and snippets.

@zhuangh
Last active May 4, 2024 23:51
Show Gist options
  • Save zhuangh/176119998d615bc4eeb96659fd21f23f to your computer and use it in GitHub Desktop.
Save zhuangh/176119998d615bc4eeb96659fd21f23f to your computer and use it in GitHub Desktop.
single_gpu_ddp.py
# python single_gpu_ddp.py
# https://discuss.pytorch.org/t/single-machine-single-gpu-distributed-best-practices/169243
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import os
def setup(rank, world_size):
# Configure the distributed environment.
# 'gloo' can be used in environments where 'nccl' is not supported, like on CPUs.
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
#dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
# Cleanup the distributed environment.
dist.destroy_process_group()
def example(rank, world_size, device="cpu"):
setup(rank, world_size)
# Create a simple model.
model = nn.Linear(10, 1)
# Move model to the specified device
model.to(device)
# Wrap the model in DistributedDataParallel using the CPU (or specific device).
ddp_model = DDP(model, device_ids=[device])
# Create some dummy input data suitable for the model dimensions, distributed to the appropriate device.
inputs = torch.randn(64, 10).to(device)
targets = torch.randn(64, 1).to(device)
# Forward pass
outputs = ddp_model(inputs)
# Compute the loss
loss_fn = nn.MSELoss()
loss = loss_fn(outputs, targets)
# Backward pass
loss.backward()
# Aggregate gradients using all_reduce with sum operation
for param in ddp_model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
# Scale the gradients by the number of processes
for param in ddp_model.parameters():
param.grad.data /= world_size
# Update the model parameters
with torch.no_grad():
for param in ddp_model.parameters():
param.data -= 0.01 * param.grad.data # Assume learning rate of 0.01
print(loss.item())
cleanup()
def main():
# Set the number of processes to the number of CPUs available or any specific number you want to use
world_size = 2 # Example: use 4 processes
#device = "cuda:0"
device = torch.device('cuda:0')
mp.spawn(example, args=(world_size, device,), nprocs=world_size, join=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment