Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created June 6, 2024 13:28
Show Gist options
  • Save garrett361/283db45d436d1e7e67e84f89db2a0e48 to your computer and use it in GitHub Desktop.
Save garrett361/283db45d436d1e7e67e84f89db2a0e48 to your computer and use it in GitHub Desktop.
mp reduce scatter xpu
"""
Launch single-node reduce scatter with multiprocessing.
python3 mp_torch_reduce_scatter.py
"""
import os
import socket
from concurrent.futures import ProcessPoolExecutor
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import intel_extension_for_pytorch as ipex # noqa
import oneccl_bindings_for_pytorch # noqa
def get_master_port(base_port: int = 29500, port_range_size: int = 1000) -> str:
# Select first open port in range
port = base_port
max_port = base_port + port_range_size
sock = socket.socket()
while port < max_port:
try:
sock.bind(("", port))
sock.close()
return str(port)
except OSError:
port += 1
raise IOError("no free ports")
def run(local_rank: int, world_size: int, master_port: str):
# Setup the environment
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["CCL_LOCAL_RANK"] = str(local_rank)
os.environ["CCL_LOCAL_SIZE"] = str(world_size)
device = torch.device(f"xpu:{local_rank}")
torch.xpu.set_device(device)
try:
dist.init_process_group("ccl")
t_in = [torch.empty(1, device=device) for _ in range(world_size)]
t_out = torch.empty(1, device=device)
dist.reduce_scatter(t_out, t_in, op=dist.ReduceOp.SUM)
finally:
dist.destroy_process_group()
def main(world_size: int = 2):
mp_context = mp.get_context("spawn")
with ProcessPoolExecutor(max_workers=world_size, mp_context=mp_context) as pool:
master_port = get_master_port()
local_ranks_list = [r for r in range(world_size)]
world_size_list = [world_size for _ in local_ranks_list]
master_ports_list = [master_port for _ in local_ranks_list]
results = pool.map(
run,
local_ranks_list,
world_size_list,
master_ports_list,
timeout=30,
)
for r in results:
print(r, flush=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment