Created
April 9, 2024 18:14
-
-
Save woshiyyya/35e68001dc143285fbb9464a39e2b35f to your computer and use it in GitHub Desktop.
Test Async Actor DDP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from ray.train._internal.utils import get_address_and_port | |
import ray | |
import os | |
import torch | |
import torch.nn as nn | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import time | |
@ray.remote(num_gpus=1) | |
class TrainWorker: | |
def __init__(self): | |
self.metadata = [1,2,3] | |
def get_worker_info(self): | |
node_ip, port = get_address_and_port() | |
return { | |
"gpu_id": min(ray.get_gpu_ids()), | |
"node_ip": node_ip, | |
"port": port | |
} | |
async def train_func(self, worker_info, global_info): | |
from ray.train.torch.config import _setup_torch_process_group | |
os.environ["CUDA_VISIBLE_DEVICES"] = worker_info['cuda_visible_devices'] | |
os.environ["MASTER_ADDR"] = global_info["master_addr"] | |
os.environ["MASTER_PORT"] = str(global_info["master_port"]) | |
os.environ["RANK"] = str(worker_info["world_rank"]) | |
os.environ["LOCAL_RANK"] = str(worker_info["local_rank"]) | |
os.environ["WORLD_SIZE"] = str(global_info["world_size"]) | |
os.environ["LOCAL_WORLD_SIZE"] = str(worker_info["local_world_size"]) | |
os.environ["NODE_RANK"] = str(worker_info["node_rank"]) | |
_setup_torch_process_group( | |
backend="nccl", | |
world_rank=worker_info["world_rank"], | |
world_size=global_info["world_size"], | |
init_method="env://", | |
timeout_s=10, | |
) | |
print(worker_info['gpu_id']) | |
# Model Training | |
device = torch.device(f"cuda:{worker_info['gpu_id']}") | |
model = nn.Linear(10, 10).to(device) | |
model = DDP(model) | |
try: | |
# Simulate a long-running operation | |
for i in range(100): | |
input_tensor = torch.FloatTensor(100, 10).to(device) | |
output = model(input_tensor) | |
loss = output.sum() | |
loss.backward() | |
time.sleep(1) | |
print(f"step {i}") | |
except Exception: | |
print("My coroutine was cancelled") | |
raise | |
def allocate_ranks(worker_bundle): | |
worker_bundle.sort(key=lambda x: (x[1]["node_ip"], x[1]["gpu_id"])) | |
local_world_size = defaultdict(int) | |
local_visible_devices = defaultdict(list) | |
for i, w in enumerate(worker_bundle): | |
node_ip = w[1]["node_ip"] | |
w[1]["world_rank"] = i | |
w[1]["node_rank"] = len(local_world_size) | |
w[1]["local_rank"] = local_world_size[node_ip] | |
local_world_size[node_ip] += 1 | |
local_visible_devices[node_ip].append(str(w[1]['gpu_id'])) | |
for w in worker_bundle: | |
node_ip = w[1]["node_ip"] | |
w[1]["local_world_size"] = local_world_size[node_ip] | |
w[1]["cuda_visible_devices"] = ",".join(local_visible_devices[node_ip]) | |
for w in worker_bundle: | |
print(w) | |
return { | |
"master_addr": worker_bundle[0][1]["node_ip"], | |
"master_port": worker_bundle[0][1]["port"], | |
"world_size": len(worker_bundle), | |
} | |
def launch_workers(n): | |
workers = [TrainWorker.remote() for i in range(n)] | |
workers_info = ray.get([worker.get_worker_info.remote() for worker in workers]) | |
return list(zip(workers, workers_info)) | |
def run_tasks_then_cancel(worker_bundle, global_info): | |
train_tasks = [worker.train_func.remote(worker_info=worker_info, global_info=global_info) for worker, worker_info in worker_bundle] | |
print("RUNNING....") | |
time.sleep(5) | |
print("CANCEL...") | |
for task in train_tasks: | |
ray.cancel(task, force=True) | |
if __name__ == "__main__": | |
worker_bundle = launch_workers(8) | |
global_info = allocate_ranks(worker_bundle) | |
run_tasks_then_cancel(worker_bundle, global_info) | |
time.sleep(100) | |
# run_tasks_then_cancel(worker_bundle, global_info) | |
# new_worker_bundle = launch_workers(4) | |
# worker_bundle.extend(new_worker_bundle) | |
# global_info = allocate_ranks(worker_bundle) | |
# run_tasks_then_cancel(worker_bundle, global_info) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment