Skip to content

Instantly share code, notes, and snippets.

@sergei-mironov
Created November 10, 2023 10:24
Show Gist options
  • Save sergei-mironov/324c3c66d210faba5aa5dc9acb5e5a4e to your computer and use it in GitHub Desktop.
Save sergei-mironov/324c3c66d210faba5aa5dc9acb5e5a4e to your computer and use it in GitHub Desktop.
Studying JAX v0.4.19 dynamic API

Studying JAX v0.4.19 dynamic API

Indexing with constants (runs, bad results)

Note: related JAX test

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones(sz, jnp.float32)
  return o[0]

jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
{ lambda ; a:i64[]. let
    b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    c:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    d:f32[] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1,)
      unique_indices=True
    ] b c
  in (d,) }
module @test {
  func.func public @catalyst.entry_point(%arg0: tensor<i64>) -> tensor<f32> attributes {llvm.emit_c_interface} {
    %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
    %3 = stablehlo.dynamic_broadcast_in_dim %0, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
    %4 = stablehlo.convert %arg0 : tensor<i64>
    %5 = stablehlo.constant dense<0> : tensor<i64>
    %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor<i64>) -> tensor<1xi64>
    %7 = "stablehlo.gather"(%3, %6) {dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi64>) -> tensor<f32>
    return %7 : tensor<f32>
  }
  func.func @setup() {
    "quantum.init"() : () -> ()
    return
  }
  func.func @teardown() {
    "quantum.finalize"() : () -> ()
    return
  }
}

array(4.68817794e-310)

Indexing with a variable (runs, bad results)

func, mlir, jaxpr = None, None, None

def func(sz:int, idx:int):
  o = jnp.ones((sz,sz), jnp.float32)
  return o[idx,0]

jaxpr = jax.make_jaxpr(func)(3,0)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3,0)
{ lambda ; a:i64[] b:i64[]. let
    c:f32[a,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 1.0
      a a
    d:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    e:bool[] = lt b 0
    f:i64[] = add b d
    g:i64[] = select_n e b f
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    h:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
    i:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    j:i64[2] = concatenate[dimension=0] h i
    k:f32[] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=True
    ] c j
  in (k,) }
module @test {
  func.func public @catalyst.entry_point(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<f32> attributes {llvm.emit_c_interface} {
    %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
    %3 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %4 = stablehlo.reshape %3 : (tensor<i32>) -> tensor<1xi32>
    %5 = stablehlo.concatenate %2, %4, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %6 = stablehlo.dynamic_broadcast_in_dim %0, %5, dims = [] : (tensor<f32>, tensor<2xi32>) -> tensor<?x?xf32>
    %7 = stablehlo.convert %arg0 : tensor<i64>
    %8 = stablehlo.constant dense<0> : tensor<i64>
    %9 = stablehlo.compare  LT, %arg1, %8,  SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
    %10 = stablehlo.add %arg1, %7 : tensor<i64>
    %11 = stablehlo.select %9, %10, %arg1 : tensor<i1>, tensor<i64>
    %12 = stablehlo.convert %arg0 : tensor<i64>
    %13 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor<i64>) -> tensor<1xi64>
    %14 = stablehlo.constant dense<0> : tensor<i64>
    %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor<i64>) -> tensor<1xi64>
    %16 = stablehlo.concatenate %13, %15, dim = 0 : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
    %17 = "stablehlo.gather"(%6, %16) {dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xi64>) -> tensor<f32>
    return %17 : tensor<f32>
  }
  func.func @setup() {
    "quantum.init"() : () -> ()
    return
  }
  func.func @teardown() {
    "quantum.finalize"() : () -> ()
    return
  }
}

array(5.26354425e-315)

Indexing with constant slice (denied by JAX)

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones(sz, jnp.float32)
  return o[0:2]

jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
Traceback (most recent call last):
  Cell In[72], line 1
    jaxpr = jax.make_jaxpr(func)(3)
  File /workspace/modules/jax/jax/_src/traceback_util.py:177 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File /workspace/modules/jax/jax/_src/api.py:2462 in make_jaxpr_f
    jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
  File /workspace/modules/jax/jax/_src/profiler.py:340 in wrapper
    return func(*args, **kwargs)
  File /workspace/modules/jax/jax/_src/interpreters/partial_eval.py:2239 in trace_to_jaxpr_dynamic2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File /workspace/modules/jax/jax/_src/interpreters/partial_eval.py:2254 in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File /workspace/modules/jax/jax/_src/linear_util.py:191 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  Cell In[71], line 3 in func
    return o[0:2]
  File /workspace/modules/jax/jax/_src/numpy/array_methods.py:728 in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File /workspace/modules/jax/jax/_src/numpy/array_methods.py:341 in _getitem
    return lax_numpy._rewriting_take(self, item)
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4323 in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4332 in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4589 in _index_to_gather
    raise IndexError(msg)
IndexError: Cannot use NumPy slice indexing on an array dimension whose size is not statically known (Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>). Try using lax.dynamic_slice/dynamic_update_slice

None
Traceback (most recent call last):
  Cell In[74], line 1
    mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
AttributeError: 'NoneType' object has no attribute 'out_avals'

Traceback (most recent call last):
  Cell In[75], line 1
    inject_functions(mlir, ctx)
  File /workspace/modules/catalyst/frontend/catalyst/utils/gen_mlir.py:58 in inject_functions
    module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'

None
Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source


The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[77], line 1
    qjit(str(mlir))(3)
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source


Indexing less axis than available (MLIR lowering does not work)

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones((sz,sz), jnp.float32)
  return o[0]

jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
{ lambda ; a:i64[]. let
    b:f32[a,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 1.0
      a a
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    c:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    d:f32[a] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0,), start_index_map=(0,))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
      unique_indices=True
    ] b c
    e:f32[a] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None,)] d a
  in (e,) }
Traceback (most recent call last):
  File /workspace/modules/jax/jax/_src/core.py:744 in __getattr__
    attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute 'type'

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File /opt/venv/bin/ipython3:8
    sys.exit(start_ipython())
  File /opt/venv/lib/python3.10/site-packages/IPython/__init__.py:128 in start_ipython
    return launch_new_instance(argv=argv, **kwargs)
  File /opt/venv/lib/python3.10/site-packages/traitlets/config/application.py:1043 in launch_instance
    app.start()
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ipapp.py:318 in start
    self.shell.mainloop()
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/interactiveshell.py:888 in mainloop
    self.interact()
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/interactiveshell.py:881 in interact
    self.run_cell(code, store_history=True)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3006 in run_cell
    result = self._run_cell(
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3061 in _run_cell
    result = runner(coro)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:129 in _pseudo_sync_runner
    coro.send(None)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3266 in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3445 in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3505 in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  Cell In[82], line 1
    jaxpr = jax.make_jaxpr(func)(3)
  Cell In[81], line 3 in func
    return o[0]
  File /workspace/modules/jax/jax/_src/numpy/array_methods.py:728 in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File /workspace/modules/jax/jax/_src/numpy/array_methods.py:341 in _getitem
    return lax_numpy._rewriting_take(self, item)
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4323 in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4350 in _gather
    y = lax.gather(
JaxStackTraceBeforeTransformation: AttributeError: DynamicJaxprTracer has no attribute type

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[84], line 1
    mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
  File /workspace/modules/catalyst/frontend/catalyst/utils/jax_extras.py:299 in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
  File /workspace/modules/catalyst/frontend/catalyst/utils/jax_extras.py:367 in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:1216 in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:1433 in jaxpr_subcomp
    ans = rule(rule_ctx, *rule_inputs, **eqn.params)
  File /workspace/modules/jax/jax/_src/lax/slicing.py:1827 in _gather_lower
    slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes)
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:672 in eval_dynamic_shape_as_tensor
    return shape_tensor(eval_dynamic_shape(ctx, shape))
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:96 in shape_tensor
    ds = map(lower_dim, sizes)
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:93 in lower_dim
    if d.type != i32_type:
  File /workspace/modules/jax/jax/_src/core.py:746 in __getattr__
    raise AttributeError(
AttributeError: DynamicJaxprTracer has no attribute type

Traceback (most recent call last):
  Cell In[85], line 1
    inject_functions(mlir, ctx)
  File /workspace/modules/catalyst/frontend/catalyst/utils/gen_mlir.py:58 in inject_functions
    module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'

None
Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source


The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[87], line 1
    qjit(str(mlir))(3)
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source


Indexing with a colon (denied by JAX)

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones((sz,sz), jnp.float32)
  return o[:]

jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
{ lambda ; a:i64[]. let
    b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a
    c:f32[a] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None,)] b a
  in (c,) }
module @test {
  func.func public @catalyst.entry_point(%arg0: tensor<i64>) -> tensor<?xf32> attributes {llvm.emit_c_interface} {
    %0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %2 = stablehlo.reshape %1 : (tensor<i32>) -> tensor<1xi32>
    %3 = stablehlo.dynamic_broadcast_in_dim %0, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
    %4 = stablehlo.convert %arg0 : (tensor<i64>) -> tensor<i32>
    %5 = stablehlo.reshape %4 : (tensor<i32>) -> tensor<1xi32>
    %6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    return %6 : tensor<?xf32>
  }
  func.func @setup() {
    "quantum.init"() : () -> ()
    return
  }
  func.func @teardown() {
    "quantum.finalize"() : () -> ()
    return
  }
}

Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:9:10: error: failed to legalize operation 'mhlo.dynamic_broadcast_in_dim'
    %6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
         ^
catalyst_module:9:10: note: see current operation: %17 = "mhlo.dynamic_broadcast_in_dim"(%11, %16) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
Invalid pass with name 'FinalizingBufferize' failed
While processing 'mlir::detail::OpToOpPassAdaptor' pass of the 'BufferizationPass' pipeline
Failed to lower MLIR module


The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[57], line 1
    qjit(str(mlir))(3)
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:9:10: error: failed to legalize operation 'mhlo.dynamic_broadcast_in_dim'
    %6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
         ^
catalyst_module:9:10: note: see current operation: %17 = "mhlo.dynamic_broadcast_in_dim"(%11, %16) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
Invalid pass with name 'FinalizingBufferize' failed
While processing 'mlir::detail::OpToOpPassAdaptor' pass of the 'BufferizationPass' pipeline
Failed to lower MLIR module


Indexing with a colon in one dimention (MLIR lowering failed)

func, mlir, jaxpr = None, None, None

def func(sz:int):
  o = jnp.ones((sz,sz,sz), jnp.float32)
  return o[0,:,0]

jaxpr = jax.make_jaxpr(func)(3)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3)
{ lambda ; a:i64[]. let
    b:f32[a,a,a] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(None, None, None)
    ] 1.0 a a a
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    _:i64[] = convert_element_type[new_dtype=int64 weak_type=False] a
    c:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    d:i64[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
    e:i64[2] = concatenate[dimension=0] c d
    f:f32[a] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(0, 2), start_index_map=(0, 2))
      fill_value=None
      indices_are_sorted=True
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1)
      unique_indices=True
    ] b e
    g:f32[a] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None,)] f a
  in (g,) }
Traceback (most recent call last):
  File /workspace/modules/jax/jax/_src/core.py:744 in __getattr__
    attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute 'type'

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File /opt/venv/bin/ipython3:8
    sys.exit(start_ipython())
  File /opt/venv/lib/python3.10/site-packages/IPython/__init__.py:128 in start_ipython
    return launch_new_instance(argv=argv, **kwargs)
  File /opt/venv/lib/python3.10/site-packages/traitlets/config/application.py:1043 in launch_instance
    app.start()
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/ipapp.py:318 in start
    self.shell.mainloop()
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/interactiveshell.py:888 in mainloop
    self.interact()
  File /opt/venv/lib/python3.10/site-packages/IPython/terminal/interactiveshell.py:881 in interact
    self.run_cell(code, store_history=True)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3006 in run_cell
    result = self._run_cell(
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3061 in _run_cell
    result = runner(coro)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:129 in _pseudo_sync_runner
    coro.send(None)
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3266 in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3445 in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File /opt/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3505 in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  Cell In[112], line 1
    jaxpr = jax.make_jaxpr(func)(3)
  Cell In[111], line 3 in func
    return o[0,:,0]
  File /workspace/modules/jax/jax/_src/numpy/array_methods.py:728 in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File /workspace/modules/jax/jax/_src/numpy/array_methods.py:341 in _getitem
    return lax_numpy._rewriting_take(self, item)
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4323 in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4350 in _gather
    y = lax.gather(
JaxStackTraceBeforeTransformation: AttributeError: DynamicJaxprTracer has no attribute type

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[114], line 1
    mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
  File /workspace/modules/catalyst/frontend/catalyst/utils/jax_extras.py:299 in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
  File /workspace/modules/catalyst/frontend/catalyst/utils/jax_extras.py:367 in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:1216 in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:1433 in jaxpr_subcomp
    ans = rule(rule_ctx, *rule_inputs, **eqn.params)
  File /workspace/modules/jax/jax/_src/lax/slicing.py:1827 in _gather_lower
    slice_sizes = mlir.eval_dynamic_shape_as_tensor(ctx, slice_sizes)
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:672 in eval_dynamic_shape_as_tensor
    return shape_tensor(eval_dynamic_shape(ctx, shape))
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:96 in shape_tensor
    ds = map(lower_dim, sizes)
  File /workspace/modules/jax/jax/_src/interpreters/mlir.py:93 in lower_dim
    if d.type != i32_type:
  File /workspace/modules/jax/jax/_src/core.py:746 in __getattr__
    raise AttributeError(
AttributeError: DynamicJaxprTracer has no attribute type

Traceback (most recent call last):
  Cell In[115], line 1
    inject_functions(mlir, ctx)
  File /workspace/modules/catalyst/frontend/catalyst/utils/gen_mlir.py:58 in inject_functions
    module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'

None
Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source


The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[117], line 1
    qjit(str(mlir))(3)
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source


Indexing with variable slice (denied by JAX)

func, mlir, jaxpr = None, None, None

def func(sz:int, idx:int):
  o = jnp.ones(sz, jnp.float32)
  return o[0:idx]

jaxpr = jax.make_jaxpr(func)(3,0)
print(jaxpr)
mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
inject_functions(mlir, ctx)
print(mlir)
qjit(str(mlir))(3,0)
Traceback (most recent call last):
  Cell In[102], line 1
    jaxpr = jax.make_jaxpr(func)(3,0)
  File /workspace/modules/jax/jax/_src/traceback_util.py:177 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File /workspace/modules/jax/jax/_src/api.py:2462 in make_jaxpr_f
    jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
  File /workspace/modules/jax/jax/_src/profiler.py:340 in wrapper
    return func(*args, **kwargs)
  File /workspace/modules/jax/jax/_src/interpreters/partial_eval.py:2239 in trace_to_jaxpr_dynamic2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File /workspace/modules/jax/jax/_src/interpreters/partial_eval.py:2254 in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File /workspace/modules/jax/jax/_src/linear_util.py:191 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  Cell In[101], line 3 in func
    return o[0:idx]
  File /workspace/modules/jax/jax/_src/numpy/array_methods.py:728 in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File /workspace/modules/jax/jax/_src/numpy/array_methods.py:341 in _getitem
    return lax_numpy._rewriting_take(self, item)
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4323 in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4332 in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
  File /workspace/modules/jax/jax/_src/numpy/lax_numpy.py:4584 in _index_to_gather
    raise IndexError(msg)
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

None
Traceback (most recent call last):
  Cell In[104], line 1
    mlir,ctx,_,_ = jaxpr_to_mlir("test", jaxpr, jaxpr.out_avals[0].shape)
AttributeError: 'NoneType' object has no attribute 'out_avals'

Traceback (most recent call last):
  Cell In[105], line 1
    inject_functions(mlir, ctx)
  File /workspace/modules/catalyst/frontend/catalyst/utils/gen_mlir.py:58 in inject_functions
    module.body.operations[0].attributes["llvm.emit_c_interface"] = ir.UnitAttr.get(context=ctx)
AttributeError: 'NoneType' object has no attribute 'body'

None
Traceback (most recent call last):
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:379 in run_from_ir
    compiler_output = run_compiler_driver(
RuntimeError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source


The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  Cell In[107], line 1
    qjit(str(mlir))(3,0)
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:666 in __call__
    function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:641 in _ensure_real_arguments_and_formal_parameters_are_compatible
    function = self.compile()
  File /workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py:574 in compile
    shared_object, llvm_ir, inferred_func_data = self.compiler.run_from_ir(
  File /workspace/modules/catalyst/frontend/catalyst/compiler.py:389 in run_from_ir
    raise CompileError(*e.args) from e
CompileError: Compilation failed:
catalyst_module:1:1: error: custom op 'None' is unknown (tried 'builtin.None' as well)
None
^
Failed to parse module as MLIR source, retrying parsing as LLVM source
catalyst_module: catalyst_module:1:1: error: expected top-level entity
None
^
Failed to parse module as LLVM source


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment