Skip to content

Instantly share code, notes, and snippets.

@xmfan
Created January 12, 2024 19:21
Show Gist options
  • Save xmfan/8dcbd9629e575d91e3ddb90c5c4e80f8 to your computer and use it in GitHub Desktop.
Save xmfan/8dcbd9629e575d91e3ddb90c5c4e80f8 to your computer and use it in GitHub Desktop.
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch._dynamo.utils import maybe_enable_compiled_autograd
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
model = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.Linear(10, 1)
)
model = model.to(torch.device("cuda"))
model = torch.compile(model, backend="inductor")
model = DDP(model)
x = torch.randn(10, requires_grad=True, device="cuda")
with maybe_enable_compiled_autograd(True):
out = model(x)
loss = out.sum()
if rank == 0:
breakpoint()
dist.barrier()
loss.backward()
if rank == 0:
breakpoint()
dist.barrier()
opt = torch.optim.SGD(model.parameters(), lr=0.01)
opt.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment