Created
September 11, 2023 13:35
-
-
Save albertbuchard/08ae0f608e9831b5e46f2f06c2699c1f to your computer and use it in GitHub Desktop.
Distributed Data Parallel Debugging - Pytorch / PDB - ddp_pdb()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DistributedIdentity: | |
""" | |
Singleton class to hold distributed identity information. | |
Handles SLURM, torchrun, and local runs. | |
Looks for the following environment variables: | |
- RANK | |
- WORLD_SIZE | |
- LOCAL_RANK | |
- SLURM_PROCID | |
- SLURM_NTASKS | |
- SLURM_LOCALID | |
- SLURM_GPUS_PER_NODE | |
- SLURMD_NODENAME | |
- SLURM_CPUS_PER_TASK | |
- GROUP_RANK | |
- LOCAL_WORLD_SIZE | |
- TORCHELASTIC_RESTART_COUNT | |
- TORCHELASTIC_MAX_RESTARTS | |
- TORCHELASTIC_RUN_ID | |
- MASTER_ADDR | |
- MASTER_PORT | |
If SLURM_JOB_ID is set, then the job_id will be set to that value. | |
Otherwise, the job_id will be set to "local-" + str(os.getpid()) | |
If RANK is set, then rank will be set to that value. | |
Otherwise, if SLURM_PROCID is set, then rank will be set to that value. | |
Otherwise, rank will be set to 0. | |
If WORLD_SIZE is set, then world_size will be set to that value. | |
Otherwise, if SLURM_NTASKS is set, then world_size will be set to that value. | |
Otherwise, world_size will be set to 1. | |
If LOCAL_RANK is set, then local_rank will be set to that value. | |
Otherwise, if SLURM_LOCALID is set, then local_rank will be set to that value. | |
If SLURMD_NODENAME is set, then nodename will be set to that value. | |
Otherwise, nodename will be set to the result of socket.gethostname() | |
""" | |
_instance = None | |
rank = None | |
gpus_per_node = None | |
local_rank = None | |
world_size = None | |
nodename = None | |
cpu_per_task = None | |
local_world_size = None | |
master_addr = None | |
master_port = None | |
is_torchelastic = None | |
torch_restart_count = None | |
torch_max_restarts = None | |
torch_runid = None | |
group_rank = None | |
job_id = None | |
def __new__(cls, *args, **kwargs): | |
if not isinstance(cls._instance, cls): | |
cls._instance = super(DistributedIdentity, cls).__new__(cls, *args, **kwargs) | |
if os.environ.get("SLURM_JOB_ID", None) is not None: | |
cls.job_id = str(os.environ["SLURM_JOB_ID"]) | |
else: | |
cls.job_id = "local-" + str(os.getpid()) | |
if os.environ.get("RANK", None) is not None: | |
cls.rank = int(os.environ["RANK"]) | |
elif os.environ.get("SLURM_PROCID", None) is not None: | |
cls.rank = int(os.environ["SLURM_PROCID"]) | |
else: | |
cls.rank = 0 | |
if os.environ.get("WORLD_SIZE", None) is not None: | |
cls.world_size = int(os.environ["WORLD_SIZE"]) | |
elif os.environ.get("SLURM_NTASKS", None) is not None: | |
cls.world_size = int(os.environ["SLURM_NTASKS"]) | |
else: | |
cls.world_size = 1 | |
if os.environ.get("LOCAL_RANK", None) is not None: | |
cls.local_rank = int(os.environ["LOCAL_RANK"]) | |
elif os.environ.get("SLURM_LOCALID", None) is not None: | |
cls.local_rank = int(os.environ["SLURM_LOCALID"]) | |
elif cls.rank is not None and cls.world_size is not None: | |
cls.local_rank = cls.rank % cls.world_size | |
if os.environ.get("SLURM_GPUS_PER_NODE", None) is not None: | |
cls.gpus_per_node = int(os.environ["SLURM_GPUS_PER_NODE"]) | |
if os.environ.get('SLURMD_NODENAME', None) is not None: | |
cls.nodename = os.environ["SLURMD_NODENAME"] | |
else: | |
cls.nodename = gethostname() | |
if os.environ.get("GROUP_RANK", None) is not None: | |
cls.group_rank = int(os.environ["GROUP_RANK"]) | |
if os.environ.get("LOCAL_WORLD_SIZE", None) is not None: | |
cls.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) | |
if os.environ.get("TORCHELASTIC_RESTART_COUNT", None) is not None: | |
cls.torch_restart_count = int(os.environ["TORCHELASTIC_RESTART_COUNT"]) | |
if os.environ.get("TORCHELASTIC_MAX_RESTARTS", None) is not None: | |
cls.torch_max_restarts = int(os.environ["TORCHELASTIC_MAX_RESTARTS"]) | |
cls.is_torchelastic = False | |
if (os.environ.get("TORCHELASTIC_RUN_ID", None) is not None and | |
os.environ.get("TORCHELASTIC_RUN_ID", None) != "" and | |
os.environ.get("TORCHELASTIC_RUN_ID", None) != "none"): | |
cls.is_torchelastic = True | |
print("TORCHELASTIC_RUN_ID found. Setting torch_runid to that value.", | |
os.environ["TORCHELASTIC_RUN_ID"]) | |
cls.torch_runid = int(os.environ["TORCHELASTIC_RUN_ID"]) | |
if os.environ.get("SLURM_CPUS_PER_TASK", None) is not None: | |
cls.cpu_per_task = int(os.environ["SLURM_CPUS_PER_TASK"]) | |
elif cls.is_torchelastic: | |
print("SLURM_CPUS_PER_TASK not found. Using max(1, (os.cpu_count() // cls.world_size) - 1)") | |
print(f"os.cpu_count(): {os.cpu_count()}") | |
print(f"cls.world_size: {cls.world_size}") | |
print( | |
f"max(1, (os.cpu_count() // cls.world_size) - 1): {max(1, (os.cpu_count() // cls.world_size) - 1)}") | |
cls.cpu_per_task = max(1, (os.cpu_count() // cls.world_size) - 1) | |
cls.master_addr = os.environ.get('MASTER_ADDR', None) | |
cls.master_port = os.environ.get('MASTER_PORT', None) | |
return cls._instance | |
@property | |
def ddp_available(self): | |
return (dist.is_available() and self.world_size > 1 and | |
(self.is_torchelastic or (self.master_addr is not None and self.master_port is not None))) | |
@property | |
def is_slurm(self): | |
return os.environ.get("SLURM_JOB_ID", None) is not None | |
@property | |
def available_cpu_count(self): | |
if self.cpu_per_task is not None: | |
return self.cpu_per_task | |
return min(os.cpu_count(), multiprocessing.cpu_count()) | |
def __str__(self): | |
return f"Rank {self.rank} (local rank {self.local_rank}) of {self.world_size} on {self.nodename}" | |
def __repr__(self): | |
return self.__str__() | |
dist_identity = DistributedIdentity() | |
def safe_barrier(): | |
""" | |
Executes a barrier only if distributed is initialized. | |
""" | |
if dist.is_initialized(): | |
dist.barrier() | |
class CustomPdb(pdb.Pdb): | |
""" | |
Custom debugger with frame skipping capability. | |
""" | |
def __init__(self, skip_frames=0, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.skip_frames = skip_frames | |
def interaction(self, frame, traceback): | |
for _ in range(self.skip_frames): | |
frame = frame.f_back | |
if not frame: | |
break | |
super().interaction(frame, traceback) | |
def ddp_pdb(rank=0): | |
""" | |
Distributed debugger function. | |
""" | |
if dist_identity.rank == rank: | |
debugger = CustomPdb(skip_frames=1) | |
debugger.set_trace() | |
safe_barrier() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment