Created
April 15, 2024 20:50
-
-
Save garrett361/bd93d6b2ed805e98ef2fd149e2084323 to your computer and use it in GitHub Desktop.
pytorch profile with comms
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Minimal distributed profiling. Profiles compute and collective communications by default. Pass the | |
`--no-comms` flag to avoid collectives. Run as in | |
torchrun --nnodes=1 --nproc-per-node=2 profile_maybe_with_comms.py [--no-comms] | |
""" | |
import argparse | |
import os | |
from pathlib import Path | |
import torch | |
import torch.distributed as dist | |
from torch.profiler import ProfilerActivity, schedule | |
if torch.cuda.is_available(): | |
assert torch.cuda.is_available() | |
from torch import cuda as accel # noqa | |
DEVICE_TYPE = "cuda" | |
BACKEND = "nccl" | |
else: | |
import intel_extension_for_pytorch as ipex # noqa | |
from torch import xpu as accel # noqa | |
import oneccl_bindings_for_pytorch # noqa | |
DEVICE_TYPE = "xpu" | |
BACKEND = "ccl" | |
# Note all of the instructions for ipex profiling | |
# https://github.com/intel/intel-extension-for-pytorch/blob/1296c267c4247a7027d2103d05204b6b556b3d63/docs/tutorials/features/profiler_kineto.md#L24-L24 | |
BATCH_SIZE = 2**14 | |
DIM: int = 2**14 | |
N_LAYERS = 5 | |
STEPS = 3 | |
RANK = int(os.getenv("RANK", 0)) | |
LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) | |
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) | |
class Model(torch.nn.Module): | |
def __init__(self, device): | |
super().__init__() | |
self.params = torch.nn.ModuleList( | |
[torch.nn.Linear(DIM, DIM, bias=False, device=device) for _ in range(N_LAYERS)] | |
) | |
def forward(self, x): | |
for param in self.params: | |
x = param(x) | |
return x | |
DEVICE = torch.device(f"{DEVICE_TYPE}:{LOCAL_RANK}") | |
def get_profiler() -> torch.profiler.profile: | |
profiling_schedule = schedule(wait=0, warmup=5, active=1) | |
activities = [ProfilerActivity.CPU] | |
if DEVICE_TYPE == "xpu": | |
activities.append(ProfilerActivity.XPU) | |
elif DEVICE_TYPE == "cuda": | |
activities.append(ProfilerActivity.CUDA) | |
else: | |
raise ValueError(f"Unexpected device type {DEVICE_TYPE=}") | |
return torch.profiler.profile( | |
activities=activities, | |
record_shapes=False, | |
profile_memory=False, | |
with_stack=False, | |
schedule=profiling_schedule, | |
) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--no-comms", action="store_true") | |
args = parser.parse_args() | |
model = Model(device=DEVICE) | |
torch_profiler = get_profiler() | |
batch = torch.randn((BATCH_SIZE, DIM), device=DEVICE) | |
with torch_profiler as p: | |
for _ in range(STEPS): | |
with torch.autocast(device_type=DEVICE_TYPE, dtype=torch.bfloat16): | |
model(batch) | |
if not args.no_comms: | |
dist.all_reduce(batch) | |
# Barrier shouldn't be necessary, but doesn't hurt. | |
dist.barrier() | |
profiler_output_dir = Path(".").absolute() / "profiler" | |
profiler_output_dir.mkdir(exist_ok=True) | |
export_path_str = str((profiler_output_dir / "profile_dist_with_comms.chrome_trace.json.gz")) | |
if not RANK: | |
print(f"Writing trace to {export_path_str}", flush=True) | |
p.export_chrome_trace(export_path_str) | |
if __name__ == "__main__": | |
assert WORLD_SIZE > 1 | |
try: | |
dist.init_process_group(backend=BACKEND) | |
main() | |
finally: | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment