Skip to content

Instantly share code, notes, and snippets.

@ManfeiBai
Created August 30, 2024 21:55
Show Gist options
  • Save ManfeiBai/6dfc589f7d2f297c1954b675022dacee to your computer and use it in GitHub Desktop.
Save ManfeiBai/6dfc589f7d2f297c1954b675022dacee to your computer and use it in GitHub Desktop.
log
(pinupdate) root@6e1dc6c462da:/pytorch/xla# pip install jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Looking in links: https://storage.googleapis.com/jax-releases/jax_nightly_releases.html, https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Collecting jax
Downloading jax-0.4.31-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib
Downloading jaxlib-0.4.31-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes)
Collecting ml-dtypes>=0.2.0 (from jax)
Using cached ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Requirement already satisfied: numpy>=1.24 in /root/miniconda3/envs/pinupdate/lib/python3.10/site-packages (from jax) (2.1.0)
Collecting opt-einsum (from jax)
Using cached opt_einsum-3.3.0-py3-none-any.whl.metadata (6.5 kB)
Collecting scipy>=1.10 (from jax)
Downloading scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Downloading jax-0.4.31-py3-none-any.whl (2.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 50.8 MB/s eta 0:00:00
Downloading jaxlib-0.4.31-cp310-cp310-manylinux2014_x86_64.whl (88.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 88.1/88.1 MB 86.9 MB/s eta 0:00:00
Using cached ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
Downloading scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (41.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 41.2/41.2 MB 126.2 MB/s eta 0:00:00
Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Installing collected packages: scipy, opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.31 jaxlib-0.4.31 ml-dtypes-0.4.0 opt-einsum-3.3.0 scipy-1.14.1
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
(pinupdate) root@6e1dc6c462da:/pytorch/xla# PJRT_DEVICE=TPU python test/test_pallas.py
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.1.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.
If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
Traceback (most recent call last): File "/pytorch/xla/test/test_pallas.py", line 5, in <module>
import torch
File "/pytorch/torch/__init__.py", line 2479, in <module>
from torch import (
File "/pytorch/torch/export/__init__.py", line 64, in <module>
from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection
File "/pytorch/torch/export/dynamic_shapes.py", line 22, in <module>
from .exported_program import ExportedProgram
File "/pytorch/torch/export/exported_program.py", line 27, in <module>
from torch._higher_order_ops.utils import autograd_not_implemented
File "/pytorch/torch/_higher_order_ops/__init__.py", line 1, in <module>
from torch._higher_order_ops.cond import cond
File "/pytorch/torch/_higher_order_ops/cond.py", line 6, in <module>
import torch._subclasses.functional_tensor
File "/pytorch/torch/_subclasses/functional_tensor.py", line 44, in <module>
class FunctionalTensor(torch.Tensor):
File "/pytorch/torch/_subclasses/functional_tensor.py", line 271, in FunctionalTensor
cpu = _conversion_method_template(device=torch.device("cpu"))
/pytorch/torch/_subclasses/functional_tensor.py:271: UserWarning: Failed to initialize NumPy:
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.1.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.
If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
(Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:84.)
cpu = _conversion_method_template(device=torch.device("cpu"))
.....F.FF.E.FF.EEEE...loc("/reduce_max[axes=(1,)]"(callsite("_flash_attention_kernel_single_batch_single_step"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":557:0) at callsite("_flash_attention_kernel"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":338:0) at callsite("_flash_attention_impl"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":746:0) at callsite("_flash_attention"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":216:0) at callsite("flash_attention"("/home/jwtan/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":198:0) at "<module>"("/mnt/disks/ssd/work/pallas/pallas_add.py":50:0)))))))): error: 'vector.multi_reduction' op requires attribute 'reduction_dims'
E..E
======================================================================
ERROR: test_flash_attention_wrapper_segment_ids_1 (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 781, in test_flash_attention_wrapper_segment_ids_1
jax_q = jnp.array(q.numpy(), dtype=jnp.float32)
RuntimeError: Numpy is not available
======================================================================
ERROR: test_paged_attention_wrapper (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 527, in test_paged_attention_wrapper
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
RuntimeError: Numpy is not available
======================================================================
ERROR: test_paged_attention_wrapper_with_attn_logits_soft_cap (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 739, in test_paged_attention_wrapper_with_attn_logits_soft_cap
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
RuntimeError: Numpy is not available
======================================================================
ERROR: test_paged_attention_wrapper_with_dynamo (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 664, in test_paged_attention_wrapper_with_dynamo
output = compiled_paged_attention(
File "/pytorch/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
File "/pytorch/xla/test/test_pallas.py", line 650, in paged_attention_wrapper
def paged_attention_wrapper(q, k, v, seq_lens, page_indices,
File "/pytorch/torch/_dynamo/eval_frame.py", line 632, in _fn
return fn(*args, **kwargs)
File "/pytorch/torch/_functorch/aot_autograd.py", line 1100, in forward
return compiled_fn(full_args)
File "/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 321, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "/pytorch/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 667, in inner_fn
outs = compiled_fn(args)
File "/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 488, in wrapper
return compiled_fn(runtime_args)
File "/pytorch/torch/_functorch/_aot_autograd/utils.py", line 98, in g
return f(*args)
File "/pytorch/torch/_dynamo/backends/torchxla.py", line 37, in fwd
compiled_graph = bridge.extract_compiled_graph(model, args)
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 723, in extract_compiled_graph
return extract_compiled_graph_helper(xla_model, xla_args)
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 839, in extract_compiled_graph_helper
return partition_fx_graph_for_cpu_fallback(xla_model, xla_args,
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 785, in partition_fx_graph_for_cpu_fallback
extract_internal(fused_module), node.args, None)
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 529, in extract_internal
xla_args_need_update) = extract_graph_helper(xla_model,
File "/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py", line 473, in extract_graph_helper
torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, [])
RuntimeError: Bad StatusOr access: INTERNAL: Mosaic failed to compile TPU kernel: Bad rhs type in tpu.matmul
at location: loc("hd,td->ht/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=(Precision.HIGHEST, Precision.HIGHEST) preferred_element_type=float32]"(callsite("flash_attention"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":258:0) at callsite("paged_flash_attention_kernel"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":223:0) at callsite("body"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":323:0) at callsite("paged_flash_attention_kernel_inline_seq_dim"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":359:0) at callsite("paged_attention"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py":620:0) at callsite("trace_pallas"("/pytorch/xla/torch_xla/experimental/custom_kernel.py":135:0) at callsite("paged_attention"("/pytorch/xla/torch_xla/experimental/custom_kernel.py":492:0) at callsite("test_paged_attention_wrapper"("/pytorch/xla/test/test_pallas.py":518:0) at "<module>"("/pytorch/xla/test/test_pallas.py":1031:0)))))))))))
The MLIR operation involved:
%2342 = "tpu.matmul"(%2338, %2340, %2341) <{precision = #tpu.contract_precision<fp32>, transpose_lhs = false, transpose_rhs = true}> : (vector<8x128xf32>, vector<128x128xbf16>, vector<8x128xf32>) -> vector<8x128xf32>
... additional diagnostics were skipped.
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
======================================================================
ERROR: test_paged_attention_wrapper_with_megacore_modes (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 593, in test_paged_attention_wrapper_with_megacore_modes
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
RuntimeError: Numpy is not available
======================================================================
ERROR: test_tpu_custom_call_pallas_flash_attention (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 134, in test_tpu_custom_call_pallas_flash_attention
self.assertTrue(torch.allclose(o[0].cpu(), expected_o.cpu()))
RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: Failed to parse the Mosaic module
======================================================================
ERROR: test_tpu_custom_call_pallas_wrap_flash_attention (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 212, in test_tpu_custom_call_pallas_wrap_flash_attention
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
RuntimeError: Bad StatusOr access: INTERNAL: Mosaic failed to compile TPU kernel: Bad lhs type in tpu.matmul
at location: loc("/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=(Precision.HIGHEST, Precision.HIGHEST) preferred_element_type=float32]"(callsite("_flash_attention_kernel_single_batch_single_step"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":515:0) at callsite("_flash_attention_kernel"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":338:0) at callsite("_flash_attention_impl"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":737:0) at callsite("_flash_attention"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":216:0) at callsite("flash_attention"("/root/miniconda3/envs/pinupdate/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py":198:0) at callsite("trace_pallas"("/pytorch/xla/torch_xla/experimental/custom_kernel.py":135:0) at callsite("wrapped_kernel"("/pytorch/xla/torch_xla/experimental/custom_kernel.py":155:0) at callsite("test_tpu_custom_call_pallas_wrap_flash_attention"("/pytorch/xla/test/test_pallas.py":210:0) at "<module>"("/pytorch/xla/test/test_pallas.py":1031:0)))))))))))
The MLIR operation involved:
%570 = "tpu.matmul"(%567, %568, %569) <{precision = #tpu.contract_precision<fp32>, transpose_lhs = false, transpose_rhs = true}> : (vector<128x128xbf16>, vector<128x128xbf16>, vector<128x128xf32>) -> vector<128x128xf32>
... additional diagnostics were skipped.
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
======================================================================
FAIL: test_flash_attention_backward (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 485, in test_flash_attention_backward
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
AssertionError: False is not true
======================================================================
FAIL: test_flash_attention_sm_scale_backward (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 922, in test_flash_attention_sm_scale_backward
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
AssertionError: False is not true
======================================================================
FAIL: test_flash_attention_wrapper (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 226, in test_flash_attention_wrapper
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
AssertionError: False is not true
======================================================================
FAIL: test_flash_attention_wrapper_sm_scale (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 880, in test_flash_attention_wrapper_sm_scale
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
AssertionError: False is not true
======================================================================
FAIL: test_flash_attention_wrapper_with_dynamo (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/pytorch/xla/test/test_pallas.py", line 247, in test_flash_attention_wrapper_with_dynamo
self.assertTrue(torch.allclose(o_no_causal.cpu(), expected_o.cpu()))
AssertionError: False is not true
----------------------------------------------------------------------
Ran 26 tests in 12.163s
FAILED (failures=5, errors=7)
(pinupdate) root@6e1dc6c462da:/pytorch/xla# pip install numpy
Requirement already satisfied: numpy in /root/miniconda3/envs/pinupdate/lib/python3.10/site-packages (2.1.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
(pinupdate) root@6e1dc6c462da:/pytorch/xla#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment