Skip to content

Instantly share code, notes, and snippets.

@qihqi
Created May 30, 2024 22:23
Show Gist options
  • Save qihqi/aa4fd50e5ef3cb96598433bd0f62817c to your computer and use it in GitHub Desktop.
Save qihqi/aa4fd50e5ef3cb96598433bd0f62817c to your computer and use it in GitHub Desktop.
Mycode:
```
import jax.numpy as jnp
import torch
import torch_xla2
import torch_xla2.interop
import torch._dynamo.config
torch._dynamo.config.traceable_tensor_subclasses.add(
torch_xla2.tensor.XLATensor2)
def f(a, b):
return torch.ops.aten.add(a, b)
def backend(f, init):
print('Inside of compiler', f.code)
def wrapped(*args):
print('Inside of wrapped')
return f(*args)
return wrapped
f2 = torch.compile(f, backend=backend)
x = torch.randn(10, 10)
y = torch.randn(10, 10)
env = torch_xla2.default_env()
x, y = env.to_xla((x, y))
res = f2(x, y)
print(res)
print(x + y)
```
Output:
(xla2) hanq-macbookpro:torch_xla2 hanq$ python -m torch_xla2.dynamo
/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
Inside of compiler
def forward(self, L_a_ : torch_xla2_tensor_XLATensor2, L_b_ : torch_xla2_tensor_XLATensor2):
l_a_ = L_a_
l_b_ = L_b_
ret = torch.ops.aten.add(l_a_, l_b_); l_a_ = l_b_ = None
return (ret,)
Inside of wrapped
Traceback (most recent call last):
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/dynamo.py", line 29, in <module>
res = f2(x, y)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/dynamo.py", line 11, in f
def f(a, b):
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1068, in to_subclass
return t.as_subclass(cls)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 169, in __torch_dispatch__
return func(*args, **(kwargs or {}))
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__
return self_._op(*args, **kwargs)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 263, in __torch_function__
return self.env.dispatch(func, types, args, kwargs)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 335, in dispatch
with jax.named_scope(_name_of_func(func)):
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 335, in torch_dynamo_resume_in_dispatch_at_335
with jax.named_scope(_name_of_func(func)):
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 359, in torch_dynamo_resume_in_dispatch_at_335
res = self.j2t_iso(res)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 397, in j2t_iso
return torch_pytree.tree_map_only(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1072, in tree_map_only
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/utils/_pytree.py", line 900, in tree_map
return treespec.unflatten(map(func, *flat_args))
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/utils/_pytree.py", line 736, in unflatten
leaves = list(leaves)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1018, in wrapped
return func(x)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 398, in <lambda>
jnp.ndarray, lambda x: XLATensor2(x, self), jaxarray)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
result = inner_convert(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
super().run()
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1311, in LOAD_ATTR
result = BuiltinVariable(getattr).call_function(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 687, in call_function
result = handler(tx, *args, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1277, in call_getattr
return obj.var_getattr(tx, name)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 746, in var_getattr
).call_function(tx, [], {})
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 335, in call_function
return super().call_function(tx, args, kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 289, in call_function
return super().call_function(tx, args, kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
tracer.run()
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1311, in LOAD_ATTR
result = BuiltinVariable(getattr).call_function(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 687, in call_function
result = handler(tx, *args, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1277, in call_getattr
return obj.var_getattr(tx, name)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 744, in var_getattr
return variables.UserMethodVariable(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 123, in __call__
obj = type.__call__(cls, *args, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 296, in __init__
super().__init__(fn=fn, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 130, in __init__
assert isinstance(
AssertionError: expected FunctionType found nb_func <nanobind.nb_func object at 0x103254c40>
from user code:
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 104, in __new__
dtype = j2t_dtype(elem.dtype)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/jax/_src/array.py", line 243, in dtype
return self.aval.dtype
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment