Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created April 19, 2022 18:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/ef1e85bbc7e9db77245bd7032d7ca971 to your computer and use it in GitHub Desktop.
Save jamesr66a/ef1e85bbc7e9db77245bd7032d7ca971 to your computer and use it in GitHub Desktop.
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index e2f033d72a..7b3a97991d 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -222,6 +222,56 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
else:
setattr(to_module, field, from_obj)
+class _WrappedCall:
+ def __init__(self, cls, cls_call):
+ self.cls = cls
+ self.cls_call = cls_call
+
+ # Previously, if an error occurred when valid
+ # symbolically-traced code was run with an invalid input, the
+ # user would see the source of the error as coming from
+ # `File "<eval_with_key_N">`, where N is some number. We use
+ # this function to generate a more informative error message. We
+ # return the traceback itself, a message explaining that the
+ # error occurred in a traced Module's generated forward
+ # function, and five lines of context surrounding the faulty
+ # line
+ @staticmethod
+ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
+ # auxiliary variables (for readability)
+ err_lineno = frame_summary.lineno
+ err_line_len = len(frame_summary.line)
+ all_src_lines = linecache.getlines(frame_summary.filename)
+
+ # constituent substrings of the error message
+ tb_repr = traceback.format_exc()
+ custom_msg = ("Call using an FX-traced Module, "
+ f"line {err_lineno} of the traced Module's "
+ "generated forward function:")
+ before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
+ marker = "~" * err_line_len + "~~~ <--- HERE"
+ err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
+
+ # joined message
+ return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
+
+ def __call__(self, obj, *args, **kwargs):
+ try:
+ if self.cls_call is not None:
+ return self.cls_call(obj, *args, **kwargs)
+ else:
+ return super(self.cls, obj).__call__(*args, **kwargs)
+ except Exception as e:
+ assert e.__traceback__
+ topmost_framesummary: traceback.FrameSummary = \
+ traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
+ if "eval_with_key" in topmost_framesummary.filename:
+ print(_WrappedCall._generate_error_message(topmost_framesummary),
+ file=sys.stderr)
+ raise e.with_traceback(None)
+ else:
+ raise e
+
@compatibility(is_backward_compatible=True)
class GraphModule(torch.nn.Module):
"""
@@ -587,51 +637,13 @@ class {module_name}(torch.nn.Module):
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
- # Previously, if an error occurred when valid
- # symbolically-traced code was run with an invalid input, the
- # user would see the source of the error as coming from
- # `File "<eval_with_key_N">`, where N is some number. We use
- # this function to generate a more informative error message. We
- # return the traceback itself, a message explaining that the
- # error occurred in a traced Module's generated forward
- # function, and five lines of context surrounding the faulty
- # line
- def generate_error_message(frame_summary: traceback.FrameSummary) -> str:
- # auxiliary variables (for readability)
- err_lineno = frame_summary.lineno
- err_line_len = len(frame_summary.line)
- all_src_lines = linecache.getlines(frame_summary.filename)
-
- # constituent substrings of the error message
- tb_repr = traceback.format_exc()
- custom_msg = ("Call using an FX-traced Module, "
- f"line {err_lineno} of the traced Module's "
- "generated forward function:")
- before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
- marker = "~" * err_line_len + "~~~ <--- HERE"
- err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
-
- # joined message
- return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
-
- def wrapped_call(self, *args, **kwargs):
- try:
- if cls_call is not None:
- return cls_call(self, *args, **kwargs)
- else:
- return super(cls, self).__call__(*args, **kwargs)
- except Exception as e:
- assert e.__traceback__
- topmost_framesummary: traceback.FrameSummary = \
- traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
- if "eval_with_key" in topmost_framesummary.filename:
- print(generate_error_message(topmost_framesummary),
- file=sys.stderr)
- raise e.with_traceback(None)
- else:
- raise e
-
- cls.__call__ = wrapped_call
+ if '_wrapped_call' not in vars(cls):
+ cls._wrapped_call = _WrappedCall(cls, cls_call)
+
+ def call_wrapped(self, *args, **kwargs):
+ return self._wrapped_call(self, *args, **kwargs)
+
+ cls.__call__ = call_wrapped
return python_code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment