Created
May 29, 2024 21:05
-
-
Save garrett361/bd67c06c85034bca39a6ee978bbf5030 to your computer and use it in GitHub Desktop.
Reduce scatter tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Raises a ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY after the 29th iteration on an Intel 1550 max. | |
""" | |
import argparse | |
import os | |
import torch | |
import intel_extension_for_pytorch as ipex # noqa | |
import oneccl_bindings_for_pytorch # noqa | |
import torch.distributed as dist | |
def get_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--dim", | |
type=int, | |
default=2**30, | |
) | |
parser.add_argument( | |
"--dtype", | |
type=str, | |
default="bfloat16", | |
) | |
parser.add_argument( | |
"--max-steps", | |
type=int, | |
default=100, | |
) | |
args = parser.parse_args() | |
return args | |
def main(dim: int, dtype: str, max_steps: int) -> None: | |
world_size = int(os.environ["WORLD_SIZE"]) | |
rank = int(os.environ["RANK"]) | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
device = torch.device(f"xpu:{local_rank}") | |
torch.xpu.set_device(device) | |
# Force dim to be divisible by the world size | |
new_dim = world_size * (dim // world_size) | |
if new_dim != dim: | |
if not rank: | |
print( | |
f"Adjusting original {dim=} to {new_dim} in order to be divisible by {world_size=}", | |
flush=True, | |
) | |
dim = new_dim | |
try: | |
dist.init_process_group("ccl") | |
t_in = [ | |
torch.randn(dim // world_size, dtype=getattr(torch, dtype), device=device) | |
for _ in range(world_size) | |
] | |
t_out = torch.empty(dim // world_size, dtype=getattr(torch, dtype), device=device) | |
for step in range(1, max_steps + 1): | |
dist.reduce_scatter(t_out, t_in, op=dist.ReduceOp.SUM) | |
torch.xpu.synchronize() | |
peak_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.peak"] / 2**30 | |
current_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.current"] / 2**30 | |
print(f"[{rank=}]: {step=} memory {peak_mem_gib=}, {current_mem_gib=}", flush=True) | |
finally: | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
args = get_args() | |
main(**vars(args)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Raises a RuntimeError: ProcessGroupCCL does not support reduce_scatter | |
""" | |
import argparse | |
import os | |
import torch | |
import intel_extension_for_pytorch as ipex # noqa | |
import oneccl_bindings_for_pytorch # noqa | |
import torch.distributed as dist | |
def get_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--dim", | |
type=int, | |
default=2**30, | |
) | |
parser.add_argument( | |
"--dtype", | |
type=str, | |
default="bfloat16", | |
) | |
parser.add_argument( | |
"--max-steps", | |
type=int, | |
default=100, | |
) | |
args = parser.parse_args() | |
return args | |
def main(dim: int, dtype: str, max_steps: int) -> None: | |
world_size = int(os.environ["WORLD_SIZE"]) | |
rank = int(os.environ["RANK"]) | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
device = torch.device(f"xpu:{local_rank}") | |
torch.xpu.set_device(device) | |
# Force dim to be divisible by the world size | |
new_dim = world_size * (dim // world_size) | |
if new_dim != dim: | |
if not rank: | |
print( | |
f"Adjusting original {dim=} to {new_dim} in order to be divisible by {world_size=}", | |
flush=True, | |
) | |
dim = new_dim | |
try: | |
dist.init_process_group("ccl") | |
t_in = torch.randn(dim, dtype=getattr(torch, dtype), device=device) | |
t_out = torch.empty(dim // world_size, dtype=getattr(torch, dtype), device=device) | |
for step in range(1, max_steps + 1): | |
dist.reduce_scatter_tensor(t_out, t_in, op=dist.ReduceOp.SUM) | |
torch.xpu.synchronize() | |
peak_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.peak"] / 2**30 | |
current_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.current"] / 2**30 | |
print(f"[{rank=}]: {step=} memory {peak_mem_gib=}, {current_mem_gib=}", flush=True) | |
finally: | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
args = get_args() | |
main(**vars(args)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment