Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created May 16, 2022 18:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/2aa05613519ed6f8d80ad3246f084955 to your computer and use it in GitHub Desktop.
Save jamesr66a/2aa05613519ed6f8d80ad3246f084955 to your computer and use it in GitHub Desktop.
import torch
import torch.fx
def foo(x):
with torch.autograd.profiler.record_function('fooo'):
return torch.relu(x)
class RecordFunctionTracer(torch.fx.Tracer):
def trace(self, root, concrete_args=None):
class bound_FX_record_function(torch.autograd.profiler.record_function):
def __init__(_self, name, args=None):
_self.tracer = self
return super(torch.autograd.profiler.record_function, _self).__init__(name, args)
def __enter__(_self):
_self.handle = self.create_proxy('call_function', torch.ops.profiler._record_function_enter,
args=(_self.name, _self.args), kwargs={})
return _self
def __exit__(_self, exc_type, exc_value, traceback):
if _self.run_callbacks_on_exit:
self.create_proxy('call_function', torch.ops.profiler._record_function_exit,
args=(_self.handle,), kwargs={})
old_record_function = torch.autograd.profiler.record_function
torch.autograd.profiler.record_function = bound_FX_record_function
try:
return super().trace(root, concrete_args)
finally:
torch.autograd.profiler.record_function = old_record_function
tracer = RecordFunctionTracer()
graph = tracer.trace(foo)
gm = torch.fx.GraphModule(tracer.root, graph)
print(gm.code)
"""
def forward(self, x):
_record_function_enter = torch.ops.profiler._record_function_enter('fooo', None)
relu = torch.relu(x); x = None
_record_function_exit = torch.ops.profiler._record_function_exit(_record_function_enter); _record_function_enter = None
return relu
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment