Created
March 31, 2021 20:03
-
-
Save anj-s/958a7e444100e762180bf289da8a6cab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# File used to run bytePS on a local worker with 2 GPUs | |
import os | |
import subprocess | |
import sys | |
import torch | |
import torch.multiprocessing as mp | |
def run_worker(rank, world_size): | |
torch.cuda.set_device(rank) | |
os.environ["DMLC_ROLE"] = "worker" | |
os.environ["DMLC_WORKER_ID"] = str(0) | |
os.environ["DMLC_NUM_WORKER"] = str(1) | |
os.environ["NVIDIA_VISIBLE_DEVICES"] = "0,1" | |
# os.environ["DMLC_ENABLE_RDMA"] = "ibverbs" | |
print(f"rank to be set {rank}") | |
os.environ["BYTEPS_LOG_LEVEL"] = "INFO" | |
os.environ["NCCL_DEBUG"] = "INFO" | |
# os.environ["NCCL_DEBUG"] = "WARN" | |
os.environ["BYTEPS_ENABLE_GDB"] = "1" | |
os.environ["BYTEPS_LOCAL_RANK"] = str(rank) | |
os.environ["BYTEPS_LOCAL_SIZE"] = str(world_size) | |
# Don't matter | |
os.environ["DMLC_NUM_SERVER"] = str(1) | |
os.environ["DMLC_PS_ROOT_URI"] = "10.0.0.1" | |
os.environ["DMLC_PS_ROOT_PORT"] = "1234" | |
os.environ["BYTEPS_CUDA_HOME"] = "/usr/local/cuda" | |
os.environ["BYTEPS_NCCL_HOME"] = "/usr/local/nccl" | |
if rank == 0: | |
print(f"os.environ {os.environ}") | |
command = "python example/pytorch/train_mnist_byteps.py" | |
subprocess.check_call(command, | |
stdout=sys.stdout, stderr=sys.stderr, shell=True) | |
if __name__ == "__main__": | |
num_devices = 2 | |
mp.spawn(run_worker, args=(num_devices,), nprocs=num_devices, join=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment