Skip to content

Instantly share code, notes, and snippets.

@jyingl3
Created April 25, 2024 14:25
Show Gist options
  • Save jyingl3/66d8bc3893a9b50a04bf208175e2d617 to your computer and use it in GitHub Desktop.
Save jyingl3/66d8bc3893a9b50a04bf208175e2d617 to your computer and use it in GitHub Desktop.
PJRT plugin tutorial
# install jax
~$ pip install -U "jax[cpu]"
# Build the .so file
~$ git clone https://github.com/openxla/xla
# Optional, checkout the branch that adds vlog
~/xla$ git checkout remotes/origin/test_626168031
# build cpu plugin
~/xla$ bazel build xla/pjrt/c:pjrt_c_api_cpu_plugin.so
# Check the method exposed. It should contain `T GetPjrtApi@@VERS_1.0` on the top
~/xla$ nm -gD bazel-bin/xla/pjrt/c/pjrt_c_api_cpu_plugin.so | grep GetPjrt
# Use this plugin in JAX by setting PJRT_NAMES_AND_LIBRARY_PATHS
~/xla$ PJRT_NAMES_AND_LIBRARY_PATHS=cpu_plugin:bazel-bin/xla/pjrt/c/pjrt_c_api_cpu_plugin.so ENABLE_PJRT_COMPATIBILITY=1 TF_CPP_VMODULE=cpu_client=3,pjrt_c_api_wrapper_impl=3 TF_CPP_MIN_LOG_LEVEL=0 python
>>> import jax
>>> from jax._src import xla_bridge
>>> jax.config.update("jax_platform_name", "cpu_plugin")
>>> client = xla_bridge.get_backend()
Platform 'cpu_plugin' is experimental and not all JAX functionality may be correctly supported!
I0000 00:00:1712356514.251375 99055 cpu_client.cc:424] TfrtCpuClient created.
2024-04-22 17:42:10.410579: I external/xla/xla/pjrt/pjrt_c_api_client.cc:134] PjRtCApiClient created.
>>> xla_bridge.backend_pjrt_c_api_version()
(0, 49)
>>> client.platform
'cpu'
>>> client.devices()
[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
>>> import numpy as np
>>> x = np.arange(12.).reshape((3, 4)).astype("float32")
>>> device_x = jax.device_put(x)
>>> @jax.jit
... def fn(x):
... return x * x
>>> x_shape = jax.ShapeDtypeStruct(shape=(16, 16), dtype=jax.numpy.dtype('float32'))
>>> lowered = fn.lower(x_shape)
>>> executable = lowered.compile()._executable
>>> executable.as_text()
'HloModule jit_fn, entry_computation_layout={(f32[16,16]{1,0})->f32[16,16]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}\n\nENTRY %main.3 (Arg_0.1: f32[16,16]) -> f32[16,16] {\n %Arg_0.1 = f32[16,16]{1,0} parameter(0)\n ROOT %multiply.2 = f32[16,16]{1,0} multiply(f32[16,16]{1,0} %Arg_0.1, f32[16,16]{1,0} %Arg_0.1), metadata={op_name="jit(fn)/jit(main)/mul" source_file="<stdin>" source_line=3}\n}\n\n'
# JAX 1+1
>>> jax.numpy.add(1, 1)
Array(2, dtype=int32, weak_type=True)
# jit
>>> jax.jit(lambda x: x * 2)(1.)
Array(2., dtype=float32, weak_type=True)
# pmap (4 devices in this example)
>>> arr = jax.numpy.arange(jax.device_count())
>>> jax.pmap(lambda x: x + jax.lax.psum(x, 'i'), axis_name='i')(arr)
Array([6, 7, 8, 9], dtype=int32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment