Skip to content

Instantly share code, notes, and snippets.

@garrett361
Last active March 10, 2025 19:49
Show Gist options
  • Save garrett361/f36b6c0b673cb1d777cb92f35438648c to your computer and use it in GitHub Desktop.
Save garrett361/f36b6c0b673cb1d777cb92f35438648c to your computer and use it in GitHub Desktop.
DTensor slicing
import os
import torch
import torch.distributed as dist
from torch.distributed.tensor import Shard, distribute_tensor
if __name__ == "__main__":
try:
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{local_rank}")
mesh = dist.device_mesh.init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
# Create a DTensor with 2 * world_size logical elements, sharded over the world. Every rank
# holds 2 tensor elements.
shard_spec = (Shard(0),)
x_dt = distribute_tensor(torch.arange(2 * world_size, device=device), mesh, shard_spec)
# Create a slice with only world_size elements
x_dt_slice = x_dt[..., :world_size]
# Expectation: the DTensor slice should still be sharded, as in Shard(-1), and every rank
# should hold a single local element.
# Actual: the DTensor slice ends up with Replicate() sharding, with every rank holding
# world_size elements.
# For nicer printing:
for rank in range(world_size):
if rank == local_rank:
print(f"\n{rank=}")
print(f"\t{x_dt=}")
print(f"\t{x_dt.placements=}")
print(f"\t{x_dt.to_local().shape=}")
print(f"\n\t{x_dt_slice=}")
print(f"\t{x_dt_slice.placements=}")
print(f"\t{x_dt_slice.to_local().shape=}")
dist.barrier()
assert x_dt.to_local().numel() == 2, f"{x_dt.to_local().numel=}"
assert x_dt.placements == shard_spec, f"{x_dt.placements =}, {shard_spec=}"
# These fail:
assert x_dt_slice.to_local().numel() == 1
assert x_dt_slice.placements == shard_spec, f"{x_dt_slice.placements =}, {shard_spec=}"
finally:
torch.distributed.destroy_process_group()
# Partial output
⬢ [podman] ❯ torchrun --nproc-per-node 4 learn_torch/distributed/dtensor/test_dtensor_slice.py
W0310 19:48:36.916000 3526002 torch/distributed/run.py:792]
W0310 19:48:36.916000 3526002 torch/distributed/run.py:792] *****************************************
W0310 19:48:36.916000 3526002 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0310 19:48:36.916000 3526002 torch/distributed/run.py:792] *****************************************
rank=0
x_dt=DTensor(local_tensor=tensor([0, 1], device='cuda:0'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),))
x_dt.placements=(Shard(dim=0),)
x_dt.to_local().shape=torch.Size([2])
x_dt_slice=DTensor(local_tensor=tensor([0, 1, 2, 3], device='cuda:0'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Replicate(),))
x_dt_slice.placements=(Replicate(),)
x_dt_slice.to_local().shape=torch.Size([4])
rank=1
x_dt=DTensor(local_tensor=tensor([2, 3], device='cuda:1'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),))
x_dt.placements=(Shard(dim=0),)
x_dt.to_local().shape=torch.Size([2])
x_dt_slice=DTensor(local_tensor=tensor([0, 1, 2, 3], device='cuda:1'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Replicate(),))
x_dt_slice.placements=(Replicate(),)
x_dt_slice.to_local().shape=torch.Size([4])
rank=2
x_dt=DTensor(local_tensor=tensor([4, 5], device='cuda:2'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),))
x_dt.placements=(Shard(dim=0),)
x_dt.to_local().shape=torch.Size([2])
x_dt_slice=DTensor(local_tensor=tensor([0, 1, 2, 3], device='cuda:2'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Replicate(),))
x_dt_slice.placements=(Replicate(),)
x_dt_slice.to_local().shape=torch.Size([4])
rank=3
x_dt=DTensor(local_tensor=tensor([6, 7], device='cuda:3'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),))
x_dt.placements=(Shard(dim=0),)
x_dt.to_local().shape=torch.Size([2])
x_dt_slice=DTensor(local_tensor=tensor([0, 1, 2, 3], device='cuda:3'), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Replicate(),))
x_dt_slice.placements=(Replicate(),)
x_dt_slice.to_local().shape=torch.Size([4])
[rank0]: Traceback (most recent call last):
[rank0]: File "/gpfs/users/goon/github/garrett361/learn_torch/learn_torch/distributed/dtensor/test_dtensor_slice.py", line 47, in <module>
[rank0]: assert x_dt_slice.to_local().numel() == 1
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError
[rank2]: Traceback (most recent call last):
[rank2]: File "/gpfs/users/goon/github/garrett361/learn_torch/learn_torch/distributed/dtensor/test_dtensor_slice.py", line 47, in <module>
[rank2]: assert x_dt_slice.to_local().numel() == 1
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: AssertionError
[rank3]: Traceback (most recent call last):
[rank3]: File "/gpfs/users/goon/github/garrett361/learn_torch/learn_torch/distributed/dtensor/test_dtensor_slice.py", line 47, in <module>
[rank3]: assert x_dt_slice.to_local().numel() == 1
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: AssertionError
[rank1]: Traceback (most recent call last):
[rank1]: File "/gpfs/users/goon/github/garrett361/learn_torch/learn_torch/distributed/dtensor/test_dtensor_slice.py", line 47, in <module>
[rank1]: assert x_dt_slice.to_local().numel() == 1
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment