Skip to content

Instantly share code, notes, and snippets.

@rohan-varma
Last active October 2, 2020 21:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-varma/2ff1d6051440d2c18e96fe57904b55d9 to your computer and use it in GitHub Desktop.
Save rohan-varma/2ff1d6051440d2c18e96fe57904b55d9 to your computer and use it in GitHub Desktop.
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
print(f"trainer got local_rank {args.local_rank}")
import torch
import torch.distributed as dist
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend="nccl", init_method="env://")
print(f"trainer {args.local_rank} initialized process group")
model = torch.nn.Linear(1, 1)
model = model.to(args.local_rank)
ddp = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank]
)
print(f"trainer {args.local_rank} done initializing DDP")
import sys
print("stderr", file=sys.stderr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment