Skip to content

Instantly share code, notes, and snippets.

@joshlk
Created October 24, 2023 09:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joshlk/681b94a66f2b4bc1583bb6e2abd6bf65 to your computer and use it in GitHub Desktop.
Save joshlk/681b94a66f2b4bc1583bb6e2abd6bf65 to your computer and use it in GitHub Desktop.
Vocab sharding using DTensors
from math import ceil
from typing import Optional, Tuple, Union
import torch
from torch import Tensor, nn
from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_module, distribute_tensor
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.embedding_ops import embedding_rules
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.placement_types import DTensorSpec, _Partial
from torch.nn import functional as F
import torch
import torch.distributed.distributed_c10d as c10d
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch import Tensor
from torch.distributed._tensor import (
DeviceMesh,
DTensor,
Shard,
)
from torch.distributed._tensor.op_schema import OpSchema, OpStrategy, RuntimeSchemaInfo
from torch.distributed._tensor.ops.math_ops import (
_infer_reduction_dims,
common_reduction_strategy,
)
from torch.distributed._tensor.ops.utils import register_op_strategy
from torch.distributed._tensor.placement_types import Shard
from torch.autograd.function import once_differentiable
from torch.distributed.tensor.parallel._utils import _prepare_input_validate
aten = torch.ops.aten
### Sharding propagation rules
@register_prop_rule(aten.embedding.default)
def embedding_rules_custom(op_schema: OpSchema) -> OutputSharding:
weight_spec, inp_spec = op_schema.args_spec
if weight_spec.placements == (Shard(0),) and inp_spec.placements == (Shard(0),):
return OutputSharding(
output_spec=DTensorSpec(mesh=inp_spec.mesh, placements=(_Partial(),)),
schema_suggestions=[
OpSchema(
op=op_schema.op,
args_schema=(
DTensorSpec(mesh=weight_spec.mesh, placements=(Shard(0),)),
DTensorSpec(mesh=inp_spec.mesh, placements=(Shard(0),)),
),
kwargs_schema={},
)
],
)
# Current embedding rules
return embedding_rules(op_schema=op_schema)
@register_op_strategy([aten.max.default, aten.max.dim, aten.max.out], schema_info=RuntimeSchemaInfo(1))
def mean_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
args_schema = op_schema.args_schema
input_strategy = args_schema[0]
assert isinstance(input_strategy, OpStrategy)
dims = None
if len(op_schema.args_schema) > 1:
dims = _infer_reduction_dims(args_schema[1], input_strategy.output_ndim)
reduce_dims = list(range(input_strategy.output_ndim)) if dims is None else dims
keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
return common_reduction_strategy(
mesh,
input_strategy,
reduce_dims,
keep_dim=keep_dim,
reduction_linear=True,
reduction_op=c10d.ReduceOp.MAX,
)
### Module and function changes
class Embedding(nn.Module):
"""Same as nn.Embedding but uses `embedding` Function which has `zero_OOR` option."""
def __init__(self, num_embeddings: int, embedding_dim: int, device=None, dtype=None, zero_OOR=False) -> None:
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device=device, dtype=dtype), requires_grad=True
)
nn.init.normal_(self.weight)
self.tp_shards = 1
self.device_mesh: Optional[DeviceMesh] = None
self.zero_OOR = zero_OOR
def forward(self, input: Tensor) -> Tensor:
if self.zero_OOR:
weight_shape = self.weight.to_local().shape if isinstance(self.weight, DTensor) else self.weight.shape
max_possible_index = weight_shape[0] - 1
OOR_indices = (input < 0) | (input > max_possible_index)
input = torch.where(OOR_indices, 0, input)
output = F.embedding(
input=input,
weight=self.weight,
)
if self.zero_OOR:
output = torch.where(OOR_indices.unsqueeze(-1), 0, output)
return output
def parallelize(self, device_mesh: DeviceMesh) -> None:
self.tp_shards = device_mesh.size()
global_vocab_size = self.weight.shape[0]
def partition_embedding_vocab_fn(name: str, module: nn.Module, device_mesh: DeviceMesh):
# Shard in vocab axis
weight: torch.Tensor = module.weight
weight = distribute_tensor(weight, device_mesh, [Shard(0)])
module.register_parameter("weight", nn.Parameter(weight))
@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56]
def vocab_shard_input_fn(
inputs: Tuple[Union[torch.Tensor, DTensor], ...],
device_mesh: Optional[DeviceMesh] = None,
) -> DTensor:
input = inputs
vocab_shard_size = ceil(global_vocab_size / self.tp_shards)
# Adjust indices so they align with local embedding indices
offset = vocab_shard_size * dist.get_rank()
input_adjusted = input - offset
input_adjusted = DTensor.from_local(input_adjusted, device_mesh=device_mesh, placements=[Shard(0)])
return input_adjusted
# Acts inplace on module
distribute_module(self, device_mesh, partition_embedding_vocab_fn, vocab_shard_input_fn)
class crossentropylosssharded(torch.autograd.Function):
@staticmethod
def forward(ctx, input: Tensor, target: Tensor):
"""Forward has two collectives: max and sum"""
# log_stable_softmax = x-max(x) - log(sum(e^(x-max(x))))
max_input, _ = input.max(dim=1)
translated = input - max_input.reshape(-1, 1)
denominator = torch.log(torch.exp(translated).sum(dim=1))
logsoftmax = translated - denominator.reshape(-1, 1)
loss = torch.gather(-logsoftmax, 1, target.reshape(-1, 1)).flatten()
ctx.save_for_backward(target, logsoftmax)
return loss
@staticmethod
@once_differentiable
def backward(ctx, loss_grad: Tensor):
target, logsoftmax = ctx.saved_tensors
# If:
# y = softmax(x)
# z = nnl_loss(y)
# Then, for a given scalar loss l:
# dl/dx = dl/dz @ dz/dx
# dl/dx = dl/dz . (y - I_t)
# Where I_t is one-hot encoded target matrix. '.' is element-wise multiplication with broadcasting
softmax = torch.exp(logsoftmax)
softmax[torch.arange(len(softmax)), target] -= 1
input_grad = loss_grad.reshape(-1, 1) * softmax
return input_grad, None
class CrossEntropyLossSharded(nn.Module):
def __init__(self, zero_OOR=False) -> None:
super().__init__()
self.zero_OOR = zero_OOR
def forward(self, input, target):
if self.zero_OOR:
input_shape = input.to_local().shape if isinstance(input, DTensor) else input.shape
max_possible_index = input_shape[1] - 1
OOR_indices = (target < 0) | (target > max_possible_index)
target_l = torch.where(OOR_indices, 0, target_l)
loss = crossentropylosssharded.apply(input, target)
if self.zero_OOR:
loss = torch.where(OOR_indices, 0, loss)
return loss
def parallelize(self, classes, device_mesh: DeviceMesh) -> None:
tp_shards = device_mesh.size()
@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56]
def vocab_shard_input_fn(
inputs: Tuple[Union[torch.Tensor, DTensor], ...],
device_mesh: Optional[DeviceMesh] = None,
) -> DTensor:
input, target = inputs
vocab_shard_size = ceil(classes / tp_shards)
# Adjust indices so they align with local embedding indices
offset = vocab_shard_size * dist.get_rank()
target_adjusted = target - offset
target_adjusted = DTensor.from_local(target_adjusted, device_mesh=device_mesh, placements=[Shard(0)])
return input, target_adjusted
# Acts inplace on module
distribute_module(self, device_mesh, input_fn=vocab_shard_input_fn)
### Example
import os
import torch.distributed as dist
def worker(rank: int, world_size: int, port: int):
# Setup env in worker to prevent pollution of parent's env
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
# Init dist with gloo/cpu backend
dist.init_process_group("gloo", rank=rank, world_size=world_size)
device_mesh = DeviceMesh("cpu", torch.arange(world_size))
# Inputs
items = 3
classes = 4
hidden = 2
classes_tp = classes // world_size
torch.manual_seed(42)
input = torch.randint(classes, (items,)) # torch.rand((items, classes), requires_grad=True)
target = torch.randint(classes, (items,))
# DTensors
input = DTensor.from_local(input, device_mesh, [Replicate()])
target = DTensor.from_local(target, device_mesh, [Replicate()])
# Model
embedding = Embedding(classes, hidden, zero_OOR=True)
loss_mod = CrossEntropyLossSharded(zero_OOR=True)
embedding.parallelize(device_mesh)
loss_mod.parallelize(classes, device_mesh)
# Embedding layer
x = embedding(input)
# Transformer layers here...
# Head layer
x = x @ embedding.weight.T # Head has tied weight with embedding
loss = loss_mod(input, target)
print(loss)
if __name__ == "__main__":
# Example
world_size = 2
port = 8345
mp.spawn(worker, nprocs=world_size, args=(world_size, port))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment