Created
April 25, 2024 14:25
-
-
Save jyingl3/66d8bc3893a9b50a04bf208175e2d617 to your computer and use it in GitHub Desktop.
PJRT plugin tutorial
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# install jax | |
~$ pip install -U "jax[cpu]" | |
# Build the .so file | |
~$ git clone https://github.com/openxla/xla | |
# Optional, checkout the branch that adds vlog | |
~/xla$ git checkout remotes/origin/test_626168031 | |
# build cpu plugin | |
~/xla$ bazel build xla/pjrt/c:pjrt_c_api_cpu_plugin.so | |
# Check the method exposed. It should contain `T GetPjrtApi@@VERS_1.0` on the top | |
~/xla$ nm -gD bazel-bin/xla/pjrt/c/pjrt_c_api_cpu_plugin.so | grep GetPjrt | |
# Use this plugin in JAX by setting PJRT_NAMES_AND_LIBRARY_PATHS | |
~/xla$ PJRT_NAMES_AND_LIBRARY_PATHS=cpu_plugin:bazel-bin/xla/pjrt/c/pjrt_c_api_cpu_plugin.so ENABLE_PJRT_COMPATIBILITY=1 TF_CPP_VMODULE=cpu_client=3,pjrt_c_api_wrapper_impl=3 TF_CPP_MIN_LOG_LEVEL=0 python | |
>>> import jax | |
>>> from jax._src import xla_bridge | |
>>> jax.config.update("jax_platform_name", "cpu_plugin") | |
>>> client = xla_bridge.get_backend() | |
Platform 'cpu_plugin' is experimental and not all JAX functionality may be correctly supported! | |
I0000 00:00:1712356514.251375 99055 cpu_client.cc:424] TfrtCpuClient created. | |
2024-04-22 17:42:10.410579: I external/xla/xla/pjrt/pjrt_c_api_client.cc:134] PjRtCApiClient created. | |
>>> xla_bridge.backend_pjrt_c_api_version() | |
(0, 49) | |
>>> client.platform | |
'cpu' | |
>>> client.devices() | |
[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)] | |
>>> import numpy as np | |
>>> x = np.arange(12.).reshape((3, 4)).astype("float32") | |
>>> device_x = jax.device_put(x) | |
>>> @jax.jit | |
... def fn(x): | |
... return x * x | |
>>> x_shape = jax.ShapeDtypeStruct(shape=(16, 16), dtype=jax.numpy.dtype('float32')) | |
>>> lowered = fn.lower(x_shape) | |
>>> executable = lowered.compile()._executable | |
>>> executable.as_text() | |
'HloModule jit_fn, entry_computation_layout={(f32[16,16]{1,0})->f32[16,16]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}\n\nENTRY %main.3 (Arg_0.1: f32[16,16]) -> f32[16,16] {\n %Arg_0.1 = f32[16,16]{1,0} parameter(0)\n ROOT %multiply.2 = f32[16,16]{1,0} multiply(f32[16,16]{1,0} %Arg_0.1, f32[16,16]{1,0} %Arg_0.1), metadata={op_name="jit(fn)/jit(main)/mul" source_file="<stdin>" source_line=3}\n}\n\n' | |
# JAX 1+1 | |
>>> jax.numpy.add(1, 1) | |
Array(2, dtype=int32, weak_type=True) | |
# jit | |
>>> jax.jit(lambda x: x * 2)(1.) | |
Array(2., dtype=float32, weak_type=True) | |
# pmap (4 devices in this example) | |
>>> arr = jax.numpy.arange(jax.device_count()) | |
>>> jax.pmap(lambda x: x + jax.lax.psum(x, 'i'), axis_name='i')(arr) | |
Array([6, 7, 8, 9], dtype=int32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment