Studying JAX v0.4.19 dynamic API

Studying JAX v0.4.19 dynamic API

Indexing with constants (runs, bad results)

Note: related JAX test

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones(sz, jnp.float32)
  return o[0]

jaxpr = jax.make_jaxpr(func)(3)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
{ lambda ; a:i64[]. let
    b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    c:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    d:f32[] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
    ] b c
  in (d,) }
module @test {
  func.func public @catalyst.entry_point(%arg0: tensor<i64>) -> tensor<f32> attributes {llvm.emit_c_interface} {
    %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
    %3 = stablehlo.dynamic_broadcast_in_dim %0, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
    %4 = stablehlo.convert %arg0 : tensor<i64>
    %5 = stablehlo.constant dense<0> : tensor<i64>
    %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor<i64>) -> tensor<1xi64>
    %7 = "stablehlo.gather"(%3, %6) {dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi64>) -> tensor<f32>
    return %7 : tensor<f32>
  func.func @setup() {
    "quantum.init"() : () -> ()
  func.func @teardown() {
    "quantum.finalize"() : () -> ()


Indexing with a variable (runs, bad results)

func, mlir, jaxpr = None, None, None

def func(sz:int, idx:int):
  o = jnp.ones((sz,sz), jnp.float32)
  return o[idx,0]

jaxpr = jax.make_jaxpr(func)(3,0)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
{ lambda ; a:i64[] b:i64[]. let
    c:f32[a,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 1.0
      a a
    d:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    e:bool[] = lt b 0
    f:i64[] = add b d
    g:i64[] = select_n e b f
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    h:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
    i:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    j:i64[2] = concatenate[dimension=0] h i
    k:f32[] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
      slice_sizes=(1, 1)
    ] c j
  in (k,) }
module @test {
  func.func public @catalyst.entry_point(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<f32> attributes {llvm.emit_c_interface} {
    %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
    %3 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %4 = stablehlo.reshape %3 : (tensor<i32>) -> tensor<1xi32>
    %5 = stablehlo.concatenate %2, %4, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %6 = stablehlo.dynamic_broadcast_in_dim %0, %5, dims = [] : (tensor<f32>, tensor<2xi32>) -> tensor<?x?xf32>
    %7 = stablehlo.convert %arg0 : tensor<i64>
    %8 = stablehlo.constant dense<0> : tensor<i64>
    %9 =  LT, %arg1, %8,  SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
    %10 = stablehlo.add %arg1, %7 : tensor<i64>
    %11 = %9, %10, %arg1 : tensor<i1>, tensor<i64>
    %12 = stablehlo.convert %arg0 : tensor<i64>
    %13 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor<i64>) -> tensor<1xi64>
    %14 = stablehlo.constant dense<0> : tensor<i64>
    %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor<i64>) -> tensor<1xi64>
    %16 = stablehlo.concatenate %13, %15, dim = 0 : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
    %17 = "stablehlo.gather"(%6, %16) {dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xi64>) -> tensor<f32>
    return %17 : tensor<f32>
  func.func @setup() {
    "quantum.init"() : () -> ()
  func.func @teardown() {
    "quantum.finalize"() : () -> ()


Indexing with constant slice (denied by JAX)

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones(sz, jnp.float32)
  return o[0:2]

jaxpr = jax.make_jaxpr(func)(3)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
Traceback (most recent call last):
  Cell In[72], line 1
    jaxpr = jax.make_jaxpr(func)(3)
  File /workspace/modules/jax/jax/_src/ in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File /workspace/modules/jax/jax/_src/ in make_jaxpr_f
    jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
  File /workspace/modules/jax/jax/_src/ in wrapper
    return func(*args, **kwargs)
  File /workspace/modules/jax/jax/_src/interpreters/ in trace_to_jaxpr_dynamic2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File /workspace/modules/jax/jax/_src/interpreters/ in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File /workspace/modules/jax/jax/_src/ in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  Cell In[71], line 3 in func
    return o[0:2]
  File /workspace/modules/jax/jax/_src/numpy/ in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File /workspace/modules/jax/jax/_src/numpy/ in _getitem
    return lax_numpy._rewriting_take(self, item)
  File /workspace/modules/jax/jax/_src/numpy/ in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File /workspace/modules/jax/jax/_src/numpy/ in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
  File /workspace/modules/jax/jax/_src/numpy/ in _index_to_gather
    raise IndexError(msg)
IndexError: Cannot use NumPy slice indexing on an array dimension whose size is not statically known (Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>). Try using lax.dynamic_slice/dynamic_update_slice

Traceback (most recent call last):
  Cell In[74], line 1
    mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
AttributeError: 'NoneType' object has no attribute 'out_avals'

Traceback (most recent call last):
  Cell In[75], line 1
    inject_functions(mlir, ctx)
  File /workspace/modules/catalyst/frontend/catalyst/utils/ in inject_functions
    module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'

Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
Failed to parse module as LLVM source

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[77], line 1
  File /workspace/modules/catalyst/frontend/catalyst/ in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/ in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/ in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
Failed to parse module as LLVM source

Indexing less axis than available (MLIR lowering does not work)

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones((sz,sz), jnp.float32)
  return o[0]

jaxpr = jax.make_jaxpr(func)(3)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
{ lambda ; a:i64[]. let
    b:f32[a,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 1.0
      a a
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    c:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    d:f32[a] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,))
      slice_sizes=(1, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    ] b c
    e:f32[a] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None,)] d a
  in (e,) }
Traceback (most recent call last):
  File /workspace/modules/jax/jax/_src/ in __getattr__
    attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute 'type'

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File /opt/venv/bin/ipython3:8
  File /opt/venv/lib/python3.10/site-packages/IPython/ in start_ipython
    return launch_new_instance(argv=argv, **kwargs)
  File /opt/venv/lib/python3.10/site-packages/traitlets/config/ in launch_instance
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ in start
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ in mainloop
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ in interact
    self.run_cell(code, store_history=True)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in run_cell
    result = self._run_cell(
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in _run_cell
    result = runner(coro)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in _pseudo_sync_runner
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  Cell In[82], line 1
    jaxpr = jax.make_jaxpr(func)(3)
  Cell In[81], line 3 in func
    return o[0]
  File /workspace/modules/jax/jax/_src/numpy/ in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File /workspace/modules/jax/jax/_src/numpy/ in _getitem
    return lax_numpy._rewriting_take(self, item)
  File /workspace/modules/jax/jax/_src/numpy/ in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File /workspace/modules/jax/jax/_src/numpy/ in _gather
    y = lax.gather(
JaxStackTraceBeforeTransformation: AttributeError: DynamicJaxprTracer has no attribute type

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.


The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[84], line 1
    mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
  File /workspace/modules/catalyst/frontend/catalyst/utils/ in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
  File /workspace/modules/catalyst/frontend/catalyst/utils/ in custom_lower_jaxpr_to_module
  File /workspace/modules/jax/jax/_src/interpreters/ in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File /workspace/modules/jax/jax/_src/interpreters/ in jaxpr_subcomp
    ans = rule(rule_ctx, *rule_inputs, **eqn.params)
  File /workspace/modules/jax/jax/_src/lax/ in _gather_lower
    slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes)
  File /workspace/modules/jax/jax/_src/interpreters/ in eval_dynamic_shape_as_tensor
    return shape_tensor(eval_dynamic_shape(ctx, shape))
  File /workspace/modules/jax/jax/_src/interpreters/ in shape_tensor
    ds = map(lower_dim, sizes)
  File /workspace/modules/jax/jax/_src/interpreters/ in lower_dim
    if d.type != i32_type:
  File /workspace/modules/jax/jax/_src/ in __getattr__
    raise AttributeError(
AttributeError: DynamicJaxprTracer has no attribute type

Traceback (most recent call last):
  Cell In[85], line 1
    inject_functions(mlir, ctx)
  File /workspace/modules/catalyst/frontend/catalyst/utils/ in inject_functions
    module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'

Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
Failed to parse module as LLVM source

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[87], line 1
  File /workspace/modules/catalyst/frontend/catalyst/ in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/ in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/ in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
Failed to parse module as LLVM source

Indexing with a colon (denied by JAX)

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones((sz,sz), jnp.float32)
  return o[:]

jaxpr = jax.make_jaxpr(func)(3)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
{ lambda ; a:i64[]. let
    b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a
    c:f32[a] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None,)] b a
  in (c,) }
module @test {
  func.func public @catalyst.entry_point(%arg0: tensor<i64>) -> tensor<?xf32> attributes {llvm.emit_c_interface} {
    %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
    %3 = stablehlo.dynamic_broadcast_in_dim %0, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
    %4 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %5 = stablehlo.reshape %4 : (tensor<i32>) -> tensor<1xi32>
    %6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    return %6 : tensor<?xf32>
  func.func @setup() {
    "quantum.init"() : () -> ()
  func.func @teardown() {
    "quantum.finalize"() : () -> ()

Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:9:10: error: failed to legalize operation 'mhlo.dynamic_broadcast_in_dim'
    %6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
catalyst_module:9:10: note: see current operation: %17 = "mhlo.dynamic_broadcast_in_dim"(%11, %16) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
Invalid pass with name 'FinalizingBufferize' failed
While processing 'mlir::detail::OpToOpPassAdaptor' pass of the 'BufferizationPass' pipeline
Failed to lower MLIR module

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[57], line 1
  File /workspace/modules/catalyst/frontend/catalyst/ in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/ in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/ in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:9:10: error: failed to legalize operation 'mhlo.dynamic_broadcast_in_dim'
    %6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
catalyst_module:9:10: note: see current operation: %17 = "mhlo.dynamic_broadcast_in_dim"(%11, %16) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
Invalid pass with name 'FinalizingBufferize' failed
While processing 'mlir::detail::OpToOpPassAdaptor' pass of the 'BufferizationPass' pipeline
Failed to lower MLIR module

Indexing with a colon in one dimention (MLIR lowering failed)

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones((sz,sz,sz), jnp.float32)
  return o[0,:,0]

jaxpr = jax.make_jaxpr(func)(3)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
{ lambda ; a:i64[]. let
    b:f32[a,a,a] = broadcast_in_dim[
      shape=(None, None, None)
    ] 1.0 a a a
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    c:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    d:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    e:i64[2] = concatenate[dimension=0] c d
    f:f32[a] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0, 2), start_index_map=(0, 2))
      slice_sizes=(1, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1)
    ] b e
    g:f32[a] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None,)] f a
  in (g,) }
Traceback (most recent call last):
  File /workspace/modules/jax/jax/_src/ in __getattr__
    attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute 'type'

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File /opt/venv/bin/ipython3:8
  File /opt/venv/lib/python3.10/site-packages/IPython/ in start_ipython
    return launch_new_instance(argv=argv, **kwargs)
  File /opt/venv/lib/python3.10/site-packages/traitlets/config/ in launch_instance
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ in start
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ in mainloop
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ in interact
    self.run_cell(code, store_history=True)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in run_cell
    result = self._run_cell(
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in _run_cell
    result = runner(coro)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in _pseudo_sync_runner
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File /opt/venv/lib/python3.10/site-packages/IPython/core/ in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  Cell In[112], line 1
    jaxpr = jax.make_jaxpr(func)(3)
  Cell In[111], line 3 in func
    return o[0,:,0]
  File /workspace/modules/jax/jax/_src/numpy/ in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File /workspace/modules/jax/jax/_src/numpy/ in _getitem
    return lax_numpy._rewriting_take(self, item)
  File /workspace/modules/jax/jax/_src/numpy/ in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File /workspace/modules/jax/jax/_src/numpy/ in _gather
    y = lax.gather(
JaxStackTraceBeforeTransformation: AttributeError: DynamicJaxprTracer has no attribute type

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.


The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[114], line 1
    mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
  File /workspace/modules/catalyst/frontend/catalyst/utils/ in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
  File /workspace/modules/catalyst/frontend/catalyst/utils/ in custom_lower_jaxpr_to_module
  File /workspace/modules/jax/jax/_src/interpreters/ in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File /workspace/modules/jax/jax/_src/interpreters/ in jaxpr_subcomp
    ans = rule(rule_ctx, *rule_inputs, **eqn.params)
  File /workspace/modules/jax/jax/_src/lax/ in _gather_lower
    slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes)
  File /workspace/modules/jax/jax/_src/interpreters/ in eval_dynamic_shape_as_tensor
    return shape_tensor(eval_dynamic_shape(ctx, shape))
  File /workspace/modules/jax/jax/_src/interpreters/ in shape_tensor
    ds = map(lower_dim, sizes)
  File /workspace/modules/jax/jax/_src/interpreters/ in lower_dim
    if d.type != i32_type:
  File /workspace/modules/jax/jax/_src/ in __getattr__
    raise AttributeError(
AttributeError: DynamicJaxprTracer has no attribute type

Traceback (most recent call last):
  Cell In[115], line 1
    inject_functions(mlir, ctx)
  File /workspace/modules/catalyst/frontend/catalyst/utils/ in inject_functions
    module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'

Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
Failed to parse module as LLVM source

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[117], line 1
  File /workspace/modules/catalyst/frontend/catalyst/ in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/ in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/ in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
Failed to parse module as LLVM source

Indexing with variable slice (denied by JAX)

func, mlir, jaxpr = None, None, None

def func(sz:int, idx:int):
  o = jnp.ones(sz, jnp.float32)
  return o[0:idx]

jaxpr = jax.make_jaxpr(func)(3,0)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
Traceback (most recent call last):
  Cell In[102], line 1
    jaxpr = jax.make_jaxpr(func)(3,0)
  File /workspace/modules/jax/jax/_src/ in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File /workspace/modules/jax/jax/_src/ in make_jaxpr_f
    jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
  File /workspace/modules/jax/jax/_src/ in wrapper
    return func(*args, **kwargs)
  File /workspace/modules/jax/jax/_src/interpreters/ in trace_to_jaxpr_dynamic2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File /workspace/modules/jax/jax/_src/interpreters/ in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File /workspace/modules/jax/jax/_src/ in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  Cell In[101], line 3 in func
    return o[0:idx]
  File /workspace/modules/jax/jax/_src/numpy/ in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File /workspace/modules/jax/jax/_src/numpy/ in _getitem
    return lax_numpy._rewriting_take(self, item)
  File /workspace/modules/jax/jax/_src/numpy/ in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File /workspace/modules/jax/jax/_src/numpy/ in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
  File /workspace/modules/jax/jax/_src/numpy/ in _index_to_gather
    raise IndexError(msg)
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

Traceback (most recent call last):
  Cell In[104], line 1
    mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
AttributeError: 'NoneType' object has no attribute 'out_avals'

Traceback (most recent call last):
  Cell In[105], line 1
    inject_functions(mlir, ctx)
  File /workspace/modules/catalyst/frontend/catalyst/utils/ in inject_functions
    module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'

Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
Failed to parse module as LLVM source

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[107], line 1
  File /workspace/modules/catalyst/frontend/catalyst/ in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/ in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/ in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/ in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
Failed to parse module as LLVM source

