Skip to content

Instantly share code, notes, and snippets.

@eavae
Last active April 22, 2024 08:56
Show Gist options
  • Save eavae/d240a569c50bc8994b89ddd2f4e23e16 to your computer and use it in GitHub Desktop.
Save eavae/d240a569c50bc8994b89ddd2f4e23e16 to your computer and use it in GitHub Desktop.
A simple example of pytorch DDP.
import os
import torch
import torch.distributed as distributed
import torch.multiprocessing as multiprocessing
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
def example(rank, world_size):
batch_size = 8
dim_in = 4
dim_out = 2
distributed.init_process_group("nccl", rank=rank, world_size=world_size)
device = torch.device(rank)
net = nn.Linear(dim_in, dim_out).to(device)
print(f"Model Initialized: Rank{rank} has weights {net.weight}")
model = DistributedDataParallel(net, device_ids=[device])
print(f"DDP Initialized: Rank{rank} has weights {net.weight}")
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
# forward pass
outputs = model(torch.randn(batch_size, dim_in).to(device))
labels = torch.randn(batch_size, dim_out).to(device)
# backward pass
loss_fn(outputs, labels).backward()
optimizer.step()
print(f"Grad Synced: Rank{rank} has grad: {net.weight.grad}")
if __name__ == "__main__":
world_size = 2
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
multiprocessing.spawn(example, args=(world_size,), nprocs=world_size, join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment