Skip to content

Instantly share code, notes, and snippets.

@bertmaher
Last active April 23, 2024 21:31
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save bertmaher/ba30e7e670338b797691d7a0b6595e34 to your computer and use it in GitHub Desktop.
Save bertmaher/ba30e7e670338b797691d7a0b6595e34 to your computer and use it in GitHub Desktop.
PyTorch Perf Analysis Tools

Performance Analysis Tools

Goals:

  • Show tools available for analyzing perf
  • Use them to analyze from torch/hub
  • Look for compiler/runtime opportunities

Ideal outcome of analyses, prioritize optimizations by:

  • Expected improvement (+ generality of improvement)
  • Effort to implement

Helpful to know what machine resources are under stress. Some examples from other teams I've worked on:

  • NUMA traffic (HHVM)
  • I-TLB, I-cache (HHVM)
  • Page cache (Redex)
  • Network latency (Buck)
  • Filesystem accesses (Buck)
  • PCI-e traffic (Glow)
  • Memory map creation (Glow)

Benchmarks

Torch Hub: https://github.com/pytorch/hub/

Benchmarks in hub/benchmarks

BERT: hub/benchmarks/models/BERT-pytorch

  1. Generally, where's the low hanging fruit for perf
  2. Case study: how much would BERT benefit from fused reductions?

pytest

cd hub/benchmarks
pytest test_bench.py -k test_train[BERT-pytorch-cuda]
------------------------------------------------- benchmark 'hub': 1 tests ------------------------------------------------
Name (time in ms)                     Min      Max     Mean  StdDev   Median     IQR  Outliers      OPS  Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------
test_train[BERT-pytorch-cuda]     63.3723  66.5398  64.7023  0.9612  64.5468  1.3138       5;0  15.4554      16           1
---------------------------------------------------------------------------------------------------------------------------

Useful options:

--benchmark-min-time [s]   # Minimum run time, reduces noise
-s                         # Show benchmark output
--co                       # List available benchmarks
-k                         # Select benchmarks by substring

Enable JIT:

class Model:
    def __init__(self, device=None, jit=True):
        self.device = device

Use legacy executor/fuser:

torch._C._jit_set_profiling_executor(False)

TODO: We should build in options to control these. I forget them all the time. But different fusers/executors don't always work nicely in same procss.

Environment variables:

  • PYTORCH_JIT_LOG_LEVEL
    • >>profiling_graph_executor_impl, >>graph_executor: Optimization passes
    • >>kernel, cuda_codegen: Tensor expressions and generated CUDA-C
  • PYTORCH_FUSION_DEBUG
    • 1: Fused kernel C code (CPU and CUDA)
    • 2: Assembly (CPU)

PYTORCH_JIT_LOG_LEVEL=">>graph_executor" gives good info but way too much.

Ad hoc logging. Right after compilation:

--- a/torch/csrc/jit/runtime/graph_executor.cpp
+++ b/torch/csrc/jit/runtime/graph_executor.cpp
@@ -570,6 +570,7 @@ struct GraphExecutorImpl : public GraphExecutorImplBase {
         return it->second;
       }
       auto plan = compileSpec(spec);
+      std::clog << "Compiled:\n" << *plan.graph << "\n";
       auto r = plan_cache.emplace(std::move(spec), std::move(plan));
       logging::getLogger()->addStatValue(
           logging::runtime_counters::EXECUTION_PLAN_CACHE_MISS, 1.0);

Output: P139751642

Direct benchmark runner

cd models/BERT-pytorch
python hubconf.py

Manual perf counters

m = Model(device='cuda', jit=True)
for _ in range(4):
    s = time.perf_counter()
    m.train()
    e = time.perf_counter()
    print("time (ms): {:.2f}".format((e - s) * 1000))
time (ms): 2649.00
time (ms): 59.07
time (ms): 57.54
time (ms): 56.84

Dump stats at exit (C++)

static int sum_counter = 0;                                                                          
struct AtExit {                                                                          
  ~AtExit() { printf("calls to sum to size: %d\n", sum_counter); }                                                                     
} atExit;                                                                          

gdb scripts

nvidia-smi

Running the training loop continuously while inspecting device utilization shows the GPU at about 33% utilization.

nvidia-smi dmon
# gpu   pwr  temp    sm   mem   enc   dec  mclk  pclk
# Idx     W     C     %     %     %     %   MHz   MHz
    0    73    38    33     5     0     0  3004   949
    0    73    38    32     5     0     0  3004   949
    0    73    38    33     5     0     0  3004   949
    0    73    38    33     5     0     0  3004   949
    0    73    38    33     5     0     0  3004   949
    0    72    38    34     5     0     0  3004   949
    0    72    39    33     5     0     0  3004   949

torch.autograd.profiler

m = Model(device='cuda', jit=True)
m.train()
with torch.autograd.profiler.profile(use_cuda=True) as prof:
    s = time.perf_counter()
    m.train()
    e = time.perf_counter()
print(prof.key_averages().table(sort_by="cuda_time_total"))
print("time (ms): {:.2f}".format((e - s) * 1000))

Results: P139718714

  • There doesn't seem to be a CUDA self time.
  • The CUDA profiling overhead is about 1x total runtime. (CPU-only is better.)

Chrome trace

prof.export_chrome_trace("bert.json")
jf upload bert.json

Trace: https://lookaside.facebook.com/intern/diff/file/data/?number=300190110&download=1

(Anybody watching happen to know what value the CUDA markers add?)

record_function

Python:

with torch.autograd.profiler.record_function("forward"):
    next_sent_output, mask_lm_output = trainer.model.forward(*self.example_inputs)
with torch.autograd.profiler.record_function("loss"):
    next_loss = trainer.criterion(next_sent_output, self.is_next)
    mask_loss = trainer.criterion(mask_lm_output.transpose(1, 2), self.bert_label)
    loss = next_loss + mask_loss
with torch.autograd.profiler.record_function("zero_grad"):
    trainer.optim_schedule.zero_grad()
with torch.autograd.profiler.record_function("backward"):
    loss.backward()
with torch.autograd.profiler.record_function("optimizer"):
    trainer.optim_schedule.step_and_update_lr()
Operator(
    "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)",
    [](Stack* stack) {
      RECORD_FUNCTION("aten::_grad_sum_to_size", last(stack, 2));
      IValue self, size;
      pop(stack, self, size);
      if (size.isNone()) {
        push(stack, std::move(self));
      } else {
        push(stack, at::sum_to(self.toTensor(), size.toIntVector()));
      }
    },
    aliasAnalysisFromSchema()),

Trace: https://lookaside.facebook.com/intern/diff/file/data/?number=300192580&download=1

  • (Note: I removed use_cuda because it distorts the trace w/o a ton of added value)
  • 40% of time in optimizer which looks mostly elementwise
  • 6% of time spent zeroing gradients
  • 2% is aten::_grad_sum_to_size

nvprof

m = Model(device='cuda', jit=True)
m.train()
with torch.cuda.profiler.profile():
    with torch.autograd.profiler.emit_nvtx():
        m.train()
  • The first CUDA API call after enabling profiling is dog slow. Maybe it would be better to run a warmup run, but the slow call shows up in the stats no matter what, so...
nvprof --profile-from-start off -- python hubconf.py

P139720772

  • Top 2 kernels (11%, 6%) are elementwise
  • reduce sum is next (6%)
  • device-device copies (5%)
  • But, like, total GPU kernel time is only 12 ms, out of 57 ms total?
  • API calls are ~40ms (even discounting the crazy slow initial cudaLaunchKernel).
    • Should I trust the API timings?
    • Does 12 ms (device) + 40 ms (host API) = 52 ms actually explain the overall time well?
    • Should we look at CPU side overhead?

nvprof trace

Get Edward Yang's nvprof2json

nvprof --profile-from-start off -o bert.nvprof -- python hubconf.py
nvprof2json bert.nvprof > bert.nvprof.json

Trace: https://lookaside.facebook.com/intern/diff/file/data/?number=300405276&download=1

  • nvvp is the NVIDIA-developed trace viewer. I like it less, but that may be an issue of familiarity. On MacOS, you will need a special JDK from Azul: Visual Profiler JRE requirements.

"Limit study"

Want to know how fast we can make aten::_grad_sum_to_size? Replace it with aten::empty

Baseline:

Empty:

Clearly this reduction is not a bottleneck (?).

CPU profiling

CPU is easier to interpret since everything is synchronous.

m = Model(device='cpu', jit=True)
m.train()
with torch.autograd.profiler.profile() as prof:
    m.train()
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
prof.export_chrome_trace("bert_train_cpu.json")

It's less slow than I expected (127 ms).

import cProfile
torch._C._jit_set_profiling_executor(False)
m = Model(device='cpu', jit=True)
m.train()
cProfile.run("m.train()", "bert_train_cpu.stats")
import pstats
pstats.Stats("bert_train_cpu.stats").strip_dirs().sort_stats("time").print_stats()
pip install flameprof
flameprof bert_train_cpu.stats
  • perf
OMP_NUM_THREADS=1 perf.real record -g -- python hubconf.py
perf.real report

Not very serious attempt:

% perf.real script -i perf.flat.data | cut -c82- | sed 's/+0x[0-9a-f]*//' | grep -v ^at::parallel_for | grep -v ^c10::function_ref | grep -v ^sgemm | grep -v at::CPUGeneratorImpl::random64 | grep -v ::uniform_real_distribution | grep -v ^Sleef | grep -v ::vectorized_loop | grep -v ::bernoulli_distribution | wc -l
23247
% perf.real script -i perf.flat.data | wc -l
51702

So like 50% overhead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment