Skip to content

Instantly share code, notes, and snippets.

@anj-s
Created March 31, 2021 20:03
Show Gist options
  • Save anj-s/958a7e444100e762180bf289da8a6cab to your computer and use it in GitHub Desktop.
Save anj-s/958a7e444100e762180bf289da8a6cab to your computer and use it in GitHub Desktop.
# File used to run bytePS on a local worker with 2 GPUs
import os
import subprocess
import sys
import torch
import torch.multiprocessing as mp
def run_worker(rank, world_size):
torch.cuda.set_device(rank)
os.environ["DMLC_ROLE"] = "worker"
os.environ["DMLC_WORKER_ID"] = str(0)
os.environ["DMLC_NUM_WORKER"] = str(1)
os.environ["NVIDIA_VISIBLE_DEVICES"] = "0,1"
# os.environ["DMLC_ENABLE_RDMA"] = "ibverbs"
print(f"rank to be set {rank}")
os.environ["BYTEPS_LOG_LEVEL"] = "INFO"
os.environ["NCCL_DEBUG"] = "INFO"
# os.environ["NCCL_DEBUG"] = "WARN"
os.environ["BYTEPS_ENABLE_GDB"] = "1"
os.environ["BYTEPS_LOCAL_RANK"] = str(rank)
os.environ["BYTEPS_LOCAL_SIZE"] = str(world_size)
# Don't matter
os.environ["DMLC_NUM_SERVER"] = str(1)
os.environ["DMLC_PS_ROOT_URI"] = "10.0.0.1"
os.environ["DMLC_PS_ROOT_PORT"] = "1234"
os.environ["BYTEPS_CUDA_HOME"] = "/usr/local/cuda"
os.environ["BYTEPS_NCCL_HOME"] = "/usr/local/nccl"
if rank == 0:
print(f"os.environ {os.environ}")
command = "python example/pytorch/train_mnist_byteps.py"
subprocess.check_call(command,
stdout=sys.stdout, stderr=sys.stderr, shell=True)
if __name__ == "__main__":
num_devices = 2
mp.spawn(run_worker, args=(num_devices,), nprocs=num_devices, join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment