Created
August 30, 2024 21:55
-
-
Save ManfeiBai/6dfc589f7d2f297c1954b675022dacee to your computer and use it in GitHub Desktop.
log
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(pinupdate) root@6e1dc6c462da:/pytorch/xla# pip install jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html | |
Looking in links: https://storage.googleapis.com/jax-releases/jax_nightly_releases.html, https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html | |
Collecting jax | |
Downloading jax-0.4.31-py3-none-any.whl.metadata (22 kB) | |
Collecting jaxlib | |
Downloading jaxlib-0.4.31-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes) | |
Collecting ml-dtypes>=0.2.0 (from jax) | |
Using cached ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB) | |
Requirement already satisfied: numpy>=1.24 in /root/miniconda3/envs/pinupdate/lib/python3.10/site-packages (from jax) (2.1.0) | |
Collecting opt-einsum (from jax) | |
Using cached opt_einsum-3.3.0-py3-none-any.whl.metadata (6.5 kB) | |
Collecting scipy>=1.10 (from jax) | |
Downloading scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB) | |
Downloading jax-0.4.31-py3-none-any.whl (2.0 MB) | |
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 50.8 MB/s eta 0:00:00 | |
Downloading jaxlib-0.4.31-cp310-cp310-manylinux2014_x86_64.whl (88.1 MB) | |
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 88.1/88.1 MB 86.9 MB/s eta 0:00:00 | |
Using cached ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB) | |
Downloading scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (41.2 MB) | |
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 41.2/41.2 MB 126.2 MB/s eta 0:00:00 | |
Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB) | |
Installing collected packages: scipy, opt-einsum, ml-dtypes, jaxlib, jax | |
Successfully installed jax-0.4.31 jaxlib-0.4.31 ml-dtypes-0.4.0 opt-einsum-3.3.0 scipy-1.14.1 | |
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning. | |
(pinupdate) root@6e1dc6c462da:/pytorch/xla# PJRT_DEVICE=TPU python test/test_pallas.py | |
A module that was compiled using NumPy 1.x cannot be run in | |
NumPy 2.1.0 as it may crash. To support both 1.x and 2.x | |
versions of NumPy, modules must be compiled with NumPy 2.0. | |
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'. | |
If you are a user of the module, the easiest solution will be to | |
downgrade to 'numpy<2' or try to upgrade the affected module. | |
We expect that some modules will need time to support NumPy 2. | |
Traceback (most recent call last): File "/pytorch/xla/test/test_pallas.py", line 5, in <module> | |
import torch | |
File "/pytorch/torch/__init__.py", line 2479, in <module> | |
from torch import ( | |
File "/pytorch/torch/export/__init__.py", line 64, in <module> | |
from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection | |
File "/pytorch/torch/export/dynamic_shapes.py", line 22, in <module> | |
from .exported_program import ExportedProgram | |
File "/pytorch/torch/export/exported_program.py", line 27, in <module> | |
from torch._higher_order_ops.utils import autograd_not_implemented | |
File "/pytorch/torch/_higher_order_ops/__init__.py", line 1, in <module> | |
from torch._higher_order_ops.cond import cond | |
File "/pytorch/torch/_higher_order_ops/cond.py", line 6, in <module> | |
import torch._subclasses.functional_tensor | |
File "/pytorch/torch/_subclasses/functional_tensor.py", line 44, in <module> | |
class FunctionalTensor(torch.Tensor): | |
File "/pytorch/torch/_subclasses/functional_tensor.py", line 271, in FunctionalTensor | |
cpu = _conversion_method_template(device=torch.device("cpu")) | |
/pytorch/torch/_subclasses/functional_tensor.py:271: UserWarning: Failed to initialize NumPy: | |
A module that was compiled using NumPy 1.x cannot be run in | |
NumPy 2.1.0 as it may crash. To support both 1.x and 2.x | |
versions of NumPy, modules must be compiled with NumPy 2.0. | |
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'. | |
If you are a user of the module, the easiest solution will be to | |
downgrade to 'numpy<2' or try to upgrade the affected module. | |
We expect that some modules will need time to support NumPy 2. | |
(Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:84.) | |
cpu = _conversion_method_template(device=torch.device("cpu")) | |
.....F.FF.E.FF.EEEE...loc("/reduce_max[axes=(1,)]"(callsite("_flash_attention_kernel_single_batch_single_step"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":557:0) at callsite("_flash_attention_kernel"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":338:0) at callsite("_flash_attention_impl"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":746:0) at callsite("_flash_attention"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":216:0) at callsite("flash_attention"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":198:0) at "<module>"("/mnt/disks/ssd/work/pallas/pallas_add.py":50:0)))))))): error: 'vector.multi_reduction' op requires attribute 'reduction_dims' | |
E..E | |
====================================================================== | |
ERROR: test_flash_attention_wrapper_segment_ids_1 (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 781, in test_flash_attention_wrapper_segment_ids_1 | |
jax_q = jnp.array(q.numpy(), dtype=jnp.float32) | |
RuntimeError: Numpy is not available | |
====================================================================== | |
ERROR: test_paged_attention_wrapper (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 527, in test_paged_attention_wrapper | |
q_jax = jnp.array(q.numpy(), dtype=jnp.float32) | |
RuntimeError: Numpy is not available | |
====================================================================== | |
ERROR: test_paged_attention_wrapper_with_attn_logits_soft_cap (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 739, in test_paged_attention_wrapper_with_attn_logits_soft_cap | |
q_jax = jnp.array(q.numpy(), dtype=jnp.float32) | |
RuntimeError: Numpy is not available | |
====================================================================== | |
ERROR: test_paged_attention_wrapper_with_dynamo (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 664, in test_paged_attention_wrapper_with_dynamo | |
output = compiled_paged_attention( | |
File "/pytorch/torch/_dynamo/eval_frame.py", line 465, in _fn | |
return fn(*args, **kwargs) | |
File "/pytorch/xla/test/test_pallas.py", line 650, in paged_attention_wrapper | |
def paged_attention_wrapper(q, k, v, seq_lens, page_indices, | |
File "/pytorch/torch/_dynamo/eval_frame.py", line 632, in _fn | |
return fn(*args, **kwargs) | |
File "/pytorch/torch/_functorch/aot_autograd.py", line 1100, in forward | |
return compiled_fn(full_args) | |
File "/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 321, in runtime_wrapper | |
all_outs = call_func_at_runtime_with_args( | |
File "/pytorch/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args | |
out = normalize_as_list(f(args)) | |
File "/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 667, in inner_fn | |
outs = compiled_fn(args) | |
File "/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 488, in wrapper | |
return compiled_fn(runtime_args) | |
File "/pytorch/torch/_functorch/_aot_autograd/utils.py", line 98, in g | |
return f(*args) | |
File "/pytorch/torch/_dynamo/backends/torchxla.py", line 37, in fwd | |
compiled_graph = bridge.extract_compiled_graph(model, args) | |
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 723, in extract_compiled_graph | |
return extract_compiled_graph_helper(xla_model, xla_args) | |
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 839, in extract_compiled_graph_helper | |
return partition_fx_graph_for_cpu_fallback(xla_model, xla_args, | |
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 785, in partition_fx_graph_for_cpu_fallback | |
extract_internal(fused_module), node.args, None) | |
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 529, in extract_internal | |
xla_args_need_update) = extract_graph_helper(xla_model, | |
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 473, in extract_graph_helper | |
torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, []) | |
RuntimeError: Bad StatusOr access: INTERNAL: Mosaic failed to compile TPU kernel: Bad rhs type in tpu.matmul | |
at location: loc("hd,td->ht/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=(Precision.HIGHEST, Precision.HIGHEST) preferred_element_type=float32]"(callsite("flash_attention"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":258:0) at callsite("paged_flash_attention_kernel"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":223:0) at callsite("body"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":323:0) at callsite("paged_flash_attention_kernel_inline_seq_dim"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":359:0) at callsite("paged_attention"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":620:0) at callsite("trace_pallas"("/pytorch/xla/torch_xla/experimental/custom_kernel.py":135:0) at callsite("paged_attention"("/pytorch/xla/torch_xla/experimental/custom_kernel.py":492:0) at callsite("test_paged_attention_wrapper"("/pytorch/xla/test/test_pallas.py":518:0) at "<module>"("/pytorch/xla/test/test_pallas.py":1031:0))))))))))) | |
The MLIR operation involved: | |
%2342 = "tpu.matmul"(%2338, %2340, %2341) <{precision = #tpu.contract_precision<fp32>, transpose_lhs = false, transpose_rhs = true}> : (vector<8x128xf32>, vector<128x128xbf16>, vector<8x128xf32>) -> vector<8x128xf32> | |
... additional diagnostics were skipped. | |
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke | |
====================================================================== | |
ERROR: test_paged_attention_wrapper_with_megacore_modes (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 593, in test_paged_attention_wrapper_with_megacore_modes | |
q_jax = jnp.array(q.numpy(), dtype=jnp.float32) | |
RuntimeError: Numpy is not available | |
====================================================================== | |
ERROR: test_tpu_custom_call_pallas_flash_attention (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 134, in test_tpu_custom_call_pallas_flash_attention | |
self.assertTrue(torch.allclose(o[0].cpu(), expected_o.cpu())) | |
RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: Failed to parse the Mosaic module | |
====================================================================== | |
ERROR: test_tpu_custom_call_pallas_wrap_flash_attention (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 212, in test_tpu_custom_call_pallas_wrap_flash_attention | |
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) | |
RuntimeError: Bad StatusOr access: INTERNAL: Mosaic failed to compile TPU kernel: Bad lhs type in tpu.matmul | |
at location: loc("/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=(Precision.HIGHEST, Precision.HIGHEST) preferred_element_type=float32]"(callsite("_flash_attention_kernel_single_batch_single_step"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":515:0) at callsite("_flash_attention_kernel"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":338:0) at callsite("_flash_attention_impl"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":737:0) at callsite("_flash_attention"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":216:0) at callsite("flash_attention"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":198:0) at callsite("trace_pallas"("/pytorch/xla/torch_xla/experimental/custom_kernel.py":135:0) at callsite("wrapped_kernel"("/pytorch/xla/torch_xla/experimental/custom_kernel.py":155:0) at callsite("test_tpu_custom_call_pallas_wrap_flash_attention"("/pytorch/xla/test/test_pallas.py":210:0) at "<module>"("/pytorch/xla/test/test_pallas.py":1031:0))))))))))) | |
The MLIR operation involved: | |
%570 = "tpu.matmul"(%567, %568, %569) <{precision = #tpu.contract_precision<fp32>, transpose_lhs = false, transpose_rhs = true}> : (vector<128x128xbf16>, vector<128x128xbf16>, vector<128x128xf32>) -> vector<128x128xf32> | |
... additional diagnostics were skipped. | |
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke | |
====================================================================== | |
FAIL: test_flash_attention_backward (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 485, in test_flash_attention_backward | |
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) | |
AssertionError: False is not true | |
====================================================================== | |
FAIL: test_flash_attention_sm_scale_backward (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 922, in test_flash_attention_sm_scale_backward | |
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) | |
AssertionError: False is not true | |
====================================================================== | |
FAIL: test_flash_attention_wrapper (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 226, in test_flash_attention_wrapper | |
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) | |
AssertionError: False is not true | |
====================================================================== | |
FAIL: test_flash_attention_wrapper_sm_scale (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 880, in test_flash_attention_wrapper_sm_scale | |
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) | |
AssertionError: False is not true | |
====================================================================== | |
FAIL: test_flash_attention_wrapper_with_dynamo (__main__.PallasTest) | |
---------------------------------------------------------------------- | |
Traceback (most recent call last): | |
File "/pytorch/xla/test/test_pallas.py", line 247, in test_flash_attention_wrapper_with_dynamo | |
self.assertTrue(torch.allclose(o_no_causal.cpu(), expected_o.cpu())) | |
AssertionError: False is not true | |
---------------------------------------------------------------------- | |
Ran 26 tests in 12.163s | |
FAILED (failures=5, errors=7) | |
(pinupdate) root@6e1dc6c462da:/pytorch/xla# pip install numpy | |
Requirement already satisfied: numpy in /root/miniconda3/envs/pinupdate/lib/python3.10/site-packages (2.1.0) | |
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning. | |
(pinupdate) root@6e1dc6c462da:/pytorch/xla# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment