Skip to content

Instantly share code, notes, and snippets.

@yf225
Last active June 17, 2024 23:36
Show Gist options
  • Save yf225/c67c4d0ff081be5a7eac72f0ea395abf to your computer and use it in GitHub Desktop.
Save yf225/c67c4d0ff081be5a7eac72f0ea395abf to your computer and use it in GitHub Desktop.
"""
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.run --standalone --nproc_per_node=1 test_dtensor_param_type.py
"""
import functools
import torch
from torch.distributed._tensor import DTensor, Replicate, init_device_mesh
class FSDPParam:
def __init__(self):
device_mesh = init_device_mesh("cuda", (1,))
replica_placement = [Replicate()]
local_tensor = torch.zeros(3, 3, device="cuda")
dtensor = DTensor.from_local(local_tensor, device_mesh=device_mesh, placements=replica_placement)
self.sharded_param = torch.nn.Parameter(dtensor)
print(f"type(self.sharded_param): {type(self.sharded_param)}") # prints "<class 'torch.distributed._tensor.api.DTensor'>"
print(f"isinstance(self.sharded_param, torch.nn.Parameter): {isinstance(self.sharded_param, torch.nn.Parameter)}") # prints True
p = FSDPParam()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment