Created
June 6, 2024 13:28
-
-
Save garrett361/283db45d436d1e7e67e84f89db2a0e48 to your computer and use it in GitHub Desktop.
mp reduce scatter xpu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Launch single-node reduce scatter with multiprocessing. | |
python3 mp_torch_reduce_scatter.py | |
""" | |
import os | |
import socket | |
from concurrent.futures import ProcessPoolExecutor | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import intel_extension_for_pytorch as ipex # noqa | |
import oneccl_bindings_for_pytorch # noqa | |
def get_master_port(base_port: int = 29500, port_range_size: int = 1000) -> str: | |
# Select first open port in range | |
port = base_port | |
max_port = base_port + port_range_size | |
sock = socket.socket() | |
while port < max_port: | |
try: | |
sock.bind(("", port)) | |
sock.close() | |
return str(port) | |
except OSError: | |
port += 1 | |
raise IOError("no free ports") | |
def run(local_rank: int, world_size: int, master_port: str): | |
# Setup the environment | |
os.environ["MASTER_ADDR"] = "127.0.0.1" | |
os.environ["MASTER_PORT"] = master_port | |
os.environ["LOCAL_RANK"] = str(local_rank) | |
os.environ["RANK"] = str(local_rank) | |
os.environ["WORLD_SIZE"] = str(world_size) | |
os.environ["CCL_LOCAL_RANK"] = str(local_rank) | |
os.environ["CCL_LOCAL_SIZE"] = str(world_size) | |
device = torch.device(f"xpu:{local_rank}") | |
torch.xpu.set_device(device) | |
try: | |
dist.init_process_group("ccl") | |
t_in = [torch.empty(1, device=device) for _ in range(world_size)] | |
t_out = torch.empty(1, device=device) | |
dist.reduce_scatter(t_out, t_in, op=dist.ReduceOp.SUM) | |
finally: | |
dist.destroy_process_group() | |
def main(world_size: int = 2): | |
mp_context = mp.get_context("spawn") | |
with ProcessPoolExecutor(max_workers=world_size, mp_context=mp_context) as pool: | |
master_port = get_master_port() | |
local_ranks_list = [r for r in range(world_size)] | |
world_size_list = [world_size for _ in local_ranks_list] | |
master_ports_list = [master_port for _ in local_ranks_list] | |
results = pool.map( | |
run, | |
local_ranks_list, | |
world_size_list, | |
master_ports_list, | |
timeout=30, | |
) | |
for r in results: | |
print(r, flush=True) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment