Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created February 4, 2026 03:14
Show Gist options
  • Select an option

  • Save ezyang/bf6e13ad90636566d0c338956c446335 to your computer and use it in GitHub Desktop.

Select an option

Save ezyang/bf6e13ad90636566d0c338956c446335 to your computer and use it in GitHub Desktop.
# Annotated Pre-Partition StableHLO IR for mlp2.py (with grad_x)
# Forward/Backward split and source line annotations
#
# Source code reference:
# line 10: rx = jax.reshard(x, ...)
# line 11: rw1 = jax.reshard(w1, ...)
# line 12: rw3 = jax.reshard(w3, ...)
# line 13: rw2 = jax.reshard(w2, ...)
# line 15: h1 = jnp.einsum("sbh,hi->sbi", rx, rw1)
# line 16: h3 = jnp.einsum("sbh,hi->sbi", rx, rw3)
# line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3)
# line 18: out = jnp.einsum("sbi,ih->sbh", h, rw2, out_sharding=...)
#
# Function returns: (out, grad_x, grad_w1, grad_w3, grad_w2)
# ============================================================================
module @jit_forward_and_backward attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <["dp"=2, "tp"=2]>
# ============================================================================
# MAIN FUNCTION
# ============================================================================
func.func public @main(
%arg0: tensor<4x8x16xf32>, # x - input activations
%arg1: tensor<16x32xf32>, # w1 - first FFN weight
%arg2: tensor<16x32xf32>, # w3 - gate weight
%arg3: tensor<32x16xf32>, # w2 - second FFN weight
%arg4: tensor<4x8x16xf32> # grad_out - gradient from loss
) -> (
tensor<4x8x16xf32>, # out (unreduced on tp)
tensor<4x8x16xf32>, # grad_x
tensor<16x32xf32>, # grad_w1
tensor<16x32xf32>, # grad_w3
tensor<32x16xf32> # grad_w2
) {
# ========== SHARDING CONSTRAINTS ON INPUTS ==========
# line 10: rx = jax.reshard(x, jax.P(None, 'dp', None, reduced={'tp'}))
%0 = sdy.sharding_constraint %arg0 <@mesh, [{}, {"dp"}, {}]> : tensor<4x8x16xf32>
# line 11: rw1 = jax.reshard(w1, jax.P(None, 'tp', reduced={'dp'}))
%1 = sdy.sharding_constraint %arg1 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32>
# line 12: rw3 = jax.reshard(w3, jax.P(None, 'tp', reduced={'dp'}))
%2 = sdy.sharding_constraint %arg2 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32>
# line 13: rw2 = jax.reshard(w2, jax.P('tp', None, reduced={'dp'}))
%3 = sdy.sharding_constraint %arg3 <@mesh, [{"tp"}, {}]> : tensor<32x16xf32>
# ========== FORWARD: h1 = x @ w1 ==========
# line 15: h1 = jnp.einsum("sbh,hi->sbi", rx, rw1)
%4 = stablehlo.dot_general %0, %1, contracting_dims = [2] x [0] : (tensor<4x8x16xf32>, tensor<16x32xf32>) -> tensor<4x8x32xf32>
%5 = sdy.sharding_constraint %4 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32>
# ========== FORWARD: h3 = x @ w3 ==========
# line 16: h3 = jnp.einsum("sbh,hi->sbi", rx, rw3)
%6 = stablehlo.dot_general %0, %2, contracting_dims = [2] x [0] : (tensor<4x8x16xf32>, tensor<16x32xf32>) -> tensor<4x8x32xf32>
%7 = sdy.sharding_constraint %6 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32>
# ========== FORWARD: silu(h1) ==========
# line 17: jax.nn.silu(h1) - returns (silu_output, sigmoid'(h1)*(1-sigmoid(h1)), sigmoid(h1))
%8:3 = call @silu(%5) : (tensor<4x8x32xf32>) -> (tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32>)
# ========== FORWARD: h = silu(h1) * h3 ==========
# line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3)
%9 = stablehlo.dot_general %8#0, %7, batching_dims = [0, 1, 2] x [0, 1, 2], contracting_dims = [] x [] : (tensor<4x8x32xf32>, tensor<4x8x32xf32>) -> tensor<4x8x32xf32>
%10 = sdy.sharding_constraint %9 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32>
# ========== FORWARD: out = h @ w2 ==========
# line 18: out = jnp.einsum("sbi,ih->sbh", h, rw2, out_sharding=...)
%11 = stablehlo.dot_general %10, %3, contracting_dims = [2] x [0] : (tensor<4x8x32xf32>, tensor<32x16xf32>) -> tensor<4x8x16xf32>
%12 = sdy.sharding_constraint %11 <@mesh, [{}, {"dp"}, {}], unreduced={"tp"}> : tensor<4x8x16xf32>
# ========== BACKWARD: g_w2 (partial, before transpose) ==========
# line 18 backward: grad_out.T @ h -> (16, 32) then transpose to (32, 16)
%13 = stablehlo.dot_general %arg4, %10, contracting_dims = [0, 1] x [0, 1] : (tensor<4x8x16xf32>, tensor<4x8x32xf32>) -> tensor<16x32xf32>
%14 = sdy.sharding_constraint %13 <@mesh, [{}, {"tp"}], unreduced={"dp"}> : tensor<16x32xf32>
# line 18 backward: transpose to get g_w2
%15 = stablehlo.transpose %14, dims = [1, 0] : (tensor<16x32xf32>) -> tensor<32x16xf32>
%16 = sdy.sharding_constraint %15 <@mesh, [{"tp"}, {}], unreduced={"dp"}> : tensor<32x16xf32>
# ========== BACKWARD: g_h = grad_out @ w2.T ==========
# line 18 backward: gradient flows back through w2
%17 = stablehlo.dot_general %arg4, %3, contracting_dims = [2] x [1] : (tensor<4x8x16xf32>, tensor<32x16xf32>) -> tensor<4x8x32xf32>
%18 = sdy.sharding_constraint %17 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32>
# ========== BACKWARD: g_h3 = g_h * silu(h1) ==========
# line 17 backward: gradient w.r.t. h3
%19 = stablehlo.dot_general %18, %8#0, batching_dims = [0, 1, 2] x [0, 1, 2], contracting_dims = [] x [] : (tensor<4x8x32xf32>, tensor<4x8x32xf32>) -> tensor<4x8x32xf32>
%20 = sdy.sharding_constraint %19 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32>
# ========== BACKWARD: g_silu = g_h * h3 ==========
# line 17 backward: gradient w.r.t. silu output (before silu backward)
%21 = stablehlo.dot_general %18, %7, batching_dims = [0, 1, 2] x [0, 1, 2], contracting_dims = [] x [] : (tensor<4x8x32xf32>, tensor<4x8x32xf32>) -> tensor<4x8x32xf32>
%22 = sdy.sharding_constraint %21 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32>
# ========== BACKWARD: g_h1 (through silu) ==========
# line 17 backward: gradient w.r.t. h1 = g_silu * silu'(h1)
%23 = call @silu_8(%8#1, %8#2, %5, %22) : (tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32>) -> tensor<4x8x32xf32>
# ========== BACKWARD: g_w3 (partial, before transpose) ==========
# line 16 backward: g_h3.T @ x -> (32, 16) then transpose to (16, 32)
%24 = stablehlo.dot_general %20, %0, contracting_dims = [0, 1] x [0, 1] : (tensor<4x8x32xf32>, tensor<4x8x16xf32>) -> tensor<32x16xf32>
%25 = sdy.sharding_constraint %24 <@mesh, [{"tp"}, {}], unreduced={"dp"}> : tensor<32x16xf32>
# line 16 backward: transpose to get g_w3
%26 = stablehlo.transpose %25, dims = [1, 0] : (tensor<32x16xf32>) -> tensor<16x32xf32>
%27 = sdy.sharding_constraint %26 <@mesh, [{}, {"tp"}], unreduced={"dp"}> : tensor<16x32xf32>
# ========== BACKWARD: g_x from h3 path = g_h3 @ w3.T ==========
# line 16 backward: gradient w.r.t. x (from h3 path)
%28 = stablehlo.dot_general %20, %2, contracting_dims = [2] x [1] : (tensor<4x8x32xf32>, tensor<16x32xf32>) -> tensor<4x8x16xf32>
%29 = sdy.sharding_constraint %28 <@mesh, [{}, {"dp"}, {}], unreduced={"tp"}> : tensor<4x8x16xf32>
# ========== BACKWARD: g_w1 (partial, before transpose) ==========
# line 15 backward: g_h1.T @ x -> (32, 16) then transpose to (16, 32)
%30 = stablehlo.dot_general %23, %0, contracting_dims = [0, 1] x [0, 1] : (tensor<4x8x32xf32>, tensor<4x8x16xf32>) -> tensor<32x16xf32>
%31 = sdy.sharding_constraint %30 <@mesh, [{"tp"}, {}], unreduced={"dp"}> : tensor<32x16xf32>
# line 15 backward: transpose to get g_w1
%32 = stablehlo.transpose %31, dims = [1, 0] : (tensor<32x16xf32>) -> tensor<16x32xf32>
%33 = sdy.sharding_constraint %32 <@mesh, [{}, {"tp"}], unreduced={"dp"}> : tensor<16x32xf32>
# ========== BACKWARD: g_x from h1 path = g_h1 @ w1.T ==========
# line 15 backward: gradient w.r.t. x (from h1 path)
%34 = stablehlo.dot_general %23, %1, contracting_dims = [2] x [1] : (tensor<4x8x32xf32>, tensor<16x32xf32>) -> tensor<4x8x16xf32>
%35 = sdy.sharding_constraint %34 <@mesh, [{}, {"dp"}, {}], unreduced={"tp"}> : tensor<4x8x16xf32>
# ========== BACKWARD: g_x = g_x_h1 + g_x_h3 ==========
# line 15+16 backward: sum gradients from both paths
%36 = stablehlo.add %29, %35 : tensor<4x8x16xf32>
# ========== REDUCE WEIGHT GRADIENTS (dp reduction) ==========
# Reduce g_w2 across dp dimension
%37 = sdy.sharding_constraint %16 <@mesh, [{"tp"}, {}]> : tensor<32x16xf32>
# Reduce g_w3 across dp dimension
%38 = sdy.sharding_constraint %27 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32>
# Reduce g_w1 across dp dimension
%39 = sdy.sharding_constraint %33 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32>
# ========== REDUCE g_x (tp reduction) ==========
# Reduce g_x across tp dimension
%40 = sdy.sharding_constraint %36 <@mesh, [{}, {"dp"}, {}]> : tensor<4x8x16xf32>
# ========== RETURN ==========
# Returns: (out, grad_x, grad_w1, grad_w3, grad_w2)
return %12, %40, %39, %38, %37 : tensor<4x8x16xf32>, tensor<4x8x16xf32>, tensor<16x32xf32>, tensor<16x32xf32>, tensor<32x16xf32>
}
# ============================================================================
# SILU FORWARD FUNCTION
# silu(x) = x * sigmoid(x), where sigmoid(x) = 1 / (1 + exp(-x))
# Returns: (silu_output, sigmoid(x)*(1-sigmoid(x)), sigmoid(x))
# The extra outputs are saved for backward pass
# ============================================================================
# line 17: jax.nn.silu(h1) - forward pass
func.func private @silu(%arg0: tensor<4x8x32xf32>) -> (tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32>) {
# Compute sigmoid(x) = 1 / (1 + exp(-x))
%0 = stablehlo.negate %arg0 : tensor<4x8x32xf32> # -x
%1 = sdy.sharding_constraint %0 <@mesh, [{}, {"dp"}, {"tp"}]>
%2 = stablehlo.exponential %1 : tensor<4x8x32xf32> # exp(-x)
%3 = sdy.sharding_constraint %2 <@mesh, [{}, {"dp"}, {"tp"}]>
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<4x8x32xf32> # 1.0
%5 = sdy.sharding_constraint %4 <@mesh, [{}, {"dp"}, {"tp"}]>
%6 = stablehlo.add %5, %3 : tensor<4x8x32xf32> # 1 + exp(-x)
%7 = sdy.sharding_constraint %6 <@mesh, [{}, {"dp"}, {"tp"}]>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<4x8x32xf32> # 1.0
%9 = sdy.sharding_constraint %8 <@mesh, [{}, {"dp"}, {"tp"}]>
%10 = stablehlo.divide %9, %7 : tensor<4x8x32xf32> # sigmoid(x) = 1 / (1 + exp(-x))
%11 = sdy.sharding_constraint %10 <@mesh, [{}, {"dp"}, {"tp"}]>
# Compute sigmoid(x) * (1 - sigmoid(x)) for backward
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%12 = sdy.sharding_constraint %cst_1 <@mesh, []> : tensor<f32>
%13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor<f32>) -> tensor<4x8x32xf32> # 1.0
%14 = sdy.sharding_constraint %13 <@mesh, [{}, {"dp"}, {"tp"}]>
%15 = stablehlo.subtract %14, %11 : tensor<4x8x32xf32> # 1 - sigmoid(x)
%16 = sdy.sharding_constraint %15 <@mesh, [{}, {"dp"}, {"tp"}]>
%17 = stablehlo.multiply %11, %16 : tensor<4x8x32xf32> # sigmoid(x) * (1 - sigmoid(x))
%18 = sdy.sharding_constraint %17 <@mesh, [{}, {"dp"}, {"tp"}]>
# Compute silu(x) = x * sigmoid(x)
%19 = stablehlo.multiply %arg0, %11 : tensor<4x8x32xf32> # x * sigmoid(x) = silu(x)
%20 = sdy.sharding_constraint %19 <@mesh, [{}, {"dp"}, {"tp"}]>
return %20, %18, %11 : tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32>
# Returns: (silu(x), sigmoid(x)*(1-sigmoid(x)), sigmoid(x))
}
# ============================================================================
# SILU BACKWARD FUNCTION
# Computes g_x where silu(x) = x * sigmoid(x)
# silu'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
# ============================================================================
# line 17 backward: gradient through silu
func.func private @silu_8(
%arg0: tensor<4x8x32xf32>, # sigmoid(x) * (1 - sigmoid(x)) - saved from forward
%arg1: tensor<4x8x32xf32>, # sigmoid(x) - saved from forward
%arg2: tensor<4x8x32xf32>, # x (h1) - the input
%arg3: tensor<4x8x32xf32> # g_silu - upstream gradient
) -> tensor<4x8x32xf32> {
# Term 1: g_silu * x * sigmoid(x) * (1 - sigmoid(x))
%0 = stablehlo.multiply %arg2, %arg3 : tensor<4x8x32xf32> # x * g_silu
%1 = sdy.sharding_constraint %0 <@mesh, [{}, {"dp"}, {"tp"}]>
%2 = stablehlo.multiply %arg3, %arg1 : tensor<4x8x32xf32> # g_silu * sigmoid(x)
%3 = sdy.sharding_constraint %2 <@mesh, [{}, {"dp"}, {"tp"}]>
%4 = stablehlo.multiply %1, %arg0 : tensor<4x8x32xf32> # x * g_silu * sigmoid(x) * (1 - sigmoid(x))
%5 = sdy.sharding_constraint %4 <@mesh, [{}, {"dp"}, {"tp"}]>
# g_x = g_silu * sigmoid(x) + x * g_silu * sigmoid(x) * (1 - sigmoid(x))
%6 = stablehlo.add %3, %5 : tensor<4x8x32xf32>
return %6 : tensor<4x8x32xf32>
}
}
# ============================================================================
# EXECUTION FLOW SUMMARY
# ============================================================================
#
# FORWARD PASS:
# 1. %4/%5: h1 = x @ w1 [line 15]
# 2. %6/%7: h3 = x @ w3 [line 16]
# 3. %8: silu(h1) = h1 * sigmoid(h1) [line 17 - silu]
# 4. %9/%10: h = silu(h1) * h3 [line 17 - elementwise mul]
# 5. %11/%12: out = h @ w2 [line 18]
#
# BACKWARD PASS:
# 1. %17/%18: g_h = grad_out @ w2.T [line 18 backward]
# 2. %19/%20: g_h3 = g_h * silu(h1) [line 17 backward - h3 gradient]
# 3. %21/%22: g_silu = g_h * h3 [line 17 backward - silu input gradient]
# 4. %23: g_h1 = silu_backward(g_silu) [line 17 backward - through silu]
# 5. %28/%29: g_x_h3 = g_h3 @ w3.T [line 16 backward - x gradient from h3]
# 6. %34/%35: g_x_h1 = g_h1 @ w1.T [line 15 backward - x gradient from h1]
# 7. %36: g_x = g_x_h1 + g_x_h3 [sum of x gradients]
# 8. %30-33: g_w1 = x.T @ g_h1 [line 15 backward - w1 gradient]
# 9. %24-27: g_w3 = x.T @ g_h3 [line 16 backward - w3 gradient]
# 10. %13-16: g_w2 = h.T @ grad_out [line 18 backward - w2 gradient]
#
# OUTPUTS: (out, grad_x, grad_w1, grad_w3, grad_w2)
# ============================================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment