Created
February 4, 2026 03:14
-
-
Save ezyang/bf6e13ad90636566d0c338956c446335 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Annotated Pre-Partition StableHLO IR for mlp2.py (with grad_x) | |
| # Forward/Backward split and source line annotations | |
| # | |
| # Source code reference: | |
| # line 10: rx = jax.reshard(x, ...) | |
| # line 11: rw1 = jax.reshard(w1, ...) | |
| # line 12: rw3 = jax.reshard(w3, ...) | |
| # line 13: rw2 = jax.reshard(w2, ...) | |
| # line 15: h1 = jnp.einsum("sbh,hi->sbi", rx, rw1) | |
| # line 16: h3 = jnp.einsum("sbh,hi->sbi", rx, rw3) | |
| # line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3) | |
| # line 18: out = jnp.einsum("sbi,ih->sbh", h, rw2, out_sharding=...) | |
| # | |
| # Function returns: (out, grad_x, grad_w1, grad_w3, grad_w2) | |
| # ============================================================================ | |
| module @jit_forward_and_backward attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} { | |
| sdy.mesh @mesh = <["dp"=2, "tp"=2]> | |
| # ============================================================================ | |
| # MAIN FUNCTION | |
| # ============================================================================ | |
| func.func public @main( | |
| %arg0: tensor<4x8x16xf32>, # x - input activations | |
| %arg1: tensor<16x32xf32>, # w1 - first FFN weight | |
| %arg2: tensor<16x32xf32>, # w3 - gate weight | |
| %arg3: tensor<32x16xf32>, # w2 - second FFN weight | |
| %arg4: tensor<4x8x16xf32> # grad_out - gradient from loss | |
| ) -> ( | |
| tensor<4x8x16xf32>, # out (unreduced on tp) | |
| tensor<4x8x16xf32>, # grad_x | |
| tensor<16x32xf32>, # grad_w1 | |
| tensor<16x32xf32>, # grad_w3 | |
| tensor<32x16xf32> # grad_w2 | |
| ) { | |
| # ========== SHARDING CONSTRAINTS ON INPUTS ========== | |
| # line 10: rx = jax.reshard(x, jax.P(None, 'dp', None, reduced={'tp'})) | |
| %0 = sdy.sharding_constraint %arg0 <@mesh, [{}, {"dp"}, {}]> : tensor<4x8x16xf32> | |
| # line 11: rw1 = jax.reshard(w1, jax.P(None, 'tp', reduced={'dp'})) | |
| %1 = sdy.sharding_constraint %arg1 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32> | |
| # line 12: rw3 = jax.reshard(w3, jax.P(None, 'tp', reduced={'dp'})) | |
| %2 = sdy.sharding_constraint %arg2 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32> | |
| # line 13: rw2 = jax.reshard(w2, jax.P('tp', None, reduced={'dp'})) | |
| %3 = sdy.sharding_constraint %arg3 <@mesh, [{"tp"}, {}]> : tensor<32x16xf32> | |
| # ========== FORWARD: h1 = x @ w1 ========== | |
| # line 15: h1 = jnp.einsum("sbh,hi->sbi", rx, rw1) | |
| %4 = stablehlo.dot_general %0, %1, contracting_dims = [2] x [0] : (tensor<4x8x16xf32>, tensor<16x32xf32>) -> tensor<4x8x32xf32> | |
| %5 = sdy.sharding_constraint %4 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32> | |
| # ========== FORWARD: h3 = x @ w3 ========== | |
| # line 16: h3 = jnp.einsum("sbh,hi->sbi", rx, rw3) | |
| %6 = stablehlo.dot_general %0, %2, contracting_dims = [2] x [0] : (tensor<4x8x16xf32>, tensor<16x32xf32>) -> tensor<4x8x32xf32> | |
| %7 = sdy.sharding_constraint %6 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32> | |
| # ========== FORWARD: silu(h1) ========== | |
| # line 17: jax.nn.silu(h1) - returns (silu_output, sigmoid'(h1)*(1-sigmoid(h1)), sigmoid(h1)) | |
| %8:3 = call @silu(%5) : (tensor<4x8x32xf32>) -> (tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32>) | |
| # ========== FORWARD: h = silu(h1) * h3 ========== | |
| # line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3) | |
| %9 = stablehlo.dot_general %8#0, %7, batching_dims = [0, 1, 2] x [0, 1, 2], contracting_dims = [] x [] : (tensor<4x8x32xf32>, tensor<4x8x32xf32>) -> tensor<4x8x32xf32> | |
| %10 = sdy.sharding_constraint %9 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32> | |
| # ========== FORWARD: out = h @ w2 ========== | |
| # line 18: out = jnp.einsum("sbi,ih->sbh", h, rw2, out_sharding=...) | |
| %11 = stablehlo.dot_general %10, %3, contracting_dims = [2] x [0] : (tensor<4x8x32xf32>, tensor<32x16xf32>) -> tensor<4x8x16xf32> | |
| %12 = sdy.sharding_constraint %11 <@mesh, [{}, {"dp"}, {}], unreduced={"tp"}> : tensor<4x8x16xf32> | |
| # ========== BACKWARD: g_w2 (partial, before transpose) ========== | |
| # line 18 backward: grad_out.T @ h -> (16, 32) then transpose to (32, 16) | |
| %13 = stablehlo.dot_general %arg4, %10, contracting_dims = [0, 1] x [0, 1] : (tensor<4x8x16xf32>, tensor<4x8x32xf32>) -> tensor<16x32xf32> | |
| %14 = sdy.sharding_constraint %13 <@mesh, [{}, {"tp"}], unreduced={"dp"}> : tensor<16x32xf32> | |
| # line 18 backward: transpose to get g_w2 | |
| %15 = stablehlo.transpose %14, dims = [1, 0] : (tensor<16x32xf32>) -> tensor<32x16xf32> | |
| %16 = sdy.sharding_constraint %15 <@mesh, [{"tp"}, {}], unreduced={"dp"}> : tensor<32x16xf32> | |
| # ========== BACKWARD: g_h = grad_out @ w2.T ========== | |
| # line 18 backward: gradient flows back through w2 | |
| %17 = stablehlo.dot_general %arg4, %3, contracting_dims = [2] x [1] : (tensor<4x8x16xf32>, tensor<32x16xf32>) -> tensor<4x8x32xf32> | |
| %18 = sdy.sharding_constraint %17 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32> | |
| # ========== BACKWARD: g_h3 = g_h * silu(h1) ========== | |
| # line 17 backward: gradient w.r.t. h3 | |
| %19 = stablehlo.dot_general %18, %8#0, batching_dims = [0, 1, 2] x [0, 1, 2], contracting_dims = [] x [] : (tensor<4x8x32xf32>, tensor<4x8x32xf32>) -> tensor<4x8x32xf32> | |
| %20 = sdy.sharding_constraint %19 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32> | |
| # ========== BACKWARD: g_silu = g_h * h3 ========== | |
| # line 17 backward: gradient w.r.t. silu output (before silu backward) | |
| %21 = stablehlo.dot_general %18, %7, batching_dims = [0, 1, 2] x [0, 1, 2], contracting_dims = [] x [] : (tensor<4x8x32xf32>, tensor<4x8x32xf32>) -> tensor<4x8x32xf32> | |
| %22 = sdy.sharding_constraint %21 <@mesh, [{}, {"dp"}, {"tp"}]> : tensor<4x8x32xf32> | |
| # ========== BACKWARD: g_h1 (through silu) ========== | |
| # line 17 backward: gradient w.r.t. h1 = g_silu * silu'(h1) | |
| %23 = call @silu_8(%8#1, %8#2, %5, %22) : (tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32>) -> tensor<4x8x32xf32> | |
| # ========== BACKWARD: g_w3 (partial, before transpose) ========== | |
| # line 16 backward: g_h3.T @ x -> (32, 16) then transpose to (16, 32) | |
| %24 = stablehlo.dot_general %20, %0, contracting_dims = [0, 1] x [0, 1] : (tensor<4x8x32xf32>, tensor<4x8x16xf32>) -> tensor<32x16xf32> | |
| %25 = sdy.sharding_constraint %24 <@mesh, [{"tp"}, {}], unreduced={"dp"}> : tensor<32x16xf32> | |
| # line 16 backward: transpose to get g_w3 | |
| %26 = stablehlo.transpose %25, dims = [1, 0] : (tensor<32x16xf32>) -> tensor<16x32xf32> | |
| %27 = sdy.sharding_constraint %26 <@mesh, [{}, {"tp"}], unreduced={"dp"}> : tensor<16x32xf32> | |
| # ========== BACKWARD: g_x from h3 path = g_h3 @ w3.T ========== | |
| # line 16 backward: gradient w.r.t. x (from h3 path) | |
| %28 = stablehlo.dot_general %20, %2, contracting_dims = [2] x [1] : (tensor<4x8x32xf32>, tensor<16x32xf32>) -> tensor<4x8x16xf32> | |
| %29 = sdy.sharding_constraint %28 <@mesh, [{}, {"dp"}, {}], unreduced={"tp"}> : tensor<4x8x16xf32> | |
| # ========== BACKWARD: g_w1 (partial, before transpose) ========== | |
| # line 15 backward: g_h1.T @ x -> (32, 16) then transpose to (16, 32) | |
| %30 = stablehlo.dot_general %23, %0, contracting_dims = [0, 1] x [0, 1] : (tensor<4x8x32xf32>, tensor<4x8x16xf32>) -> tensor<32x16xf32> | |
| %31 = sdy.sharding_constraint %30 <@mesh, [{"tp"}, {}], unreduced={"dp"}> : tensor<32x16xf32> | |
| # line 15 backward: transpose to get g_w1 | |
| %32 = stablehlo.transpose %31, dims = [1, 0] : (tensor<32x16xf32>) -> tensor<16x32xf32> | |
| %33 = sdy.sharding_constraint %32 <@mesh, [{}, {"tp"}], unreduced={"dp"}> : tensor<16x32xf32> | |
| # ========== BACKWARD: g_x from h1 path = g_h1 @ w1.T ========== | |
| # line 15 backward: gradient w.r.t. x (from h1 path) | |
| %34 = stablehlo.dot_general %23, %1, contracting_dims = [2] x [1] : (tensor<4x8x32xf32>, tensor<16x32xf32>) -> tensor<4x8x16xf32> | |
| %35 = sdy.sharding_constraint %34 <@mesh, [{}, {"dp"}, {}], unreduced={"tp"}> : tensor<4x8x16xf32> | |
| # ========== BACKWARD: g_x = g_x_h1 + g_x_h3 ========== | |
| # line 15+16 backward: sum gradients from both paths | |
| %36 = stablehlo.add %29, %35 : tensor<4x8x16xf32> | |
| # ========== REDUCE WEIGHT GRADIENTS (dp reduction) ========== | |
| # Reduce g_w2 across dp dimension | |
| %37 = sdy.sharding_constraint %16 <@mesh, [{"tp"}, {}]> : tensor<32x16xf32> | |
| # Reduce g_w3 across dp dimension | |
| %38 = sdy.sharding_constraint %27 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32> | |
| # Reduce g_w1 across dp dimension | |
| %39 = sdy.sharding_constraint %33 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32> | |
| # ========== REDUCE g_x (tp reduction) ========== | |
| # Reduce g_x across tp dimension | |
| %40 = sdy.sharding_constraint %36 <@mesh, [{}, {"dp"}, {}]> : tensor<4x8x16xf32> | |
| # ========== RETURN ========== | |
| # Returns: (out, grad_x, grad_w1, grad_w3, grad_w2) | |
| return %12, %40, %39, %38, %37 : tensor<4x8x16xf32>, tensor<4x8x16xf32>, tensor<16x32xf32>, tensor<16x32xf32>, tensor<32x16xf32> | |
| } | |
| # ============================================================================ | |
| # SILU FORWARD FUNCTION | |
| # silu(x) = x * sigmoid(x), where sigmoid(x) = 1 / (1 + exp(-x)) | |
| # Returns: (silu_output, sigmoid(x)*(1-sigmoid(x)), sigmoid(x)) | |
| # The extra outputs are saved for backward pass | |
| # ============================================================================ | |
| # line 17: jax.nn.silu(h1) - forward pass | |
| func.func private @silu(%arg0: tensor<4x8x32xf32>) -> (tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32>) { | |
| # Compute sigmoid(x) = 1 / (1 + exp(-x)) | |
| %0 = stablehlo.negate %arg0 : tensor<4x8x32xf32> # -x | |
| %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| %2 = stablehlo.exponential %1 : tensor<4x8x32xf32> # exp(-x) | |
| %3 = sdy.sharding_constraint %2 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32> | |
| %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<4x8x32xf32> # 1.0 | |
| %5 = sdy.sharding_constraint %4 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| %6 = stablehlo.add %5, %3 : tensor<4x8x32xf32> # 1 + exp(-x) | |
| %7 = sdy.sharding_constraint %6 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32> | |
| %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<f32>) -> tensor<4x8x32xf32> # 1.0 | |
| %9 = sdy.sharding_constraint %8 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| %10 = stablehlo.divide %9, %7 : tensor<4x8x32xf32> # sigmoid(x) = 1 / (1 + exp(-x)) | |
| %11 = sdy.sharding_constraint %10 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| # Compute sigmoid(x) * (1 - sigmoid(x)) for backward | |
| %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32> | |
| %12 = sdy.sharding_constraint %cst_1 <@mesh, []> : tensor<f32> | |
| %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor<f32>) -> tensor<4x8x32xf32> # 1.0 | |
| %14 = sdy.sharding_constraint %13 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| %15 = stablehlo.subtract %14, %11 : tensor<4x8x32xf32> # 1 - sigmoid(x) | |
| %16 = sdy.sharding_constraint %15 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| %17 = stablehlo.multiply %11, %16 : tensor<4x8x32xf32> # sigmoid(x) * (1 - sigmoid(x)) | |
| %18 = sdy.sharding_constraint %17 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| # Compute silu(x) = x * sigmoid(x) | |
| %19 = stablehlo.multiply %arg0, %11 : tensor<4x8x32xf32> # x * sigmoid(x) = silu(x) | |
| %20 = sdy.sharding_constraint %19 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| return %20, %18, %11 : tensor<4x8x32xf32>, tensor<4x8x32xf32>, tensor<4x8x32xf32> | |
| # Returns: (silu(x), sigmoid(x)*(1-sigmoid(x)), sigmoid(x)) | |
| } | |
| # ============================================================================ | |
| # SILU BACKWARD FUNCTION | |
| # Computes g_x where silu(x) = x * sigmoid(x) | |
| # silu'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x)) | |
| # ============================================================================ | |
| # line 17 backward: gradient through silu | |
| func.func private @silu_8( | |
| %arg0: tensor<4x8x32xf32>, # sigmoid(x) * (1 - sigmoid(x)) - saved from forward | |
| %arg1: tensor<4x8x32xf32>, # sigmoid(x) - saved from forward | |
| %arg2: tensor<4x8x32xf32>, # x (h1) - the input | |
| %arg3: tensor<4x8x32xf32> # g_silu - upstream gradient | |
| ) -> tensor<4x8x32xf32> { | |
| # Term 1: g_silu * x * sigmoid(x) * (1 - sigmoid(x)) | |
| %0 = stablehlo.multiply %arg2, %arg3 : tensor<4x8x32xf32> # x * g_silu | |
| %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| %2 = stablehlo.multiply %arg3, %arg1 : tensor<4x8x32xf32> # g_silu * sigmoid(x) | |
| %3 = sdy.sharding_constraint %2 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| %4 = stablehlo.multiply %1, %arg0 : tensor<4x8x32xf32> # x * g_silu * sigmoid(x) * (1 - sigmoid(x)) | |
| %5 = sdy.sharding_constraint %4 <@mesh, [{}, {"dp"}, {"tp"}]> | |
| # g_x = g_silu * sigmoid(x) + x * g_silu * sigmoid(x) * (1 - sigmoid(x)) | |
| %6 = stablehlo.add %3, %5 : tensor<4x8x32xf32> | |
| return %6 : tensor<4x8x32xf32> | |
| } | |
| } | |
| # ============================================================================ | |
| # EXECUTION FLOW SUMMARY | |
| # ============================================================================ | |
| # | |
| # FORWARD PASS: | |
| # 1. %4/%5: h1 = x @ w1 [line 15] | |
| # 2. %6/%7: h3 = x @ w3 [line 16] | |
| # 3. %8: silu(h1) = h1 * sigmoid(h1) [line 17 - silu] | |
| # 4. %9/%10: h = silu(h1) * h3 [line 17 - elementwise mul] | |
| # 5. %11/%12: out = h @ w2 [line 18] | |
| # | |
| # BACKWARD PASS: | |
| # 1. %17/%18: g_h = grad_out @ w2.T [line 18 backward] | |
| # 2. %19/%20: g_h3 = g_h * silu(h1) [line 17 backward - h3 gradient] | |
| # 3. %21/%22: g_silu = g_h * h3 [line 17 backward - silu input gradient] | |
| # 4. %23: g_h1 = silu_backward(g_silu) [line 17 backward - through silu] | |
| # 5. %28/%29: g_x_h3 = g_h3 @ w3.T [line 16 backward - x gradient from h3] | |
| # 6. %34/%35: g_x_h1 = g_h1 @ w1.T [line 15 backward - x gradient from h1] | |
| # 7. %36: g_x = g_x_h1 + g_x_h3 [sum of x gradients] | |
| # 8. %30-33: g_w1 = x.T @ g_h1 [line 15 backward - w1 gradient] | |
| # 9. %24-27: g_w3 = x.T @ g_h3 [line 16 backward - w3 gradient] | |
| # 10. %13-16: g_w2 = h.T @ grad_out [line 18 backward - w2 gradient] | |
| # | |
| # OUTPUTS: (out, grad_x, grad_w1, grad_w3, grad_w2) | |
| # ============================================================================ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment