Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Created February 27, 2024 02:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save madebyollin/41a948a7c69a36b1e1fded71f253e7ef to your computer and use it in GitHub Desktop.
Save madebyollin/41a948a7c69a36b1e1fded71f253e7ef to your computer and use it in GitHub Desktop.
Add human-readable profiling markers to a pytorch module
def add_profiling_markers(model):
"""Monkey-patch profiling markers into an nn.Module.
Args:
model: an nn.Module
Effect:
all model.named_module() forward calls get wrapped in their
own profiling scope, making traces easier to understand.
"""
from torch.profiler import record_function
def add_profiling_to_forward(name, module):
def profiled_forward(*args, **kwargs):
with record_function(f"{name}.forward"):
return module._forward(*args, **kwargs)
return profiled_forward
for name, module in model.named_modules():
if not hasattr(module, "_forward"):
module._forward = module.forward
module.forward = add_profiling_to_forward(name, module)
# Usage
add_profiling_markers(model)
with torch.profiler.profile() as prof:
y = model(x).cpu()
prof.export_chrome_trace("trace.json")
# then open chrome and load trace.json into the chrome://tracing tab
@madebyollin
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment