Skip to content

Instantly share code, notes, and snippets.

@garrett361
Last active June 10, 2024 14:37
Show Gist options
  • Save garrett361/1f016056c1edc156bbad7be57c724d4f to your computer and use it in GitHub Desktop.
Save garrett361/1f016056c1edc156bbad7be57c724d4f to your computer and use it in GitHub Desktop.
fdsp and ddp min tests
"""
Basic FSDP/DDP applied to a linear model.
"""
import argparse
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn.parallel import DistributedDataParallel as DDP
if torch.cuda.is_available():
accel = torch.cuda
DEVICE_TYPE = "cuda"
BACKEND = "nccl"
else:
import intel_extension_for_pytorch as ipex # noqa
import oneccl_bindings_for_pytorch as torch_ccl # noqa
print(
f"Using Versions: {torch.__version__=}, {ipex.__version__=}, {torch_ccl.__version__=}",
flush=True,
)
accel = torch.xpu
DEVICE_TYPE = "xpu"
BACKEND = "ccl"
class LinearModel(nn.Module):
def __init__(self, d_model: int, n_layers: int, device: torch.device) -> None:
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(d_model, d_model, device=device) for _ in range(n_layers)]
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
out = inputs
for lin in self.linears:
out = lin(out)
return out
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=1,
)
parser.add_argument(
"-d",
"--d-model",
type=int,
default=2**14,
)
parser.add_argument(
"-n",
"--n-layers",
type=int,
default=4,
)
parser.add_argument(
"-m",
"--max-steps",
type=int,
default=500,
)
parser.add_argument(
"-p", "--parallelization", type=str, default="fsdp", help="One of 'fsdp' or 'ddp'."
)
args = parser.parse_args()
return args
def main(
batch_size: int,
d_model: int,
n_layers: int,
max_steps: int,
parallelization: str,
) -> None:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"{DEVICE_TYPE}:{local_rank}")
accel.set_device(device)
try:
if world_size > 1:
dist.init_process_group(BACKEND)
inputs = torch.randn(batch_size, d_model, device=device)
model = LinearModel(d_model, n_layers, device)
# Print model details pre-sharding (only matters for FSDP)
if not rank:
print(
40 * "*",
f"Running {max_steps=}, {batch_size=} across {world_size} GPUs with\n",
f"{model=}\n",
f"Num parameters: {sum(p.numel() for p in model.parameters()):.2e}",
f"GiB: {sum(p.numel() * p.element_size() for p in model.parameters()) / 2 ** 30:.2e}",
40 * "*",
flush=True,
sep="\n",
)
if world_size > 1:
if parallelization == "fsdp":
model = FSDP(
model, device_id=device, auto_wrap_policy=ModuleWrapPolicy([nn.Linear])
)
elif parallelization == "ddp":
model = DDP(model)
else:
raise ValueError(
f"parallelization should be one of 'ddp' or 'fsdp' (default), not {parallelization=}"
)
dist.barrier()
for step in range(1, max_steps + 1):
loss = model(inputs).pow(2).mean()
loss.backward()
accel.synchronize()
peak_mem_GiB = accel.memory_stats()["allocated_bytes.all.peak"] / 2**30
current_mem_GiB = accel.memory_stats()["allocated_bytes.all.current"] / 2**30
print(f"[{rank=}]: {step=} memory {peak_mem_GiB=}, {current_mem_GiB=}", flush=True)
finally:
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
args = get_args()
main(**vars(args))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment