Created
February 4, 2026 03:02
-
-
Save ezyang/f0e541a7f426f0ab3b7399bec7a041e2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| import jax.numpy as jnp | |
| from jax import lax | |
| jax.config.update('jax_num_cpu_devices', 4) | |
| jax.set_mesh(jax.make_mesh((2, 2), ('dp', 'tp'))) | |
| def mlp(x, w1, w3, w2): | |
| # !!! ATTENTION !!! | |
| rx = jax.reshard(x, jax.P(None, 'dp', None, reduced={'tp'})) | |
| rw1 = jax.reshard(w1, jax.P(None, 'tp', reduced={'dp'})) | |
| rw3 = jax.reshard(w3, jax.P(None, 'tp', reduced={'dp'})) | |
| rw2 = jax.reshard(w2, jax.P('tp', None, reduced={'dp'})) | |
| h1 = jnp.einsum("sbh,hi->sbi", rx, rw1) | |
| h3 = jnp.einsum("sbh,hi->sbi", rx, rw3) | |
| h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3) | |
| out = jnp.einsum("sbi,ih->sbh", h, rw2, out_sharding=jax.P(None, 'dp', None, unreduced={'tp'})) | |
| return out | |
| def forward_and_backward(x, w1, w3, w2, grad_out): | |
| """Compute forward pass and gradients using vjp with explicit grad_out.""" | |
| out, vjp_fn = jax.vjp(mlp, x, w1, w3, w2) | |
| # grad_out is reduced (since out is unreduced on 'tp') | |
| grad_x, grad_w1, grad_w3, grad_w2 = vjp_fn(grad_out) | |
| return out, (grad_w1, grad_w3, grad_w2) | |
| seq = 4 | |
| batch = 8 | |
| hidden = 16 | |
| intermediate = 32 | |
| x = jax.device_put( | |
| jnp.ones((seq, batch, hidden), dtype=jnp.float32), | |
| jax.P(None, 'dp', None) | |
| ) | |
| w1 = jax.device_put( | |
| jnp.ones((hidden, intermediate), dtype=jnp.float32), | |
| jax.P(None, 'tp') | |
| ) | |
| w3 = jax.device_put( | |
| jnp.ones((hidden, intermediate), dtype=jnp.float32), | |
| jax.P(None, 'tp') | |
| ) | |
| w2 = jax.device_put( | |
| jnp.ones((intermediate, hidden), dtype=jnp.float32), | |
| jax.P('tp', None) | |
| ) | |
| # grad_out is reduced on 'tp' (since out is unreduced on 'tp') | |
| grad_out = jax.device_put( | |
| jnp.ones((seq, batch, hidden), dtype=jnp.float32), | |
| jax.P(None, 'dp', None, reduced={'tp'}) | |
| ) | |
| # Lower and dump HLO/Shardy MLIR for forward+backward pass | |
| print("\n" + "=" * 80) | |
| print("FORWARD + BACKWARD PASS") | |
| print("=" * 80) | |
| backward_lowered = jax.jit(forward_and_backward).lower(x, w1, w3, w2, grad_out) | |
| print("\n--- Forward+Backward StableHLO with Shardy annotations ---") | |
| print(backward_lowered.as_text()) # stablehlo is default, includes sdy.* ops | |
| # Compile to get post-partitioning HLO with collectives | |
| backward_compiled = backward_lowered.compile() | |
| print("\n--- Forward+Backward HLO (post-partitioning with collectives) ---") | |
| print(backward_compiled.as_text()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment