Skip to content

Instantly share code, notes, and snippets.

@sean-smith
Created March 4, 2024 21:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sean-smith/15980ec0a19109e2778f6540005c896c to your computer and use it in GitHub Desktop.
Save sean-smith/15980ec0a19109e2778f6540005c896c to your computer and use it in GitHub Desktop.
This is a fork of Meta's torch_distributed.py that works on SageMaker HyperPod
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import sys
import time
import torch
import submitit
NUM_NODES = 2
NUM_TASKS_PER_NODE = 8
NUM_CPUS_PER_TASK = 1
PARTITION = "dev"
LOGS_DIR = "logs"
def print_env():
for key in sorted(os.environ.keys()):
if not (
key.startswith(("SLURM_", "SUBMITIT_"))
or key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
):
continue
value = os.environ[key]
print(f"{key}={value}")
class Task:
def __call__(self):
# print_env()
print("exporting PyTorch distributed environment variables")
dist_env = submitit.helpers.TorchDistributedEnvironment().export()
print(f"master: {dist_env.master_addr}:{dist_env.master_port}")
print(f"rank: {dist_env.rank}")
print(f"world size: {dist_env.world_size}")
print(f"local rank: {dist_env.local_rank}")
print(f"local world size: {dist_env.local_world_size}")
# print_env()
# Using the (default) env:// initialization method
torch.distributed.init_process_group(backend="nccl")
assert dist_env.rank == torch.distributed.get_rank()
assert dist_env.world_size == torch.distributed.get_world_size()
# Actual task / computation
tensor = dist_env.rank * torch.ones(1).cuda()
time.sleep(120)
torch.distributed.all_reduce(tensor)
if dist_env.rank == 0:
result = list(tensor)
print(result)
return result
def checkpoint(self):
print("checkpointing")
return submitit.helpers.DelayedSubmission(self)
def main():
executor = submitit.AutoExecutor(folder=LOGS_DIR)
executor.update_parameters(
nodes=NUM_NODES,
tasks_per_node=NUM_TASKS_PER_NODE,
cpus_per_task=NUM_CPUS_PER_TASK,
slurm_partition=PARTITION,
)
task = Task()
job = executor.submit(task)
submitit.helpers.monitor_jobs([job])
print(job.results()[0])
return 0
if __name__ == "__main__":
sys.exit(main())
@sean-smith
Copy link
Author

To run this:

python3 torch_distributed.py

Note you need at least 2 instances in your cluster.

@cfregly
Copy link

cfregly commented Mar 7, 2024

awesome!!

@sean-smith
Copy link
Author

FYI failing with:

----------------------
Traceback (most recent call last):
  File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/submitit/core/submission.py", line 55, in process_job
    result = delayed.result()
  File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/submitit/core/utils.py", line 133, in result
    self._result = self.function(*self.args, **self.kwargs)
  File "/fsx/ubuntu/torch_distributed.py", line 50, in __call__
    torch.distributed.init_process_group(backend="nccl")
  File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 74, in wrapper
    func_return = func(*args, **kwargs)
  File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1148, in init_process_group
    default_pg, _ = _new_process_group_helper(
  File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1268, in _new_process_group_helper
    raise RuntimeError("Distributed package doesn't have NCCL built in")
RuntimeError: Distributed package doesn't have NCCL built in

I think pytorch needs to have "distributed" enabled somehow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment