Skip to content

Instantly share code, notes, and snippets.

@albertbuchard
Created September 11, 2023 13:35
Show Gist options
  • Save albertbuchard/08ae0f608e9831b5e46f2f06c2699c1f to your computer and use it in GitHub Desktop.
Save albertbuchard/08ae0f608e9831b5e46f2f06c2699c1f to your computer and use it in GitHub Desktop.
Distributed Data Parallel Debugging - Pytorch / PDB - ddp_pdb()
class DistributedIdentity:
"""
Singleton class to hold distributed identity information.
Handles SLURM, torchrun, and local runs.
Looks for the following environment variables:
- RANK
- WORLD_SIZE
- LOCAL_RANK
- SLURM_PROCID
- SLURM_NTASKS
- SLURM_LOCALID
- SLURM_GPUS_PER_NODE
- SLURMD_NODENAME
- SLURM_CPUS_PER_TASK
- GROUP_RANK
- LOCAL_WORLD_SIZE
- TORCHELASTIC_RESTART_COUNT
- TORCHELASTIC_MAX_RESTARTS
- TORCHELASTIC_RUN_ID
- MASTER_ADDR
- MASTER_PORT
If SLURM_JOB_ID is set, then the job_id will be set to that value.
Otherwise, the job_id will be set to "local-" + str(os.getpid())
If RANK is set, then rank will be set to that value.
Otherwise, if SLURM_PROCID is set, then rank will be set to that value.
Otherwise, rank will be set to 0.
If WORLD_SIZE is set, then world_size will be set to that value.
Otherwise, if SLURM_NTASKS is set, then world_size will be set to that value.
Otherwise, world_size will be set to 1.
If LOCAL_RANK is set, then local_rank will be set to that value.
Otherwise, if SLURM_LOCALID is set, then local_rank will be set to that value.
If SLURMD_NODENAME is set, then nodename will be set to that value.
Otherwise, nodename will be set to the result of socket.gethostname()
"""
_instance = None
rank = None
gpus_per_node = None
local_rank = None
world_size = None
nodename = None
cpu_per_task = None
local_world_size = None
master_addr = None
master_port = None
is_torchelastic = None
torch_restart_count = None
torch_max_restarts = None
torch_runid = None
group_rank = None
job_id = None
def __new__(cls, *args, **kwargs):
if not isinstance(cls._instance, cls):
cls._instance = super(DistributedIdentity, cls).__new__(cls, *args, **kwargs)
if os.environ.get("SLURM_JOB_ID", None) is not None:
cls.job_id = str(os.environ["SLURM_JOB_ID"])
else:
cls.job_id = "local-" + str(os.getpid())
if os.environ.get("RANK", None) is not None:
cls.rank = int(os.environ["RANK"])
elif os.environ.get("SLURM_PROCID", None) is not None:
cls.rank = int(os.environ["SLURM_PROCID"])
else:
cls.rank = 0
if os.environ.get("WORLD_SIZE", None) is not None:
cls.world_size = int(os.environ["WORLD_SIZE"])
elif os.environ.get("SLURM_NTASKS", None) is not None:
cls.world_size = int(os.environ["SLURM_NTASKS"])
else:
cls.world_size = 1
if os.environ.get("LOCAL_RANK", None) is not None:
cls.local_rank = int(os.environ["LOCAL_RANK"])
elif os.environ.get("SLURM_LOCALID", None) is not None:
cls.local_rank = int(os.environ["SLURM_LOCALID"])
elif cls.rank is not None and cls.world_size is not None:
cls.local_rank = cls.rank % cls.world_size
if os.environ.get("SLURM_GPUS_PER_NODE", None) is not None:
cls.gpus_per_node = int(os.environ["SLURM_GPUS_PER_NODE"])
if os.environ.get('SLURMD_NODENAME', None) is not None:
cls.nodename = os.environ["SLURMD_NODENAME"]
else:
cls.nodename = gethostname()
if os.environ.get("GROUP_RANK", None) is not None:
cls.group_rank = int(os.environ["GROUP_RANK"])
if os.environ.get("LOCAL_WORLD_SIZE", None) is not None:
cls.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
if os.environ.get("TORCHELASTIC_RESTART_COUNT", None) is not None:
cls.torch_restart_count = int(os.environ["TORCHELASTIC_RESTART_COUNT"])
if os.environ.get("TORCHELASTIC_MAX_RESTARTS", None) is not None:
cls.torch_max_restarts = int(os.environ["TORCHELASTIC_MAX_RESTARTS"])
cls.is_torchelastic = False
if (os.environ.get("TORCHELASTIC_RUN_ID", None) is not None and
os.environ.get("TORCHELASTIC_RUN_ID", None) != "" and
os.environ.get("TORCHELASTIC_RUN_ID", None) != "none"):
cls.is_torchelastic = True
print("TORCHELASTIC_RUN_ID found. Setting torch_runid to that value.",
os.environ["TORCHELASTIC_RUN_ID"])
cls.torch_runid = int(os.environ["TORCHELASTIC_RUN_ID"])
if os.environ.get("SLURM_CPUS_PER_TASK", None) is not None:
cls.cpu_per_task = int(os.environ["SLURM_CPUS_PER_TASK"])
elif cls.is_torchelastic:
print("SLURM_CPUS_PER_TASK not found. Using max(1, (os.cpu_count() // cls.world_size) - 1)")
print(f"os.cpu_count(): {os.cpu_count()}")
print(f"cls.world_size: {cls.world_size}")
print(
f"max(1, (os.cpu_count() // cls.world_size) - 1): {max(1, (os.cpu_count() // cls.world_size) - 1)}")
cls.cpu_per_task = max(1, (os.cpu_count() // cls.world_size) - 1)
cls.master_addr = os.environ.get('MASTER_ADDR', None)
cls.master_port = os.environ.get('MASTER_PORT', None)
return cls._instance
@property
def ddp_available(self):
return (dist.is_available() and self.world_size > 1 and
(self.is_torchelastic or (self.master_addr is not None and self.master_port is not None)))
@property
def is_slurm(self):
return os.environ.get("SLURM_JOB_ID", None) is not None
@property
def available_cpu_count(self):
if self.cpu_per_task is not None:
return self.cpu_per_task
return min(os.cpu_count(), multiprocessing.cpu_count())
def __str__(self):
return f"Rank {self.rank} (local rank {self.local_rank}) of {self.world_size} on {self.nodename}"
def __repr__(self):
return self.__str__()
dist_identity = DistributedIdentity()
def safe_barrier():
"""
Executes a barrier only if distributed is initialized.
"""
if dist.is_initialized():
dist.barrier()
class CustomPdb(pdb.Pdb):
"""
Custom debugger with frame skipping capability.
"""
def __init__(self, skip_frames=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.skip_frames = skip_frames
def interaction(self, frame, traceback):
for _ in range(self.skip_frames):
frame = frame.f_back
if not frame:
break
super().interaction(frame, traceback)
def ddp_pdb(rank=0):
"""
Distributed debugger function.
"""
if dist_identity.rank == rank:
debugger = CustomPdb(skip_frames=1)
debugger.set_trace()
safe_barrier()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment