Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dlwh/54be96f93f395071ad4688d2ee356c90 to your computer and use it in GitHub Desktop.
Save dlwh/54be96f93f395071ad4688d2ee356c90 to your computer and use it in GitHub Desktop.
HloModule jit__slice_and_replicate, is_scheduled=true, entry_computation_layout={(f32[256,512,256]{2,1,0:T(8,128)}, s32[]{:T(128)})->f32[524288]{0:T(1024)}}, allow_spmd_sharding_propagation_to_parameters={false,true}, allow_spmd_sharding_propagation_to_output={true}, num_partitions=4
ENTRY main.11_spmd {
constant = s32[]{:T(128)} constant(0)
constant.1 = s32[]{:T(128)} constant(134217728)
param = f32[256,512,256]{2,1,0:T(8,128)} parameter(0), sharding={devices=[4,1,1]<=[4]}, metadata={op_name="in_array"}
param.1 = s32[]{:T(128)} parameter(1), sharding={replicated}, metadata={op_name="offset"}
bitcast.3 = f32[16384,2,8,128]{3,2,1,0:T(8,128)} bitcast(param)
copy.2 = f32[16384,2,8,128]{3,1,2,0:T(2,128)S(3)} copy(bitcast.3)
bitcast.2 = f32[33554432]{0:T(1024)S(3)} bitcast(copy.2)
all-gather = f32[134217728]{0:T(1024)} all-gather(bitcast.2), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="jit(_slice_and_replicate)/jit(main)/dynamic_slice[slice_sizes=(524288,)]" source_file="/home/dlwh/slice.py" source_line=24}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"0"},"scoped_memory_configs":[],"used_scoped_memory_configs":[]}
copy.3 = s32[]{:T(128)S(6)} copy(param.1)
compare.0 = pred[]{:T(512)S(6)} compare(copy.3, constant), direction=LT, metadata={op_name="jit(_slice_and_replicate)/jit(main)/lt" source_file="/home/dlwh/slice.py" source_line=24}
add.0 = s32[]{:T(128)S(6)} add(copy.3, constant.1), metadata={op_name="jit(_slice_and_replicate)/jit(main)/add" source_file="/home/dlwh/slice.py" source_line=24}
select.0 = s32[]{:T(128)S(6)} select(compare.0, add.0, copy.3), metadata={op_name="jit(_slice_and_replicate)/jit(main)/select_n" source_file="/home/dlwh/slice.py" source_line=24}
ROOT dynamic-slice.0 = f32[524288]{0:T(1024)} dynamic-slice(all-gather, select.0), dynamic_slice_sizes={524288}, metadata={op_name="jit(_slice_and_replicate)/jit(main)/dynamic_slice[slice_sizes=(524288,)]" source_file="/home/dlwh/slice.py" source_line=24}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]}
} // main.11_spmd
import functools
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
def test_jax_slicing(mesh): # Test on the specified TPU mesh
arr = jnp.zeros((1024, 512, 256))
in_sharding = NamedSharding(mesh, PartitionSpec("x"))
arr = jax.device_put(arr, in_sharding)
arr_size = arr.size
# Sharding Setup
shardings = [None]
sharding = NamedSharding(mesh, PartitionSpec(*shardings))
@functools.partial(jax.jit, static_argnums=(2,))
def _slice_and_replicate(in_array, offset, size):
in_array = jnp.ravel(in_array)
in_array = jax.lax.dynamic_slice(in_array, [offset], [size])
in_array = jax.lax.with_sharding_constraint(in_array, sharding)
return in_array
copy_chunk_size = 512 * 1024
num_slices = (arr_size + copy_chunk_size - 1) // copy_chunk_size
slices = []
for _ in range(num_slices): # Simplified for demonstration
offset = 0 # Assuming only one slice in this test
slice_size = min(copy_chunk_size, arr_size - offset)
slice = _slice_and_replicate(arr, offset, slice_size)
# Explicit device put is often not needed with jit and sharding constraints
slices.append(np.array(slice))
# Assertions for testing
#assert len(slices) == 1, "Expected one slice in this test"
#assert slices[0].shape == (slice_size,), "Unexpected slice shape"
# TPU Configuration (crucial)
devices = jax.devices("tpu") # Get all available TPU devices
mesh_shape = (len(devices),) # Shape for a TPU v4-8
mesh = Mesh(np.array(devices).reshape(mesh_shape), ("x",))
# Run the test
with mesh:
test_jax_slicing(mesh)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment