Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Created April 9, 2024 18:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save woshiyyya/35e68001dc143285fbb9464a39e2b35f to your computer and use it in GitHub Desktop.
Save woshiyyya/35e68001dc143285fbb9464a39e2b35f to your computer and use it in GitHub Desktop.
Test Async Actor DDP
from collections import defaultdict
from ray.train._internal.utils import get_address_and_port
import ray
import os
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import time
@ray.remote(num_gpus=1)
class TrainWorker:
def __init__(self):
self.metadata = [1,2,3]
def get_worker_info(self):
node_ip, port = get_address_and_port()
return {
"gpu_id": min(ray.get_gpu_ids()),
"node_ip": node_ip,
"port": port
}
async def train_func(self, worker_info, global_info):
from ray.train.torch.config import _setup_torch_process_group
os.environ["CUDA_VISIBLE_DEVICES"] = worker_info['cuda_visible_devices']
os.environ["MASTER_ADDR"] = global_info["master_addr"]
os.environ["MASTER_PORT"] = str(global_info["master_port"])
os.environ["RANK"] = str(worker_info["world_rank"])
os.environ["LOCAL_RANK"] = str(worker_info["local_rank"])
os.environ["WORLD_SIZE"] = str(global_info["world_size"])
os.environ["LOCAL_WORLD_SIZE"] = str(worker_info["local_world_size"])
os.environ["NODE_RANK"] = str(worker_info["node_rank"])
_setup_torch_process_group(
backend="nccl",
world_rank=worker_info["world_rank"],
world_size=global_info["world_size"],
init_method="env://",
timeout_s=10,
)
print(worker_info['gpu_id'])
# Model Training
device = torch.device(f"cuda:{worker_info['gpu_id']}")
model = nn.Linear(10, 10).to(device)
model = DDP(model)
try:
# Simulate a long-running operation
for i in range(100):
input_tensor = torch.FloatTensor(100, 10).to(device)
output = model(input_tensor)
loss = output.sum()
loss.backward()
time.sleep(1)
print(f"step {i}")
except Exception:
print("My coroutine was cancelled")
raise
def allocate_ranks(worker_bundle):
worker_bundle.sort(key=lambda x: (x[1]["node_ip"], x[1]["gpu_id"]))
local_world_size = defaultdict(int)
local_visible_devices = defaultdict(list)
for i, w in enumerate(worker_bundle):
node_ip = w[1]["node_ip"]
w[1]["world_rank"] = i
w[1]["node_rank"] = len(local_world_size)
w[1]["local_rank"] = local_world_size[node_ip]
local_world_size[node_ip] += 1
local_visible_devices[node_ip].append(str(w[1]['gpu_id']))
for w in worker_bundle:
node_ip = w[1]["node_ip"]
w[1]["local_world_size"] = local_world_size[node_ip]
w[1]["cuda_visible_devices"] = ",".join(local_visible_devices[node_ip])
for w in worker_bundle:
print(w)
return {
"master_addr": worker_bundle[0][1]["node_ip"],
"master_port": worker_bundle[0][1]["port"],
"world_size": len(worker_bundle),
}
def launch_workers(n):
workers = [TrainWorker.remote() for i in range(n)]
workers_info = ray.get([worker.get_worker_info.remote() for worker in workers])
return list(zip(workers, workers_info))
def run_tasks_then_cancel(worker_bundle, global_info):
train_tasks = [worker.train_func.remote(worker_info=worker_info, global_info=global_info) for worker, worker_info in worker_bundle]
print("RUNNING....")
time.sleep(5)
print("CANCEL...")
for task in train_tasks:
ray.cancel(task, force=True)
if __name__ == "__main__":
worker_bundle = launch_workers(8)
global_info = allocate_ranks(worker_bundle)
run_tasks_then_cancel(worker_bundle, global_info)
time.sleep(100)
# run_tasks_then_cancel(worker_bundle, global_info)
# new_worker_bundle = launch_workers(4)
# worker_bundle.extend(new_worker_bundle)
# global_info = allocate_ranks(worker_bundle)
# run_tasks_then_cancel(worker_bundle, global_info)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment