Created
May 30, 2024 22:23
-
-
Save qihqi/aa4fd50e5ef3cb96598433bd0f62817c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Mycode: | |
``` | |
import jax.numpy as jnp | |
import torch | |
import torch_xla2 | |
import torch_xla2.interop | |
import torch._dynamo.config | |
torch._dynamo.config.traceable_tensor_subclasses.add( | |
torch_xla2.tensor.XLATensor2) | |
def f(a, b): | |
return torch.ops.aten.add(a, b) | |
def backend(f, init): | |
print('Inside of compiler', f.code) | |
def wrapped(*args): | |
print('Inside of wrapped') | |
return f(*args) | |
return wrapped | |
f2 = torch.compile(f, backend=backend) | |
x = torch.randn(10, 10) | |
y = torch.randn(10, 10) | |
env = torch_xla2.default_env() | |
x, y = env.to_xla((x, y)) | |
res = f2(x, y) | |
print(res) | |
print(x + y) | |
``` | |
Output: | |
(xla2) hanq-macbookpro:torch_xla2 hanq$ python -m torch_xla2.dynamo | |
/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. | |
self.pid = os.fork() | |
Inside of compiler | |
def forward(self, L_a_ : torch_xla2_tensor_XLATensor2, L_b_ : torch_xla2_tensor_XLATensor2): | |
l_a_ = L_a_ | |
l_b_ = L_b_ | |
ret = torch.ops.aten.add(l_a_, l_b_); l_a_ = l_b_ = None | |
return (ret,) | |
Inside of wrapped | |
Traceback (most recent call last): | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/runpy.py", line 196, in _run_module_as_main | |
return _run_code(code, main_globals, None, | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/runpy.py", line 86, in _run_code | |
exec(code, run_globals) | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/dynamo.py", line 29, in <module> | |
res = f2(x, y) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn | |
return fn(*args, **kwargs) | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/dynamo.py", line 11, in f | |
def f(a, b): | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1068, in to_subclass | |
return t.as_subclass(cls) | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 169, in __torch_dispatch__ | |
return func(*args, **(kwargs or {})) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__ | |
return self_._op(*args, **kwargs) | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 263, in __torch_function__ | |
return self.env.dispatch(func, types, args, kwargs) | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 335, in dispatch | |
with jax.named_scope(_name_of_func(func)): | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 335, in torch_dynamo_resume_in_dispatch_at_335 | |
with jax.named_scope(_name_of_func(func)): | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 359, in torch_dynamo_resume_in_dispatch_at_335 | |
res = self.j2t_iso(res) | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 397, in j2t_iso | |
return torch_pytree.tree_map_only( | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1072, in tree_map_only | |
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/utils/_pytree.py", line 900, in tree_map | |
return treespec.unflatten(map(func, *flat_args)) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/utils/_pytree.py", line 736, in unflatten | |
leaves = list(leaves) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1018, in wrapped | |
return func(x) | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 398, in <lambda> | |
jnp.ndarray, lambda x: XLATensor2(x, self), jaxarray) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors | |
return callback(frame, cache_entry, hooks, frame_state, skip=1) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame | |
result = inner_convert( | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert | |
return _compile( | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/contextlib.py", line 79, in inner | |
return func(*args, **kwds) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile | |
guarded_code = compile_inner(code, one_graph, hooks, transform) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper | |
r = func(*args, **kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner | |
out_code = transform_code_object(code, transform) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object | |
transformations(instructions, code_options) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn | |
return fn(*args, **kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform | |
tracer.run() | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run | |
super().run() | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run | |
and self.step() | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step | |
getattr(self, inst.opname)(inst) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1311, in LOAD_ATTR | |
result = BuiltinVariable(getattr).call_function( | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 687, in call_function | |
result = handler(tx, *args, **kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1277, in call_getattr | |
return obj.var_getattr(tx, name) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 746, in var_getattr | |
).call_function(tx, [], {}) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 335, in call_function | |
return super().call_function(tx, args, kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 289, in call_function | |
return super().call_function(tx, args, kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function | |
return tx.inline_user_function_return( | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return | |
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call | |
return cls.inline_call_(parent, func, args, kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_ | |
tracer.run() | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run | |
and self.step() | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step | |
getattr(self, inst.opname)(inst) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1311, in LOAD_ATTR | |
result = BuiltinVariable(getattr).call_function( | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 687, in call_function | |
result = handler(tx, *args, **kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1277, in call_getattr | |
return obj.var_getattr(tx, name) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 744, in var_getattr | |
return variables.UserMethodVariable( | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 123, in __call__ | |
obj = type.__call__(cls, *args, **kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 296, in __init__ | |
super().__init__(fn=fn, **kwargs) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 130, in __init__ | |
assert isinstance( | |
AssertionError: expected FunctionType found nb_func <nanobind.nb_func object at 0x103254c40> | |
from user code: | |
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 104, in __new__ | |
dtype = j2t_dtype(elem.dtype) | |
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/jax/_src/array.py", line 243, in dtype | |
return self.aval.dtype | |
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information | |
You can suppress this exception and fall back to eager by setting: | |
import torch._dynamo | |
torch._dynamo.config.suppress_errors = True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment