Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created April 22, 2024 18:30
Show Gist options
  • Save garrett361/b8bded976d4bf82f31cd4a34ed97ee30 to your computer and use it in GitHub Desktop.
Save garrett361/b8bded976d4bf82f31cd4a34ed97ee30 to your computer and use it in GitHub Desktop.
Minimally profile comms/compute overlap
"""
Minimal profiling script for profiling compute/comms overlap.
torchrun --nnodes=1 --nproc-per-node=2 profile_comms_compute_overlap.py [--no-comms]
"""
import argparse
import os
from pathlib import Path
import torch
import torch.distributed as dist
from torch.profiler import ProfilerActivity
if torch.cuda.is_available():
from torch import cuda as accel # noqa
DEVICE_TYPE = "cuda"
BACKEND = "nccl"
else:
import intel_extension_for_pytorch as ipex # noqa
from torch import xpu as accel # noqa
import oneccl_bindings_for_pytorch # noqa
DEVICE_TYPE = "xpu"
BACKEND = "ccl"
# Note all of the instructions for ipex profiling
# https://github.com/intel/intel-extension-for-pytorch/blob/1296c267c4247a7027d2103d05204b6b556b3d63/docs/tutorials/features/profiler_kineto.md#L24-L24
BATCH_SIZE = 2**14
COMMS_BATCH_SIZE = 2**17
DIM = 2**14
N_LAYERS = 3
STEPS = 3
RANK = int(os.getenv("RANK", 0))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
DEVICE = torch.device(f"{DEVICE_TYPE}:{LOCAL_RANK}")
class Model(torch.nn.Module):
def __init__(self, dim: int, device: torch.device) -> None:
super().__init__()
self.layers = torch.nn.ModuleList(
[torch.nn.Linear(DIM, DIM, bias=False, device=device) for _ in range(N_LAYERS)]
)
def forward(self, x) -> torch.Tensor:
for layer in self.layers:
x = layer(x)
return x
def get_profiler() -> torch.profiler.profile:
activities = [ProfilerActivity.CPU]
if DEVICE_TYPE == "xpu":
activities.append(ProfilerActivity.XPU)
elif DEVICE_TYPE == "cuda":
activities.append(ProfilerActivity.CUDA)
else:
raise ValueError(f"Unexpected device type {DEVICE_TYPE=}")
return torch.profiler.profile(
activities=activities,
record_shapes=False,
profile_memory=False,
with_stack=False,
)
def run_one_iter(model, batch, comms_batch, comms_stream) -> None:
with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.bfloat16):
model(batch)
with accel.stream(comms_stream):
dist.all_reduce(comms_batch)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--active", type=int, default=3)
parser.add_argument("--model-batch-size", type=int, default=BATCH_SIZE)
parser.add_argument("--comms-batch-size", type=int, default=COMMS_BATCH_SIZE)
parser.add_argument("--dim", type=int, default=DIM)
parser.add_argument("--base-dir", type=str, default=".")
args = parser.parse_args()
model = Model(device=DEVICE, dim=args.dim)
torch_profiler = get_profiler()
batch = torch.randn((args.model_batch_size, args.dim), device=DEVICE)
comms_batch = torch.randn(args.comms_batch_size, args.dim, device=DEVICE)
comms_stream = accel.Stream(device=DEVICE)
# Warmups
for _ in range(args.warmup):
run_one_iter(model, batch, comms_batch, comms_stream)
dist.barrier()
accel.synchronize()
# Profile
with torch_profiler as p:
for _ in range(args.active):
run_one_iter(model, batch, comms_batch, comms_stream)
# Write out traces
profiler_output_dir = Path(args.base_dir).absolute() / "profiler"
profiler_output_dir.mkdir(exist_ok=True)
file_name = f"profile_comms_compute_overlap.rank_{RANK}.no_step.chrome_trace.json.gz"
export_path_str = str(profiler_output_dir / file_name)
p.export_chrome_trace(export_path_str)
if __name__ == "__main__":
assert WORLD_SIZE > 1
try:
dist.init_process_group(backend=BACKEND)
main()
finally:
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment