Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created April 15, 2024 20:50
Show Gist options
  • Save garrett361/bd93d6b2ed805e98ef2fd149e2084323 to your computer and use it in GitHub Desktop.
Save garrett361/bd93d6b2ed805e98ef2fd149e2084323 to your computer and use it in GitHub Desktop.
pytorch profile with comms
"""
Minimal distributed profiling. Profiles compute and collective communications by default. Pass the
`--no-comms` flag to avoid collectives. Run as in
torchrun --nnodes=1 --nproc-per-node=2 profile_maybe_with_comms.py [--no-comms]
"""
import argparse
import os
from pathlib import Path
import torch
import torch.distributed as dist
from torch.profiler import ProfilerActivity, schedule
if torch.cuda.is_available():
assert torch.cuda.is_available()
from torch import cuda as accel # noqa
DEVICE_TYPE = "cuda"
BACKEND = "nccl"
else:
import intel_extension_for_pytorch as ipex # noqa
from torch import xpu as accel # noqa
import oneccl_bindings_for_pytorch # noqa
DEVICE_TYPE = "xpu"
BACKEND = "ccl"
# Note all of the instructions for ipex profiling
# https://github.com/intel/intel-extension-for-pytorch/blob/1296c267c4247a7027d2103d05204b6b556b3d63/docs/tutorials/features/profiler_kineto.md#L24-L24
BATCH_SIZE = 2**14
DIM: int = 2**14
N_LAYERS = 5
STEPS = 3
RANK = int(os.getenv("RANK", 0))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.params = torch.nn.ModuleList(
[torch.nn.Linear(DIM, DIM, bias=False, device=device) for _ in range(N_LAYERS)]
)
def forward(self, x):
for param in self.params:
x = param(x)
return x
DEVICE = torch.device(f"{DEVICE_TYPE}:{LOCAL_RANK}")
def get_profiler() -> torch.profiler.profile:
profiling_schedule = schedule(wait=0, warmup=5, active=1)
activities = [ProfilerActivity.CPU]
if DEVICE_TYPE == "xpu":
activities.append(ProfilerActivity.XPU)
elif DEVICE_TYPE == "cuda":
activities.append(ProfilerActivity.CUDA)
else:
raise ValueError(f"Unexpected device type {DEVICE_TYPE=}")
return torch.profiler.profile(
activities=activities,
record_shapes=False,
profile_memory=False,
with_stack=False,
schedule=profiling_schedule,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--no-comms", action="store_true")
args = parser.parse_args()
model = Model(device=DEVICE)
torch_profiler = get_profiler()
batch = torch.randn((BATCH_SIZE, DIM), device=DEVICE)
with torch_profiler as p:
for _ in range(STEPS):
with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.bfloat16):
model(batch)
if not args.no_comms:
dist.all_reduce(batch)
# Barrier shouldn't be necessary, but doesn't hurt.
dist.barrier()
profiler_output_dir = Path(".").absolute() / "profiler"
profiler_output_dir.mkdir(exist_ok=True)
export_path_str = str((profiler_output_dir / "profile_dist_with_comms.chrome_trace.json.gz"))
if not RANK:
print(f"Writing trace to {export_path_str}", flush=True)
p.export_chrome_trace(export_path_str)
if __name__ == "__main__":
assert WORLD_SIZE > 1
try:
dist.init_process_group(backend=BACKEND)
main()
finally:
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment