Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created May 21, 2024 04:27
Show Gist options
  • Save youkaichao/ed95c221fc9eb059cdf74fc91e3ab358 to your computer and use it in GitHub Desktop.
Save youkaichao/ed95c221fc9eb059cdf74fc91e3ab358 to your computer and use it in GitHub Desktop.
enable verbose cudagraph dump for pytorch
#define _GNU_SOURCE
#include <stdio.h>
#include <link.h>
#include <stdbool.h>
#include <string.h>
#include <stdlib.h>
typedef int cudaError_t;
typedef void* cudaGraph_t;
static cudaError_t (*real_cudaGraphDebugDotPrint)(cudaGraph_t graph, const char* path, unsigned int flags) = NULL;
// Hooked function
cudaError_t my_cudaGraphDebugDotPrint(cudaGraph_t graph, const char* path, unsigned int flags) {
if (real_cudaGraphDebugDotPrint == NULL) {
fprintf(stderr, "Error: real_cudaGraphDebugDotPrint is NULL\n");
return -1;
}
// Call the original function with flags always set to 1
return real_cudaGraphDebugDotPrint(graph, path, 1);
}
// Callback function for each symbol resolution
unsigned int la_version(unsigned int version) {
printf("Audit module loaded.\n");
return version;
}
// Called when a library is bound
char *la_objsearch(const char *name, uintptr_t *cookie, unsigned int flag) {
return (char *)name; // Return the symbol name to continue searching
}
unsigned int la_objopen(struct link_map *map, Lmid_t lmid, uintptr_t *cookie) {
return LA_FLG_BINDTO | LA_FLG_BINDFROM;
}
// Called when a symbol is bound
uintptr_t la_symbind64(Elf64_Sym *sym, unsigned int ndx, uintptr_t *refcook,
uintptr_t *defcook, unsigned int *flags, const char *symname) {
// Perform any custom actions here
if (strcmp(symname, "cudaGraphDebugDotPrint") == 0) {
printf("Symbol bound: %s\n", symname);
// Set the function pointer to the original function address
real_cudaGraphDebugDotPrint = (cudaError_t (*)(cudaGraph_t, const char*, unsigned int))sym->st_value;
// Return the address of the hooked function
return (uintptr_t)my_cudaGraphDebugDotPrint;
}
return sym->st_value; // Return the symbol's actual address
}
@youkaichao
Copy link
Author

Usage:

  • compile it into a shared library gcc -fPIC -shared -o audit.so audit.c
  • execute script with environment variable LD_AUDIT=./audit.so python test.py

Example script:

# test.py
import torch
g = torch.cuda.CUDAGraph()
g.enable_debug_mode()

# Placeholder input used for capture
static_input = torch.empty((5,), device="cuda")

# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for _ in range(3):
        static_output = static_input * 2
torch.cuda.current_stream().wait_stream(s)

# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with torch.cuda.graph(g):
    static_output = static_input * 2

g.debug_dump("cuda_graph_hooked.dot")

# Fills the graph's input memory with new data to compute on
static_input.copy_(torch.full((5,), 3, device="cuda"))
g.replay()
# static_output holds the results
print(static_output)  # full of 3 * 2 = 6

# Fills the graph's input memory with more data to compute on
static_input.copy_(torch.full((5,), 4, device="cuda"))
g.replay()
print(static_output)  # full of 4 * 2 = 8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment