Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Created June 17, 2022 00:40
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save merrymercy/47f744b395173fca805ab9d93e66c59e to your computer and use it in GitHub Desktop.
Save merrymercy/47f744b395173fca805ab9d93e66c59e to your computer and use it in GitHub Desktop.
Use jax.pjit to partition embedding table
"""Test embedding table partition in XLA.
References:
- Introduction to pjit. https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html
"""
from functools import partial
import jax
import jax.numpy as jnp
from jax.experimental import maps
from jax.experimental import PartitionSpec as P
from jax.experimental.pjit import pjit
import numpy as np
def run_embedding(mesh_shape, in_axis_resources, out_axis_resources):
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = maps.Mesh(devices, ('x',))
@partial(pjit,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources)
def f(indices, embedding):
out = jnp.take(embedding, indices, axis=0)
return out
batch_size = 8
vocab_size = 1024
feature_size = 512
np.random.seed(0)
indices = np.random.uniform(low=0, high=vocab_size-1, size=(batch_size,)).astype(np.int32)
embedding = np.random.randn(vocab_size, feature_size).astype(np.float32)
with maps.Mesh(mesh.devices, mesh.axis_names):
out = f(indices, embedding)
executable = f.lower(indices, embedding).compile().runtime_executable()
print("=" * 20 + " HLO " + "=" * 20)
print(executable.hlo_modules()[0].to_string())
def test_embed_col_partition():
run_embedding((8,),
in_axis_resources=(P(None,), P(None, 'x')),
out_axis_resources=(P(None, 'x')))
def test_embed_row_partition():
run_embedding((8,),
in_axis_resources=(P(None,), P('x', None)),
out_axis_resources=(P(None, None)))
if __name__ == "__main__":
test_embed_col_partition()
test_embed_row_partition()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment