Skip to content

Instantly share code, notes, and snippets.

@dfyz
Created May 21, 2024 20:51
Show Gist options
  • Save dfyz/a43c071830460f9a9c5b745def3d9ac9 to your computer and use it in GitHub Desktop.
Save dfyz/a43c071830460f9a9c5b745def3d9ac9 to your computer and use it in GitHub Desktop.
An hackish example of integrating NCCL kernel-level profiling into PyTorch
diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp
index c68eb18099..f7038051d3 100644
--- a/torch/csrc/autograd/profiler_kineto.cpp
+++ b/torch/csrc/autograd/profiler_kineto.cpp
@@ -31,6 +31,9 @@
#ifdef USE_KINETO
#include <libkineto.h>
#include <time_since_epoch.h>
+#include <dlfcn.h>
+
+#include <fstream>
#ifndef _MSC_VER
// TODO: TO be removed, once this properly works from libkineto
@@ -608,6 +611,16 @@ void enableProfilerWithEventPostProcess(
state_ptr->setEventPostProcessingCallback(std::move(cb));
}
+void ncclSetSaveTimingsState(bool enabled) {
+ // I'm sorry, all right? Making the autograd profiler depend on custom NCCL properly
+ // is a nightmare, so I'm resorting to this hack.
+ void* libnccl = dlopen("libnccl.so", RTLD_NOW);
+ TORCH_CHECK(libnccl != nullptr, "libnccl.so not found");
+ auto* impl = reinterpret_cast<void(*)(bool)>(dlsym(libnccl, "ncclSetSaveTimingsState"));
+ TORCH_CHECK(impl != nullptr, "ncclSetSaveTimingsState not found");
+ impl(enabled);
+}
+
void enableProfiler(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities,
@@ -647,9 +660,12 @@ void enableProfiler(
if (!config.global()) {
torch::profiler::impl::kineto::startTrace();
}
+ ncclSetSaveTimingsState(true);
}
std::unique_ptr<ProfilerResult> disableProfiler() {
+ ncclSetSaveTimingsState(false);
+
auto state_ptr = ProfilerStateBase::pop();
const auto& config = state_ptr->config();
TORCH_CHECK(
@@ -916,6 +932,12 @@ ProfilerResult::~ProfilerResult() = default;
void ProfilerResult::save(const std::string& path) {
trace_->save(path);
+
+ void* libnccl = dlopen("libnccl.so", RTLD_NOW);
+ TORCH_CHECK(libnccl != nullptr, "libnccl.so not found");
+ auto* impl = reinterpret_cast<void(*)(const char*)>(dlsym(libnccl, "ncclAppendTimingsToJson"));
+ TORCH_CHECK(impl != nullptr, "ncclAppendTimingsToJson not found");
+ impl(path.c_str());
}
} // namespace profiler
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment