Created October 24, 2023 09:29
Vocab sharding using DTensors
from math import ceil
from typing import Optional, Tuple, Union
import torch
from torch import Tensor, nn
from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_module, distribute_tensor
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.embedding_ops import embedding_rules
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.placement_types import DTensorSpec, _Partial
from torch.nn import functional as F
import torch
import torch.distributed.distributed_c10d as c10d
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch import Tensor
from torch.distributed._tensor import (
from torch.distributed._tensor.op_schema import OpSchema, OpStrategy, RuntimeSchemaInfo
from torch.distributed._tensor.ops.math_ops import (
from torch.distributed._tensor.ops.utils import register_op_strategy
from torch.distributed._tensor.placement_types import Shard
from torch.autograd.function import once_differentiable
from torch.distributed.tensor.parallel._utils import _prepare_input_validate
aten = torch.ops.aten
### Sharding propagation rules
def embedding_rules_custom(op_schema: OpSchema) -> OutputSharding:
weight_spec, inp_spec = op_schema.args_spec
if weight_spec.placements == (Shard(0),) and inp_spec.placements == (Shard(0),):
return OutputSharding(
output_spec=DTensorSpec(mesh=inp_spec.mesh, placements=(_Partial(),)),
DTensorSpec(mesh=weight_spec.mesh, placements=(Shard(0),)),
DTensorSpec(mesh=inp_spec.mesh, placements=(Shard(0),)),
# Current embedding rules
return embedding_rules(op_schema=op_schema)
@register_op_strategy([aten.max.default, aten.max.dim, aten.max.out], schema_info=RuntimeSchemaInfo(1))
def mean_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
args_schema = op_schema.args_schema
input_strategy = args_schema[0]
assert isinstance(input_strategy, OpStrategy)
dims = None
if len(op_schema.args_schema) > 1:
dims = _infer_reduction_dims(args_schema[1], input_strategy.output_ndim)
reduce_dims = list(range(input_strategy.output_ndim)) if dims is None else dims
keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
return common_reduction_strategy(
### Module and function changes
class Embedding(nn.Module):
"""Same as nn.Embedding but uses `embedding` Function which has `zero_OOR` option."""
def __init__(self, num_embeddings: int, embedding_dim: int, device=None, dtype=None, zero_OOR=False) -> None:
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device=device, dtype=dtype), requires_grad=True
self.tp_shards = 1
self.device_mesh: Optional[DeviceMesh] = None
self.zero_OOR = zero_OOR
def forward(self, input: Tensor) -> Tensor:
if self.zero_OOR:
weight_shape = self.weight.to_local().shape if isinstance(self.weight, DTensor) else self.weight.shape
max_possible_index = weight_shape[0] - 1
OOR_indices = (input < 0) | (input > max_possible_index)
input = torch.where(OOR_indices, 0, input)
output = F.embedding(
if self.zero_OOR:
output = torch.where(OOR_indices.unsqueeze(-1), 0, output)
return output
def parallelize(self, device_mesh: DeviceMesh) -> None:
self.tp_shards = device_mesh.size()
global_vocab_size = self.weight.shape[0]
def partition_embedding_vocab_fn(name: str, module: nn.Module, device_mesh: DeviceMesh):
# Shard in vocab axis
weight: torch.Tensor = module.weight
weight = distribute_tensor(weight, device_mesh, [Shard(0)])
module.register_parameter("weight", nn.Parameter(weight))
@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56]
def vocab_shard_input_fn(
inputs: Tuple[Union[torch.Tensor, DTensor], ...],
device_mesh: Optional[DeviceMesh] = None,
) -> DTensor:
input = inputs
vocab_shard_size = ceil(global_vocab_size / self.tp_shards)
# Adjust indices so they align with local embedding indices
offset = vocab_shard_size * dist.get_rank()
input_adjusted = input - offset
input_adjusted = DTensor.from_local(input_adjusted, device_mesh=device_mesh, placements=[Shard(0)])
return input_adjusted
# Acts inplace on module
distribute_module(self, device_mesh, partition_embedding_vocab_fn, vocab_shard_input_fn)
class crossentropylosssharded(torch.autograd.Function):
def forward(ctx, input: Tensor, target: Tensor):
"""Forward has two collectives: max and sum"""
# log_stable_softmax = x-max(x) - log(sum(e^(x-max(x))))
max_input, _ = input.max(dim=1)
translated = input - max_input.reshape(-1, 1)
denominator = torch.log(torch.exp(translated).sum(dim=1))
logsoftmax = translated - denominator.reshape(-1, 1)
loss = torch.gather(-logsoftmax, 1, target.reshape(-1, 1)).flatten()
ctx.save_for_backward(target, logsoftmax)
return loss
def backward(ctx, loss_grad: Tensor):
target, logsoftmax = ctx.saved_tensors
# If:
# y = softmax(x)
# z = nnl_loss(y)
# Then, for a given scalar loss l:
# dl/dx = dl/dz @ dz/dx
# dl/dx = dl/dz . (y - I_t)
# Where I_t is one-hot encoded target matrix. '.' is element-wise multiplication with broadcasting
softmax = torch.exp(logsoftmax)
softmax[torch.arange(len(softmax)), target] -= 1
input_grad = loss_grad.reshape(-1, 1) * softmax
return input_grad, None
class CrossEntropyLossSharded(nn.Module):
def __init__(self, zero_OOR=False) -> None:
self.zero_OOR = zero_OOR
def forward(self, input, target):
if self.zero_OOR:
input_shape = input.to_local().shape if isinstance(input, DTensor) else input.shape
max_possible_index = input_shape[1] - 1
OOR_indices = (target < 0) | (target > max_possible_index)
target_l = torch.where(OOR_indices, 0, target_l)
loss = crossentropylosssharded.apply(input, target)
if self.zero_OOR:
loss = torch.where(OOR_indices, 0, loss)
return loss
def parallelize(self, classes, device_mesh: DeviceMesh) -> None:
tp_shards = device_mesh.size()
@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56]
def vocab_shard_input_fn(
inputs: Tuple[Union[torch.Tensor, DTensor], ...],
device_mesh: Optional[DeviceMesh] = None,
) -> DTensor:
input, target = inputs
vocab_shard_size = ceil(classes / tp_shards)
# Adjust indices so they align with local embedding indices
offset = vocab_shard_size * dist.get_rank()
target_adjusted = target - offset
target_adjusted = DTensor.from_local(target_adjusted, device_mesh=device_mesh, placements=[Shard(0)])
return input, target_adjusted
# Acts inplace on module
distribute_module(self, device_mesh, input_fn=vocab_shard_input_fn)
### Example
import os
import torch.distributed as dist
def worker(rank: int, world_size: int, port: int):
# Setup env in worker to prevent pollution of parent's env
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
# Init dist with gloo/cpu backend
dist.init_process_group("gloo", rank=rank, world_size=world_size)
device_mesh = DeviceMesh("cpu", torch.arange(world_size))
# Inputs
items = 3
classes = 4
hidden = 2
classes_tp = classes // world_size
input = torch.randint(classes, (items,)) # torch.rand((items, classes), requires_grad=True)
target = torch.randint(classes, (items,))
# DTensors
input = DTensor.from_local(input, device_mesh, [Replicate()])
target = DTensor.from_local(target, device_mesh, [Replicate()])
# Model
embedding = Embedding(classes, hidden, zero_OOR=True)
loss_mod = CrossEntropyLossSharded(zero_OOR=True)
loss_mod.parallelize(classes, device_mesh)
# Embedding layer
x = embedding(input)
# Transformer layers here...
# Head layer
x = x @ embedding.weight.T # Head has tied weight with embedding
loss = loss_mod(input, target)
if __name__ == "__main__":
# Example
world_size = 2
port = 8345
mp.spawn(worker, nprocs=world_size, args=(world_size, port))
