Skip to content

Instantly share code, notes, and snippets.

@zhangqiaorjc
Created February 25, 2022 06:36
Show Gist options
  • Save zhangqiaorjc/c08ec7b28dabeee37a01cd85c7da6352 to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/c08ec7b28dabeee37a01cd85c7da6352 to your computer and use it in GitHub Desktop.
make_hlo
def make_hlo(f, optimize=False, metadata=False, platform=None):
"""Utility function for printing JAX-emitted HLO and XLA-compiled HLO.
Args:
f: jax function to return hlo for.
optimize: bool: whether to return platform-specific, XLA-optimized HLO
metadata: bool: whether to include JAX metadata information
platform: Optional[str]: None, 'cpu','gpu','tpu' - platform to compile for,
None uses default.
Returns:
str: HLO in text format.
"""
client = jax.lib.xla_bridge.get_backend(platform)
print_opts = jax.lib.xla_client._xla.HloPrintOptions.short_parsable()
print_opts.print_metadata = metadata
def wrapped_fn(*args, **kwargs):
c = jax.xla_computation(f)(*args, **kwargs)
if optimize:
return client.compile(c).hlo_modules()[0].to_string(print_opts)
else:
return c.as_hlo_module().to_string(print_opts)
return wrapped_fn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment