Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created May 29, 2024 21:05
Show Gist options
  • Save garrett361/bd67c06c85034bca39a6ee978bbf5030 to your computer and use it in GitHub Desktop.
Save garrett361/bd67c06c85034bca39a6ee978bbf5030 to your computer and use it in GitHub Desktop.
Reduce scatter tests
"""
Raises a ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY after the 29th iteration on an Intel 1550 max.
"""
import argparse
import os
import torch
import intel_extension_for_pytorch as ipex # noqa
import oneccl_bindings_for_pytorch # noqa
import torch.distributed as dist
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--dim",
type=int,
default=2**30,
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
)
parser.add_argument(
"--max-steps",
type=int,
default=100,
)
args = parser.parse_args()
return args
def main(dim: int, dtype: str, max_steps: int) -> None:
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"xpu:{local_rank}")
torch.xpu.set_device(device)
# Force dim to be divisible by the world size
new_dim = world_size * (dim // world_size)
if new_dim != dim:
if not rank:
print(
f"Adjusting original {dim=} to {new_dim} in order to be divisible by {world_size=}",
flush=True,
)
dim = new_dim
try:
dist.init_process_group("ccl")
t_in = [
torch.randn(dim // world_size, dtype=getattr(torch, dtype), device=device)
for _ in range(world_size)
]
t_out = torch.empty(dim // world_size, dtype=getattr(torch, dtype), device=device)
for step in range(1, max_steps + 1):
dist.reduce_scatter(t_out, t_in, op=dist.ReduceOp.SUM)
torch.xpu.synchronize()
peak_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.peak"] / 2**30
current_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.current"] / 2**30
print(f"[{rank=}]: {step=} memory {peak_mem_gib=}, {current_mem_gib=}", flush=True)
finally:
dist.destroy_process_group()
if __name__ == "__main__":
args = get_args()
main(**vars(args))
"""
Raises a RuntimeError: ProcessGroupCCL does not support reduce_scatter
"""
import argparse
import os
import torch
import intel_extension_for_pytorch as ipex # noqa
import oneccl_bindings_for_pytorch # noqa
import torch.distributed as dist
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--dim",
type=int,
default=2**30,
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
)
parser.add_argument(
"--max-steps",
type=int,
default=100,
)
args = parser.parse_args()
return args
def main(dim: int, dtype: str, max_steps: int) -> None:
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"xpu:{local_rank}")
torch.xpu.set_device(device)
# Force dim to be divisible by the world size
new_dim = world_size * (dim // world_size)
if new_dim != dim:
if not rank:
print(
f"Adjusting original {dim=} to {new_dim} in order to be divisible by {world_size=}",
flush=True,
)
dim = new_dim
try:
dist.init_process_group("ccl")
t_in = torch.randn(dim, dtype=getattr(torch, dtype), device=device)
t_out = torch.empty(dim // world_size, dtype=getattr(torch, dtype), device=device)
for step in range(1, max_steps + 1):
dist.reduce_scatter_tensor(t_out, t_in, op=dist.ReduceOp.SUM)
torch.xpu.synchronize()
peak_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.peak"] / 2**30
current_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.current"] / 2**30
print(f"[{rank=}]: {step=} memory {peak_mem_gib=}, {current_mem_gib=}", flush=True)
finally:
dist.destroy_process_group()
if __name__ == "__main__":
args = get_args()
main(**vars(args))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment