Created
April 22, 2024 18:30
-
-
Save garrett361/b8bded976d4bf82f31cd4a34ed97ee30 to your computer and use it in GitHub Desktop.
Minimally profile comms/compute overlap
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Minimal profiling script for profiling compute/comms overlap. | |
torchrun --nnodes=1 --nproc-per-node=2 profile_comms_compute_overlap.py [--no-comms] | |
""" | |
import argparse | |
import os | |
from pathlib import Path | |
import torch | |
import torch.distributed as dist | |
from torch.profiler import ProfilerActivity | |
if torch.cuda.is_available(): | |
from torch import cuda as accel # noqa | |
DEVICE_TYPE = "cuda" | |
BACKEND = "nccl" | |
else: | |
import intel_extension_for_pytorch as ipex # noqa | |
from torch import xpu as accel # noqa | |
import oneccl_bindings_for_pytorch # noqa | |
DEVICE_TYPE = "xpu" | |
BACKEND = "ccl" | |
# Note all of the instructions for ipex profiling | |
# https://github.com/intel/intel-extension-for-pytorch/blob/1296c267c4247a7027d2103d05204b6b556b3d63/docs/tutorials/features/profiler_kineto.md#L24-L24 | |
BATCH_SIZE = 2**14 | |
COMMS_BATCH_SIZE = 2**17 | |
DIM = 2**14 | |
N_LAYERS = 3 | |
STEPS = 3 | |
RANK = int(os.getenv("RANK", 0)) | |
LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) | |
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) | |
DEVICE = torch.device(f"{DEVICE_TYPE}:{LOCAL_RANK}") | |
class Model(torch.nn.Module): | |
def __init__(self, dim: int, device: torch.device) -> None: | |
super().__init__() | |
self.layers = torch.nn.ModuleList( | |
[torch.nn.Linear(DIM, DIM, bias=False, device=device) for _ in range(N_LAYERS)] | |
) | |
def forward(self, x) -> torch.Tensor: | |
for layer in self.layers: | |
x = layer(x) | |
return x | |
def get_profiler() -> torch.profiler.profile: | |
activities = [ProfilerActivity.CPU] | |
if DEVICE_TYPE == "xpu": | |
activities.append(ProfilerActivity.XPU) | |
elif DEVICE_TYPE == "cuda": | |
activities.append(ProfilerActivity.CUDA) | |
else: | |
raise ValueError(f"Unexpected device type {DEVICE_TYPE=}") | |
return torch.profiler.profile( | |
activities=activities, | |
record_shapes=False, | |
profile_memory=False, | |
with_stack=False, | |
) | |
def run_one_iter(model, batch, comms_batch, comms_stream) -> None: | |
with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.bfloat16): | |
model(batch) | |
with accel.stream(comms_stream): | |
dist.all_reduce(comms_batch) | |
def main() -> None: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--warmup", type=int, default=2) | |
parser.add_argument("--active", type=int, default=3) | |
parser.add_argument("--model-batch-size", type=int, default=BATCH_SIZE) | |
parser.add_argument("--comms-batch-size", type=int, default=COMMS_BATCH_SIZE) | |
parser.add_argument("--dim", type=int, default=DIM) | |
parser.add_argument("--base-dir", type=str, default=".") | |
args = parser.parse_args() | |
model = Model(device=DEVICE, dim=args.dim) | |
torch_profiler = get_profiler() | |
batch = torch.randn((args.model_batch_size, args.dim), device=DEVICE) | |
comms_batch = torch.randn(args.comms_batch_size, args.dim, device=DEVICE) | |
comms_stream = accel.Stream(device=DEVICE) | |
# Warmups | |
for _ in range(args.warmup): | |
run_one_iter(model, batch, comms_batch, comms_stream) | |
dist.barrier() | |
accel.synchronize() | |
# Profile | |
with torch_profiler as p: | |
for _ in range(args.active): | |
run_one_iter(model, batch, comms_batch, comms_stream) | |
# Write out traces | |
profiler_output_dir = Path(args.base_dir).absolute() / "profiler" | |
profiler_output_dir.mkdir(exist_ok=True) | |
file_name = f"profile_comms_compute_overlap.rank_{RANK}.no_step.chrome_trace.json.gz" | |
export_path_str = str(profiler_output_dir / file_name) | |
p.export_chrome_trace(export_path_str) | |
if __name__ == "__main__": | |
assert WORLD_SIZE > 1 | |
try: | |
dist.init_process_group(backend=BACKEND) | |
main() | |
finally: | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment