Created
March 4, 2024 21:03
-
-
Save sean-smith/15980ec0a19109e2778f6540005c896c to your computer and use it in GitHub Desktop.
This is a fork of Meta's torch_distributed.py that works on SageMaker HyperPod
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
import os | |
import sys | |
import time | |
import torch | |
import submitit | |
NUM_NODES = 2 | |
NUM_TASKS_PER_NODE = 8 | |
NUM_CPUS_PER_TASK = 1 | |
PARTITION = "dev" | |
LOGS_DIR = "logs" | |
def print_env(): | |
for key in sorted(os.environ.keys()): | |
if not ( | |
key.startswith(("SLURM_", "SUBMITIT_")) | |
or key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE", "LOCAL_RANK", "LOCAL_WORLD_SIZE") | |
): | |
continue | |
value = os.environ[key] | |
print(f"{key}={value}") | |
class Task: | |
def __call__(self): | |
# print_env() | |
print("exporting PyTorch distributed environment variables") | |
dist_env = submitit.helpers.TorchDistributedEnvironment().export() | |
print(f"master: {dist_env.master_addr}:{dist_env.master_port}") | |
print(f"rank: {dist_env.rank}") | |
print(f"world size: {dist_env.world_size}") | |
print(f"local rank: {dist_env.local_rank}") | |
print(f"local world size: {dist_env.local_world_size}") | |
# print_env() | |
# Using the (default) env:// initialization method | |
torch.distributed.init_process_group(backend="nccl") | |
assert dist_env.rank == torch.distributed.get_rank() | |
assert dist_env.world_size == torch.distributed.get_world_size() | |
# Actual task / computation | |
tensor = dist_env.rank * torch.ones(1).cuda() | |
time.sleep(120) | |
torch.distributed.all_reduce(tensor) | |
if dist_env.rank == 0: | |
result = list(tensor) | |
print(result) | |
return result | |
def checkpoint(self): | |
print("checkpointing") | |
return submitit.helpers.DelayedSubmission(self) | |
def main(): | |
executor = submitit.AutoExecutor(folder=LOGS_DIR) | |
executor.update_parameters( | |
nodes=NUM_NODES, | |
tasks_per_node=NUM_TASKS_PER_NODE, | |
cpus_per_task=NUM_CPUS_PER_TASK, | |
slurm_partition=PARTITION, | |
) | |
task = Task() | |
job = executor.submit(task) | |
submitit.helpers.monitor_jobs([job]) | |
print(job.results()[0]) | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main()) |
awesome!!
FYI failing with:
----------------------
Traceback (most recent call last):
File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/submitit/core/submission.py", line 55, in process_job
result = delayed.result()
File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/submitit/core/utils.py", line 133, in result
self._result = self.function(*self.args, **self.kwargs)
File "/fsx/ubuntu/torch_distributed.py", line 50, in __call__
torch.distributed.init_process_group(backend="nccl")
File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 74, in wrapper
func_return = func(*args, **kwargs)
File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1148, in init_process_group
default_pg, _ = _new_process_group_helper(
File "/fsx/ubuntu/awsome-distributed-training/3.test_cases/13.SM-dataparallel-deepspeed/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1268, in _new_process_group_helper
raise RuntimeError("Distributed package doesn't have NCCL built in")
RuntimeError: Distributed package doesn't have NCCL built in
I think pytorch needs to have "distributed" enabled somehow.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
To run this:
Note you need at least 2 instances in your cluster.