Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created February 4, 2026 03:02
Show Gist options
  • Select an option

  • Save ezyang/f0e541a7f426f0ab3b7399bec7a041e2 to your computer and use it in GitHub Desktop.

Select an option

Save ezyang/f0e541a7f426f0ab3b7399bec7a041e2 to your computer and use it in GitHub Desktop.
import jax
import jax.numpy as jnp
from jax import lax
jax.config.update('jax_num_cpu_devices', 4)
jax.set_mesh(jax.make_mesh((2, 2), ('dp', 'tp')))
def mlp(x, w1, w3, w2):
# !!! ATTENTION !!!
rx = jax.reshard(x, jax.P(None, 'dp', None, reduced={'tp'}))
rw1 = jax.reshard(w1, jax.P(None, 'tp', reduced={'dp'}))
rw3 = jax.reshard(w3, jax.P(None, 'tp', reduced={'dp'}))
rw2 = jax.reshard(w2, jax.P('tp', None, reduced={'dp'}))
h1 = jnp.einsum("sbh,hi->sbi", rx, rw1)
h3 = jnp.einsum("sbh,hi->sbi", rx, rw3)
h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3)
out = jnp.einsum("sbi,ih->sbh", h, rw2, out_sharding=jax.P(None, 'dp', None, unreduced={'tp'}))
return out
def forward_and_backward(x, w1, w3, w2, grad_out):
"""Compute forward pass and gradients using vjp with explicit grad_out."""
out, vjp_fn = jax.vjp(mlp, x, w1, w3, w2)
# grad_out is reduced (since out is unreduced on 'tp')
grad_x, grad_w1, grad_w3, grad_w2 = vjp_fn(grad_out)
return out, (grad_w1, grad_w3, grad_w2)
seq = 4
batch = 8
hidden = 16
intermediate = 32
x = jax.device_put(
jnp.ones((seq, batch, hidden), dtype=jnp.float32),
jax.P(None, 'dp', None)
)
w1 = jax.device_put(
jnp.ones((hidden, intermediate), dtype=jnp.float32),
jax.P(None, 'tp')
)
w3 = jax.device_put(
jnp.ones((hidden, intermediate), dtype=jnp.float32),
jax.P(None, 'tp')
)
w2 = jax.device_put(
jnp.ones((intermediate, hidden), dtype=jnp.float32),
jax.P('tp', None)
)
# grad_out is reduced on 'tp' (since out is unreduced on 'tp')
grad_out = jax.device_put(
jnp.ones((seq, batch, hidden), dtype=jnp.float32),
jax.P(None, 'dp', None, reduced={'tp'})
)
# Lower and dump HLO/Shardy MLIR for forward+backward pass
print("\n" + "=" * 80)
print("FORWARD + BACKWARD PASS")
print("=" * 80)
backward_lowered = jax.jit(forward_and_backward).lower(x, w1, w3, w2, grad_out)
print("\n--- Forward+Backward StableHLO with Shardy annotations ---")
print(backward_lowered.as_text()) # stablehlo is default, includes sdy.* ops
# Compile to get post-partitioning HLO with collectives
backward_compiled = backward_lowered.compile()
print("\n--- Forward+Backward HLO (post-partitioning with collectives) ---")
print(backward_compiled.as_text())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment