Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Created October 8, 2023 13:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save llandsmeer/a6b857431050fb336a49835661a9e7f3 to your computer and use it in GitHub Desktop.
Save llandsmeer/a6b857431050fb336a49835661a9e7f3 to your computer and use it in GitHub Desktop.
JAX to MLIR using IREE example
import jax
import iree.compiler as compiler
def YOUR_FUNCTION(x):
return x + 1
input_sample = [
jax.numpy.array([0.,])
]
aot = jax.jit(YOUR_FUNCTION).lower(*input_sample)
hlo_proto = aot.as_text()
mlir = compiler.compile_str(
hlo_proto,
target_backends=['llvm-cpu'],
output_generic_mlir=True,
input_type=compiler.InputType.STABLEHLO_XLA,
output_format=compiler.OutputFormat.MLIR_TEXT,
strip_debug_ops=True,
strip_source_map=True,
output_mlir_debuginfo=False,
)
print(mlir.decode('utf8'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment