Skip to content

Instantly share code, notes, and snippets.

@d4l3k
Created August 1, 2022 18:43
Show Gist options
  • Save d4l3k/8abe20f855b1823d9386ad276ed35f60 to your computer and use it in GitHub Desktop.
Save d4l3k/8abe20f855b1823d9386ad276ed35f60 to your computer and use it in GitHub Desktop.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
For distributed training, TorchX relies on the scheduler's gang scheduling
capabilities to schedule ``n`` copies of nodes. Once launched, the application
is expected to be written in a way that leverages this topology, for instance,
with PyTorch's
`DDP <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`_.
You can express a variety of node topologies with TorchX by specifying multiple
:py:class:`torchx.specs.Role` in your component's AppDef. Each role maps to
a homogeneous group of nodes that performs a "role" (function) in the overall
training. Scheduling-wise, TorchX launches each role as a sub-gang.
A DDP-style training job has a single role: trainers. Whereas a
training job that uses parameter servers will have two roles: parameter server, trainer.
You can specify different entrypoint (executable), num replicas, resource requirements,
and more for each role.
DDP Builtin
----------------
DDP-style trainers are common and easy to templetize since they are homogeneous
single role AppDefs, so there is a builtin: ``dist.ddp``. Assuming your DDP
training script is called ``main.py``, launch it as:
.. code:: shell-session
# locally, 1 node x 4 workers
$ torchx run -s local_cwd dist.ddp -j 1x4 --script main.py
# locally, 2 node x 4 workers (8 total)
$ torchx run -s local_cwd dist.ddp -j 2x4 --script main.py
# remote (optionally pass --rdzv_port to use a different master port than the default 29500)
$ torchx run -s kubernetes -cfg queue=default dist.ddp \\
-j 2x4 \\
--script main.py \\
Note that the only difference compared to the local launch is the scheduler (``-s``).
The ``dist.ddp`` builtin uses ``torchelastic`` (more specifically ``torch.distributed.run``)
under the hood. Read more about torchelastic `here <https://pytorch.org/docs/stable/elastic/run.html>`_.
Components APIs
-----------------
"""
import os
import shlex
from pathlib import Path
from typing import Dict, Iterable, List, Optional
import torchx
import torchx.specs as specs
from torchx.specs import macros
_TORCH_DEBUG_FLAGS: Dict[str, str] = {
"CUDA_LAUNCH_BLOCKING": "1",
"NCCL_DESYNC_DEBUG": "1",
"TORCH_DISTRIBUTED_DEBUG": "DETAIL",
"TORCH_SHOW_CPP_STACKTRACES": "1",
}
"""
These are commonly set environment variables to debug PyTorch execution.
* ``CUDA_LAUNCH_BLOCKING``: Read more `here <https://docs.nvidia.com/cuda/cuda-gdb/index.html#set-cuda-launch-blocking>`__.
* ``NCCL_DESYNC_DEBUG``
* ``TORCH_DISTRIBUTED_DEBUG``: Read more `here <https://pytorch.org/docs/stable/distributed.html#torch-distributed-debug>`__.
* ``TORCH_SHOW_CPP_STACKTRACES``: Read more `here <https://pytorch.org/docs/stable/distributed.html#torch-distributed-debug>`__.
"""
def ddp(
*script_args: str,
script: Optional[str] = None,
m: Optional[str] = None,
image: str = torchx.IMAGE,
name: Optional[str] = None,
h: Optional[str] = None,
cpu: int = 2,
gpu: int = 0,
memMB: int = 1024,
j: str = "1x2",
env: Optional[Dict[str, str]] = None,
max_retries: int = 0,
rdzv_port: int = 29500,
mounts: Optional[List[str]] = None,
debug: bool = False,
) -> specs.AppDef:
"""
Distributed data parallel style application (one role, multi-replica).
Uses `torch.distributed.run <https://pytorch.org/docs/stable/distributed.elastic.html>`_
to launch and coordinate PyTorch worker processes. Defaults to using ``c10d`` rendezvous backend
on rendezvous_endpoint ``$rank_0_host:$rdzv_port``. Note that ``rdzv_port`` parameter is ignored
when running on single node, and instead we use port 0 which instructs torchelastic to chose
a free random port on the host.
Note: (cpu, gpu, memMB) parameters are mutually exclusive with ``h`` (named resource) where
``h`` takes precedence if specified for setting resource requirements.
See `registering named resources <https://pytorch.org/torchx/latest/advanced.html#registering-named-resources>`_.
Args:
script_args: arguments to the main module
script: script or binary to run within the image
m: the python module path to run
image: image (e.g. docker)
name: job name override (uses the script name if not specified)
cpu: number of cpus per replica
gpu: number of gpus per replica
memMB: cpu memory in MB per replica
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
j: {nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
max_retries: the number of scheduler retries allowed
rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous.
Only takes effect when running multi-node. When running single node, this parameter
is ignored and a random free port is chosen.
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
See scheduler documentation for more info.
debug: whether to run with preset debug flags enabled
"""
if (script is None) == (m is None):
raise ValueError("exactly one of --script and -m must be specified")
rep = j.split("x")
if len(rep) == 1: # num replicas only
nnodes = 1
nproc_per_node = int(rep[0])
elif len(rep) == 2:
nnodes = int(rep[0])
nproc_per_node = int(rep[1])
else:
raise ValueError(f"Invalid format for -j, usage example: 1x4. Given: {j}")
if script:
# script name/module no extension
role_name = Path(script).stem
elif m:
role_name = m.rpartition(".")[2]
else:
raise ValueError("failed to compute role_name")
rdzv_backend = "c10d"
rdzv_backend = "static"
if nnodes == 1:
# using port 0 makes elastic chose a free random port which is ok
# for single-node jobs since all workers run under a single agent
rdzv_endpoint = "localhost:0"
else:
rdzv_endpoint = _noquote(f"$${macros.rank0_env}:{rdzv_port}")
if env is None:
env = {}
env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING"))
if debug:
env.update(_TORCH_DEBUG_FLAGS)
cmd = [
"python",
"-m",
"torch.distributed.run",
"--rdzv_backend",
rdzv_backend,
"--rdzv_endpoint",
rdzv_endpoint,
"--rdzv_id",
f"{macros.app_id}",
"--nnodes",
str(nnodes),
"--nproc_per_node",
str(nproc_per_node),
"--node_rank", macros.replica_id,
"--tee",
"3",
"--role",
"",
]
if script is not None:
cmd += [script]
elif m is not None:
cmd += ["-m", m]
cmd += script_args
return specs.AppDef(
name=name or role_name,
roles=[
specs.Role(
name=role_name,
image=image,
entrypoint="bash",
num_replicas=nnodes,
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
args=["-c", _args_join(cmd)],
env=env,
port_map={
"c10d": 29500,
},
max_retries=max_retries,
mounts=specs.parse_mounts(mounts) if mounts else [],
)
],
)
def _args_join(args: Iterable[str]) -> str:
"""
_args_join is like shlex.join but if the argument is wrapped in _noquote
it'll not quote that argument.
"""
quoted = [arg if isinstance(arg, _noquote) else shlex.quote(arg) for arg in args]
return " ".join(quoted)
class _noquote(str):
"""
_noquote is a wrapper around str that indicates that the argument shouldn't
be passed through shlex.quote.
"""
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment