Last active
March 10, 2025 19:49
-
-
Save garrett361/f36b6c0b673cb1d777cb92f35438648c to your computer and use it in GitHub Desktop.
DTensor slicing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.distributed as dist | |
from torch.distributed.tensor import Shard, distribute_tensor | |
if __name__ == "__main__": | |
try: | |
world_size = int(os.environ["WORLD_SIZE"]) | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
device = torch.device(f"cuda:{local_rank}") | |
mesh = dist.device_mesh.init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) | |
# Create a DTensor with 2 * world_size logical elements, sharded over the world. Every rank | |
# holds 2 tensor elements. | |
shard_spec = (Shard(0),) | |
x_dt = distribute_tensor(torch.arange(2 * world_size, device=device), mesh, shard_spec) | |
# Create a slice with only world_size elements | |
x_dt_slice = x_dt[..., :world_size] | |
# Expectation: the DTensor slice should still be sharded, as in Shard(-1), and every rank | |
# should hold a single local element. | |
# Actual: the DTensor slice ends up with Replicate() sharding, with every rank holding | |
# world_size elements. | |
# For nicer printing: | |
for rank in range(world_size): | |
if rank == local_rank: | |
print(f"\n{rank=}") | |
print(f"\t{x_dt=}") | |
print(f"\t{x_dt.placements=}") | |
print(f"\t{x_dt.to_local().shape=}") | |
print(f"\n\t{x_dt_slice=}") | |
print(f"\t{x_dt_slice.placements=}") | |
print(f"\t{x_dt_slice.to_local().shape=}") | |
dist.barrier() | |
assert x_dt.to_local().numel() == 2, f"{x_dt.to_local().numel=}" | |
assert x_dt.placements == shard_spec, f"{x_dt.placements =}, {shard_spec=}" | |
# These fail: | |
assert x_dt_slice.to_local().numel() == 1 | |
assert x_dt_slice.placements == shard_spec, f"{x_dt_slice.placements =}, {shard_spec=}" | |
finally: | |
torch.distributed.destroy_process_group() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Partial output | |
⬢ [podman] ❯ torchrun --nproc-per-node 4 learn_torch/distributed/dtensor/test_dtensor_slice.py | |
W0310 19:48:36.916000 3526002 torch/distributed/run.py:792] | |
W0310 19:48:36.916000 3526002 torch/distributed/run.py:792] ***************************************** | |
W0310 19:48:36.916000 3526002 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | |
W0310 19:48:36.916000 3526002 torch/distributed/run.py:792] ***************************************** | |
rank=0 | |
x_dt=DTensor(local_tensor=tensor([0, 1], device='cuda:0'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)) | |
x_dt.placements=(Shard(dim=0),) | |
x_dt.to_local().shape=torch.Size([2]) | |
x_dt_slice=DTensor(local_tensor=tensor([0, 1, 2, 3], device='cuda:0'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Replicate(),)) | |
x_dt_slice.placements=(Replicate(),) | |
x_dt_slice.to_local().shape=torch.Size([4]) | |
rank=1 | |
x_dt=DTensor(local_tensor=tensor([2, 3], device='cuda:1'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)) | |
x_dt.placements=(Shard(dim=0),) | |
x_dt.to_local().shape=torch.Size([2]) | |
x_dt_slice=DTensor(local_tensor=tensor([0, 1, 2, 3], device='cuda:1'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Replicate(),)) | |
x_dt_slice.placements=(Replicate(),) | |
x_dt_slice.to_local().shape=torch.Size([4]) | |
rank=2 | |
x_dt=DTensor(local_tensor=tensor([4, 5], device='cuda:2'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)) | |
x_dt.placements=(Shard(dim=0),) | |
x_dt.to_local().shape=torch.Size([2]) | |
x_dt_slice=DTensor(local_tensor=tensor([0, 1, 2, 3], device='cuda:2'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Replicate(),)) | |
x_dt_slice.placements=(Replicate(),) | |
x_dt_slice.to_local().shape=torch.Size([4]) | |
rank=3 | |
x_dt=DTensor(local_tensor=tensor([6, 7], device='cuda:3'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)) | |
x_dt.placements=(Shard(dim=0),) | |
x_dt.to_local().shape=torch.Size([2]) | |
x_dt_slice=DTensor(local_tensor=tensor([0, 1, 2, 3], device='cuda:3'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Replicate(),)) | |
x_dt_slice.placements=(Replicate(),) | |
x_dt_slice.to_local().shape=torch.Size([4]) | |
[rank0]: Traceback (most recent call last): | |
[rank0]: File "/gpfs/users/goon/github/garrett361/learn_torch/learn_torch/distributed/dtensor/test_dtensor_slice.py", line 47, in <module> | |
[rank0]: assert x_dt_slice.to_local().numel() == 1 | |
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
[rank0]: AssertionError | |
[rank2]: Traceback (most recent call last): | |
[rank2]: File "/gpfs/users/goon/github/garrett361/learn_torch/learn_torch/distributed/dtensor/test_dtensor_slice.py", line 47, in <module> | |
[rank2]: assert x_dt_slice.to_local().numel() == 1 | |
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
[rank2]: AssertionError | |
[rank3]: Traceback (most recent call last): | |
[rank3]: File "/gpfs/users/goon/github/garrett361/learn_torch/learn_torch/distributed/dtensor/test_dtensor_slice.py", line 47, in <module> | |
[rank3]: assert x_dt_slice.to_local().numel() == 1 | |
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
[rank3]: AssertionError | |
[rank1]: Traceback (most recent call last): | |
[rank1]: File "/gpfs/users/goon/github/garrett361/learn_torch/learn_torch/distributed/dtensor/test_dtensor_slice.py", line 47, in <module> | |
[rank1]: assert x_dt_slice.to_local().numel() == 1 | |
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
[rank1]: AssertionError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment