Created
February 4, 2026 03:14
-
-
Save ezyang/20715de362d7851598018500b23d1ed0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Annotated Post-Partition HLO for mlp2.py (with grad_x) | |
| # Forward/Backward split and source line annotations | |
| # | |
| # Stack Frame ID to Source Line Mapping: | |
| # 5 → line 15: h1 = jnp.einsum("sbh,hi->sbi", rx, rw1) | |
| # 6 → line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3) [silu part] | |
| # 8 → line 16: h3 = jnp.einsum("sbh,hi->sbi", rx, rw3) | |
| # 9 → line 17: h = jnp.einsum("sbi,sbi->sbi", ...) [elementwise multiply] | |
| # 11 → line 18: out = jnp.einsum("sbi,ih->sbh", h, rw2, ...) | |
| # | |
| # "jvp(...)" = forward pass operation | |
| # "transpose(jvp(...))" = backward pass operation | |
| # | |
| # Function returns: (out, grad_x, grad_w1, grad_w3, grad_w2) | |
| # ============================================================================ | |
| # ---------------------------------------------------------------------------- | |
| # FUSED COMPUTATION: Backward - transpose grad_out for g_w2 computation | |
| # Source: line 18 backward (grad_out preparation for weight gradient) | |
| # ---------------------------------------------------------------------------- | |
| %fused_computation (param_0.2: f32[4,4,16]) -> f32[16,16] { | |
| %param_0.2 = f32[4,4,16]{2,1,0} parameter(0) | |
| # Transpose grad_out: [4,4,16] -> [16,4,4] -> [16,16] | |
| %transpose.84 = f32[16,4,4]{0,2,1} transpose(%param_0.2), dimensions={2,0,1}, metadata={op_name="grad_out"} | |
| %copy.13 = f32[16,4,4]{2,1,0} copy(%transpose.84), metadata={op_name="grad_out"} | |
| ROOT %bitcast.15 = f32[16,16]{1,0} bitcast(%copy.13), metadata={op_name="grad_out"} | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # FUSED COMPUTATION 1: Backward - g_h3 (for g_w3 path, with transpose) | |
| # Source: line 17 backward - computing gradient for h3 | |
| # g_h3 = g_h * silu(h1) [where h = silu(h1) * h3] | |
| # This version includes transpose for weight gradient computation | |
| # ---------------------------------------------------------------------------- | |
| # line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3) - backward for h3 | |
| %fused_computation.1 (param_0.8: f32[16,16], param_1.3: f32[4,4,16], param_2.1: f32[16,16]) -> f32[16,16] { | |
| # param_0.8 = g_h (from line 18 backward) | |
| # param_1.3 = sigmoid(h1) (saved from forward, line 17 silu) | |
| # param_2.1 = h1 (from line 15 forward) | |
| %param_0.8 = f32[16,16]{1,0} parameter(0) | |
| # line 18 backward: g_h coming from output gradient | |
| %bitcast.17 = f32[4,4,16]{2,1,0} bitcast(%param_0.8), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/dot_general" stack_frame_id=11} | |
| %param_2.1 = f32[16,16]{1,0} parameter(2) | |
| # line 15 forward: h1 = x @ w1 | |
| %bitcast.18 = f32[4,4,16]{2,1,0} bitcast(%param_2.1), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5} | |
| %param_1.3 = f32[4,4,16]{2,1,0} parameter(1) | |
| # line 17 forward: silu(h1) = sigmoid(h1) * h1 | |
| %mul.40 = f32[4,4,16]{2,1,0} multiply(%bitcast.18, %param_1.3), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6} | |
| # line 17 backward: g_h3 = g_h * silu(h1) | |
| %multiply.4 = f32[4,4,16]{2,1,0} multiply(%bitcast.17, %mul.40), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| %transpose.85 = f32[16,4,4]{0,2,1} transpose(%multiply.4), dimensions={2,0,1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| %copy.14 = f32[16,4,4]{2,1,0} copy(%transpose.85), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| ROOT %bitcast.16 = f32[16,16]{1,0} bitcast(%copy.14), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # FUSED COMPUTATION 2: Backward - g_h1 (for g_w1 path, with transpose) | |
| # Source: line 17 backward - computing gradient for h1 through silu | |
| # g_h1 = g_h * h3 * silu'(h1) where silu'(x) = sigmoid(x) + x*sigmoid(x)*(1-sigmoid(x)) | |
| # This version includes transpose for weight gradient computation | |
| # ---------------------------------------------------------------------------- | |
| # line 17: backward through silu - g_h1 | |
| %fused_computation.2 (param_0.13: f32[4,4,16], param_1.9: f32[16,16], param_2.7: f32[16,16], param_3.7: f32[16,16]) -> f32[16,16] { | |
| # param_0.13 = sigmoid(h1) (saved activation) | |
| # param_1.9 = g_h (from line 18 backward) | |
| # param_2.7 = h3 (from line 16 forward) | |
| # param_3.7 = h1 (from line 15 forward) | |
| %param_1.9 = f32[16,16]{1,0} parameter(1) | |
| %param_2.7 = f32[16,16]{1,0} parameter(2) | |
| # line 17 backward: g_h * h3 (gradient through sbi,sbi->sbi w.r.t. first operand) | |
| %multiply.5 = f32[16,16]{1,0} multiply(%param_1.9, %param_2.7), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| %bitcast.20 = f32[4,4,16]{2,1,0} bitcast(%multiply.5), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| %param_0.13 = f32[4,4,16]{2,1,0} parameter(0) | |
| # line 17 backward through silu: term1 = g_silu * sigmoid(h1) | |
| %mul.44 = f32[4,4,16]{2,1,0} multiply(%bitcast.20, %param_0.13), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6} | |
| %param_3.7 = f32[16,16]{1,0} parameter(3) | |
| # line 15 forward: h1 | |
| %bitcast.21 = f32[4,4,16]{2,1,0} bitcast(%param_3.7), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5} | |
| # line 17 backward through silu: g_silu * h1 | |
| %mul.43 = f32[4,4,16]{2,1,0} multiply(%bitcast.21, %bitcast.20), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6} | |
| # line 17: constant 1 for silu derivative | |
| %constant.35 = f32[] constant(1), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/reshard" stack_frame_id=6} | |
| %jvp_jit_silu__.2 = f32[4,4,16]{2,1,0} broadcast(%constant.35), dimensions={}, metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))"} | |
| # line 17 backward: (1 - sigmoid(h1)) | |
| %sub.16 = f32[4,4,16]{2,1,0} subtract(%jvp_jit_silu__.2, %param_0.13), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/sub" stack_frame_id=6} | |
| # line 17 backward: sigmoid(h1) * (1 - sigmoid(h1)) | |
| %mul.42 = f32[4,4,16]{2,1,0} multiply(%param_0.13, %sub.16), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6} | |
| # line 17 backward: term2 = g_silu * h1 * sigmoid(h1) * (1-sigmoid(h1)) | |
| %mul.41 = f32[4,4,16]{2,1,0} multiply(%mul.43, %mul.42), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6} | |
| # line 17 backward: g_h1 = term1 + term2 | |
| %add_any.7 = f32[4,4,16]{2,1,0} add(%mul.44, %mul.41), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6} | |
| %transpose.86 = f32[16,4,4]{0,2,1} transpose(%add_any.7), dimensions={2,0,1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6} | |
| %copy.15 = f32[16,4,4]{2,1,0} copy(%transpose.86), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6} | |
| ROOT %bitcast.19 = f32[16,16]{1,0} bitcast(%copy.15), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6} | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # FUSED COMPUTATION 3: Backward - g_h1 (for g_x path, no transpose) | |
| # Source: line 17 backward - computing gradient for h1 for input gradient | |
| # Same computation as fused_computation.2 but without the transpose at the end | |
| # Used for computing g_x = g_h1 @ w1.T | |
| # ---------------------------------------------------------------------------- | |
| # line 17: backward through silu - g_h1 for grad_x computation | |
| %fused_computation.3 (param_0.16: f32[4,4,16], param_1.15: f32[16,16], param_2.13: f32[16,16], param_3.15: f32[16,16]) -> f32[16,16] { | |
| %param_1.15 = f32[16,16]{1,0} parameter(1) | |
| %param_2.13 = f32[16,16]{1,0} parameter(2) | |
| # line 17 backward: g_h * h3 | |
| %multiply.6 = f32[16,16]{1,0} multiply(%param_1.15, %param_2.13), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| %bitcast.23 = f32[4,4,16]{2,1,0} bitcast(%multiply.6), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| %param_0.16 = f32[4,4,16]{2,1,0} parameter(0) | |
| # line 17 backward: g_silu * sigmoid(h1) | |
| %mul.48 = f32[4,4,16]{2,1,0} multiply(%bitcast.23, %param_0.16), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6} | |
| %param_3.15 = f32[16,16]{1,0} parameter(3) | |
| %bitcast.24 = f32[4,4,16]{2,1,0} bitcast(%param_3.15), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5} | |
| # line 17 backward: g_silu * h1 | |
| %mul.47 = f32[4,4,16]{2,1,0} multiply(%bitcast.24, %bitcast.23), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6} | |
| %constant.36 = f32[] constant(1), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/reshard" stack_frame_id=6} | |
| %jvp_jit_silu__.3 = f32[4,4,16]{2,1,0} broadcast(%constant.36), dimensions={}, metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))"} | |
| # line 17 backward: (1 - sigmoid(h1)) | |
| %sub.17 = f32[4,4,16]{2,1,0} subtract(%jvp_jit_silu__.3, %param_0.16), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/sub" stack_frame_id=6} | |
| # line 17 backward: sigmoid(h1) * (1 - sigmoid(h1)) | |
| %mul.46 = f32[4,4,16]{2,1,0} multiply(%param_0.16, %sub.17), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6} | |
| # line 17 backward: term2 | |
| %mul.45 = f32[4,4,16]{2,1,0} multiply(%mul.47, %mul.46), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6} | |
| # line 17 backward: g_h1 = term1 + term2 (no transpose - for grad_x path) | |
| %add_any.8 = f32[4,4,16]{2,1,0} add(%mul.48, %mul.45), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6} | |
| ROOT %bitcast.22 = f32[16,16]{1,0} bitcast(%add_any.8), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6} | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # FUSED COMPUTATION 4: Backward - g_h3 (for g_x path, no transpose) | |
| # Source: line 17 backward - computing gradient for h3 for input gradient | |
| # g_h3 = g_h * silu(h1) - without transpose, for grad_x computation | |
| # ---------------------------------------------------------------------------- | |
| # line 17: backward for h3 - for grad_x computation | |
| %fused_computation.4 (param_0.20: f32[16,16], param_1.19: f32[4,4,16], param_2.15: f32[16,16]) -> f32[16,16] { | |
| %param_0.20 = f32[16,16]{1,0} parameter(0) | |
| # line 18 backward: g_h | |
| %bitcast.26 = f32[4,4,16]{2,1,0} bitcast(%param_0.20), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/dot_general" stack_frame_id=11} | |
| %param_2.15 = f32[16,16]{1,0} parameter(2) | |
| # line 15 forward: h1 | |
| %bitcast.27 = f32[4,4,16]{2,1,0} bitcast(%param_2.15), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5} | |
| %param_1.19 = f32[4,4,16]{2,1,0} parameter(1) | |
| # line 17 forward: silu(h1) = h1 * sigmoid(h1) | |
| %mul.49 = f32[4,4,16]{2,1,0} multiply(%bitcast.27, %param_1.19), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6} | |
| # line 17 backward: g_h3 = g_h * silu(h1) (no transpose - for grad_x path) | |
| %multiply.7 = f32[4,4,16]{2,1,0} multiply(%bitcast.26, %mul.49), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| ROOT %bitcast.25 = f32[16,16]{1,0} bitcast(%multiply.7), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # FUSED COMPUTATION 5: Forward - compute h = silu(h1) * h3 | |
| # Source: line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3) | |
| # ---------------------------------------------------------------------------- | |
| # line 17: h = silu(h1) * h3 (forward pass) | |
| %fused_computation.5 (param_0.23: f32[16,16], param_1.23: f32[4,4,16], param_2.17: f32[16,16]) -> f32[16,16] { | |
| # param_0.23 = h3 (from line 16) | |
| # param_1.23 = sigmoid(h1) (from silu computation) | |
| # param_2.17 = h1 (from line 15) | |
| %param_2.17 = f32[16,16]{1,0} parameter(2) | |
| # line 15 forward: h1 = x @ w1 | |
| %bitcast.30 = f32[4,4,16]{2,1,0} bitcast(%param_2.17), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5} | |
| %param_1.23 = f32[4,4,16]{2,1,0} parameter(1) | |
| # line 17 forward: silu(h1) = h1 * sigmoid(h1) | |
| %mul.50 = f32[4,4,16]{2,1,0} multiply(%bitcast.30, %param_1.23), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6} | |
| %param_0.23 = f32[16,16]{1,0} parameter(0) | |
| # line 16 forward: h3 = x @ w3 | |
| %bitcast.29 = f32[4,4,16]{2,1,0} bitcast(%param_0.23), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=8} | |
| # line 17 forward: h = silu(h1) * h3 | |
| %multiply.8 = f32[4,4,16]{2,1,0} multiply(%mul.50, %bitcast.29), metadata={op_name="jit(forward_and_backward)/jvp(sbi,sbi->sbi)/dot_general" stack_frame_id=9} | |
| ROOT %bitcast.28 = f32[16,16]{1,0} bitcast(%multiply.8), metadata={op_name="jit(forward_and_backward)/jvp(sbi,sbi->sbi)/dot_general" stack_frame_id=9} | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # FUSED COMPUTATION 6: Forward - compute sigmoid(h1) for silu | |
| # Source: line 17: jax.nn.silu(h1) = h1 * sigmoid(h1), this computes sigmoid(h1) | |
| # silu(x) = x * sigmoid(x), sigmoid(x) = 1 / (1 + exp(-x)) | |
| # ---------------------------------------------------------------------------- | |
| # line 17: sigmoid(h1) computation (part of silu forward) | |
| %fused_computation.6 (param_0.26: f32[16,16]) -> f32[4,4,16] { | |
| # line 17: constant 1 for sigmoid | |
| %constant.37 = f32[] constant(1), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/reshard" stack_frame_id=6} | |
| %jvp_jit_silu__.8 = f32[4,4,16]{2,1,0} broadcast(%constant.37), dimensions={}, metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))"} | |
| %param_0.26 = f32[16,16]{1,0} parameter(0) | |
| # line 15 forward: h1 = x @ w1 (input to silu) | |
| %bitcast.31 = f32[4,4,16]{2,1,0} bitcast(%param_0.26), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5} | |
| # line 17: -h1 for sigmoid | |
| %neg.8 = f32[4,4,16]{2,1,0} negate(%bitcast.31), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/neg" stack_frame_id=6} | |
| # line 17: exp(-h1) | |
| %exp.8 = f32[4,4,16]{2,1,0} exponential(%neg.8), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/exp" stack_frame_id=6} | |
| # line 17: 1 + exp(-h1) | |
| %add.18 = f32[4,4,16]{2,1,0} add(%jvp_jit_silu__.8, %exp.8), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/add" stack_frame_id=6} | |
| # line 17: sigmoid(h1) = 1 / (1 + exp(-h1)) | |
| ROOT %div.12 = f32[4,4,16]{2,1,0} divide(%jvp_jit_silu__.8, %add.18), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/div" stack_frame_id=6} | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # WRAPPED ADD: Backward - sum g_x from h1 and h3 paths | |
| # Source: line 15+16 backward - combining input gradients from both paths | |
| # g_x = g_x_h1 + g_x_h3 | |
| # ---------------------------------------------------------------------------- | |
| # line 15+16 backward: sum input gradients | |
| %wrapped_add_computation (param_0.27: f32[4,4,16], param_1.29: f32[4,4,16]) -> f32[4,4,16] { | |
| %param_0.27 = f32[4,4,16]{2,1,0} parameter(0) | |
| %param_1.29 = f32[4,4,16]{2,1,0} parameter(1) | |
| # g_x = g_x_h3 + g_x_h1 | |
| ROOT %add_any.9 = f32[4,4,16]{2,1,0} add(%param_0.27, %param_1.29), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/add_any" stack_frame_id=5} | |
| } | |
| # ---------------------------------------------------------------------------- | |
| # REDUCTION COMPUTATIONS: Used by all-reduce for gradient accumulation | |
| # ---------------------------------------------------------------------------- | |
| # Used for all-reduce to sum input gradients across TP dimension | |
| %region_0.1 (dot_general.24: f32[], dot_general.0: f32[]) -> f32[] { | |
| %dot_general.24 = f32[] parameter(0), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general"} | |
| %dot_general.0 = f32[] parameter(1), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general"} | |
| ROOT %dot_general.1 = f32[] add(%dot_general.24, %dot_general.0), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8} | |
| } | |
| # Used for all-reduce to sum weight gradients across TP dimension | |
| %region_2.5 (transpose.0: f32[], transpose.1: f32[]) -> f32[] { | |
| %transpose.0 = f32[] parameter(0), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose"} | |
| %transpose.1 = f32[] parameter(1), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose"} | |
| ROOT %transpose.2 = f32[] add(%transpose.0, %transpose.1), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5} | |
| } | |
| # ============================================================================ | |
| # MAIN ENTRY POINT | |
| # Returns: (out, grad_x, grad_w1, grad_w3, grad_w2) | |
| # ============================================================================ | |
| ENTRY %main.11_spmd (param.5: f32[4,4,16], param.6: f32[16,16], param.7: f32[16,16], param.8: f32[16,16], param.9: f32[4,4,16]) -> (f32[4,4,16], f32[4,4,16], f32[16,16], f32[16,16], f32[16,16]) { | |
| # ========== PARAMETERS ========== | |
| # line 10: rx = jax.reshard(x, ...) - input activations | |
| %param.5 = f32[4,4,16]{2,1,0} parameter(0), sharding={devices=[1,2,1,2]<=[4] last_tile_dim_replicate}, metadata={op_name="x"} | |
| # line 11: rw1 = jax.reshard(w1, ...) - first FFN weight | |
| %param.6 = f32[16,16]{1,0} parameter(1), sharding={devices=[1,2,2]<=[2,2]T(1,0) last_tile_dim_replicate}, metadata={op_name="w1"} | |
| # line 12: rw3 = jax.reshard(w3, ...) - gate weight | |
| %param.7 = f32[16,16]{1,0} parameter(2), sharding={devices=[1,2,2]<=[2,2]T(1,0) last_tile_dim_replicate}, metadata={op_name="w3"} | |
| # line 13: rw2 = jax.reshard(w2, ...) - second FFN weight | |
| %param.8 = f32[16,16]{1,0} parameter(3), sharding={devices=[2,1,2]<=[2,2]T(1,0) last_tile_dim_replicate}, metadata={op_name="w2"} | |
| # grad_out: gradient from loss w.r.t. output | |
| %param.9 = f32[4,4,16]{2,1,0} parameter(4), sharding={devices=[1,2,1,2]<=[4] last_tile_dim_replicate}, metadata={op_name="grad_out"} | |
| # ========== LAYOUT TRANSFORMATIONS ========== | |
| # Reshape x for matmul: [4,4,16] -> [16,16] | |
| %bitcast = f32[16,16]{1,0} bitcast(%param.5), metadata={op_name="x"} | |
| # Reshape grad_out for matmul: [4,4,16] -> [16,16] | |
| %bitcast.5 = f32[16,16]{1,0} bitcast(%param.9), metadata={op_name="grad_out"} | |
| # ========== BACKWARD: Prepare grad_out for g_w2 ========== | |
| # Transpose grad_out for weight gradient computation | |
| %copy_bitcast_fusion = f32[16,16]{1,0} fusion(%param.9), kind=kLoop, calls=%fused_computation, metadata={op_name="grad_out"} | |
| # ========== FORWARD: Compute h1 = x @ w1 ========== | |
| # line 15: h1 = jnp.einsum("sbh,hi->sbi", rx, rw1) | |
| %dot.12 = f32[16,16]{1,0} dot(%bitcast, %param.6), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5} | |
| # ========== FORWARD: Compute h3 = x @ w3 ========== | |
| # line 16: h3 = jnp.einsum("sbh,hi->sbi", rx, rw3) | |
| %dot.13 = f32[16,16]{1,0} dot(%bitcast, %param.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=8} | |
| # ========== BACKWARD: Compute g_h = grad_out @ w2.T ========== | |
| # line 18 backward: gradient flows back from output through w2 | |
| %dot.16 = f32[16,16]{1,0} dot(%bitcast.5, %param.8), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/dot_general" stack_frame_id=11} | |
| # ========== FORWARD: Compute sigmoid(h1) for silu ========== | |
| # line 17: sigmoid(h1) = 1 / (1 + exp(-h1)) | |
| %add_divide_fusion = f32[4,4,16]{2,1,0} fusion(%dot.12), kind=kLoop, calls=%fused_computation.6, metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/div" stack_frame_id=6} | |
| # ========== BACKWARD: Compute g_h1 (for grad_x path, no transpose) ========== | |
| # line 17 backward: g_h1 = g_h * h3 * silu'(h1) | |
| %add_bitcast_fusion = f32[16,16]{1,0} fusion(%add_divide_fusion, %dot.16, %dot.13, %dot.12), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6} | |
| # ========== FORWARD: Compute h = silu(h1) * h3 ========== | |
| # line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3) | |
| %multiply_bitcast_fusion.1 = f32[16,16]{1,0} fusion(%dot.13, %add_divide_fusion, %dot.12), kind=kLoop, calls=%fused_computation.5, metadata={op_name="jit(forward_and_backward)/jvp(sbi,sbi->sbi)/dot_general" stack_frame_id=9} | |
| # ========== BACKWARD: Compute g_h3 (for g_w3 path, with transpose) ========== | |
| # line 17 backward: g_h3 = g_h * silu(h1) | |
| %copy_bitcast_fusion.1 = f32[16,16]{1,0} fusion(%dot.16, %add_divide_fusion, %dot.12), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| # ========== BACKWARD: Compute g_h1 (for g_w1 path, with transpose) ========== | |
| # line 17 backward: g_h1 = g_h * h3 * silu'(h1) (transposed for weight gradient) | |
| %copy_bitcast_fusion.2 = f32[16,16]{1,0} fusion(%add_divide_fusion, %dot.16, %dot.13, %dot.12), kind=kLoop, calls=%fused_computation.2, metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6} | |
| # ========== BACKWARD: Compute g_h3 (for grad_x path, no transpose) ========== | |
| # line 17 backward: g_h3 for input gradient computation | |
| %multiply_bitcast_fusion = f32[16,16]{1,0} fusion(%dot.16, %add_divide_fusion, %dot.12), kind=kLoop, calls=%fused_computation.4, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9} | |
| # ========== BACKWARD: Compute g_x from h1 path = g_h1 @ w1.T ========== | |
| # line 15 backward: input gradient from h1 path | |
| %dot.20 = f32[16,16]{1,0} dot(%add_bitcast_fusion, %param.6), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=5} | |
| # ========== FORWARD: Compute out = h @ w2 ========== | |
| # line 18: out = jnp.einsum("sbi,ih->sbh", h, rw2, ...) | |
| %dot.15 = f32[16,16]{1,0} dot(%multiply_bitcast_fusion.1, %param.8), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(forward_and_backward)/jvp(sbi,ih->sbh)/dot_general" stack_frame_id=11} | |
| # ========== BACKWARD: Compute g_w2 = h.T @ grad_out ========== | |
| # line 18 backward: weight gradient for w2 | |
| %dot.29 = f32[16,16]{1,0} dot(%multiply_bitcast_fusion.1, %copy_bitcast_fusion), lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/transpose" stack_frame_id=11} | |
| # ========== BACKWARD: Compute g_w3 = x.T @ g_h3 ========== | |
| # line 16 backward: weight gradient for w3 | |
| %dot.28 = f32[16,16]{1,0} dot(%bitcast, %copy_bitcast_fusion.1), lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=8} | |
| # ========== BACKWARD: Compute g_w1 = x.T @ g_h1 ========== | |
| # line 15 backward: weight gradient for w1 | |
| %dot.27 = f32[16,16]{1,0} dot(%bitcast, %copy_bitcast_fusion.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5} | |
| # ========== BACKWARD: Compute g_x from h3 path = g_h3 @ w3.T ========== | |
| # line 16 backward: input gradient from h3 path | |
| %dot.18 = f32[16,16]{1,0} dot(%multiply_bitcast_fusion, %param.7), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8} | |
| # ========== LAYOUT: Reshape g_x components ========== | |
| # line 15 backward: reshape g_x_h1 | |
| %bitcast.11 = f32[4,4,16]{2,1,0} bitcast(%dot.20), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=5} | |
| # line 18 forward: reshape out | |
| %bitcast.4 = f32[4,4,16]{2,1,0} bitcast(%dot.15), metadata={op_name="jit(forward_and_backward)/jvp(sbi,ih->sbh)/dot_general" stack_frame_id=11} | |
| # ========== BACKWARD: All-reduce weight gradients across TP ========== | |
| # Sum weight gradients across tensor parallel devices (replica_groups={{0,2},{1,3}}) | |
| # Returns: (g_w1, g_w3, g_w2) | |
| %all-reduce.1 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}) all-reduce(%dot.27, %dot.28, %dot.29), channel_id=3, replica_groups={{0,2},{1,3}}, use_global_device_ids=true, to_apply=%region_2.5, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5} | |
| # line 16 backward: reshape g_x_h3 | |
| %bitcast.8 = f32[4,4,16]{2,1,0} bitcast(%dot.18), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8} | |
| # Extract weight gradients from all-reduce tuple | |
| # line 15 backward: g_w1 (reduced) | |
| %get-tuple-element.4 = f32[16,16]{1,0} get-tuple-element(%all-reduce.1), index=0, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5} | |
| # line 16 backward: g_w3 (reduced) | |
| %get-tuple-element.6 = f32[16,16]{1,0} get-tuple-element(%all-reduce.1), index=1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=8} | |
| # line 18 backward: g_w2 (reduced) | |
| %get-tuple-element.8 = f32[16,16]{1,0} get-tuple-element(%all-reduce.1), index=2, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/transpose/jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/transpose" stack_frame_id=11} | |
| # ========== BACKWARD: All-reduce input gradients across TP ========== | |
| # Sum input gradients across tensor parallel devices (replica_groups={{0,1},{2,3}}) | |
| # Returns: (g_x_h3, g_x_h1) | |
| %all-reduce = (f32[4,4,16]{2,1,0}, f32[4,4,16]{2,1,0}) all-reduce(%bitcast.8, %bitcast.11), channel_id=1, replica_groups={{0,1},{2,3}}, use_global_device_ids=true, to_apply=%region_0.1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8} | |
| # Extract input gradients from all-reduce tuple | |
| # line 16 backward: g_x_h3 (reduced across TP) | |
| %get-tuple-element = f32[4,4,16]{2,1,0} get-tuple-element(%all-reduce), index=0, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8} | |
| # line 15 backward: g_x_h1 (reduced across TP) | |
| %get-tuple-element.2 = f32[4,4,16]{2,1,0} get-tuple-element(%all-reduce), index=1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=5} | |
| # ========== BACKWARD: Sum input gradients from both paths ========== | |
| # line 15+16 backward: g_x = g_x_h1 + g_x_h3 | |
| %wrapped_add = f32[4,4,16]{2,1,0} fusion(%get-tuple-element, %get-tuple-element.2), kind=kLoop, calls=%wrapped_add_computation, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/add_any" stack_frame_id=5} | |
| # ========== FINAL OUTPUT ========== | |
| # Returns: (out, grad_x, grad_w1, grad_w3, grad_w2) | |
| ROOT %tuple.6 = (f32[4,4,16]{2,1,0}, f32[4,4,16]{2,1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(%bitcast.4, %wrapped_add, %get-tuple-element.4, %get-tuple-element.6, %get-tuple-element.8) | |
| } | |
| # ============================================================================ | |
| # EXECUTION FLOW SUMMARY | |
| # ============================================================================ | |
| # | |
| # FORWARD PASS (in execution order): | |
| # 1. %dot.12: h1 = x @ w1 [line 15] | |
| # 2. %dot.13: h3 = x @ w3 [line 16] | |
| # 3. %add_divide_fusion: sigmoid(h1) [line 17 - silu part] | |
| # 4. %multiply_bitcast_fusion.1: h = silu(h1) * h3 [line 17] | |
| # 5. %dot.15: out = h @ w2 [line 18] | |
| # | |
| # BACKWARD PASS (in execution order): | |
| # 1. %copy_bitcast_fusion: transpose grad_out [prep for g_w2] | |
| # 2. %dot.16: g_h = grad_out @ w2.T [line 18 backward] | |
| # 3. %add_bitcast_fusion: g_h1 (for grad_x, no transpose) [line 17 backward] | |
| # 4. %copy_bitcast_fusion.1: g_h3 (for g_w3, transposed) [line 17 backward] | |
| # 5. %copy_bitcast_fusion.2: g_h1 (for g_w1, transposed) [line 17 backward] | |
| # 6. %multiply_bitcast_fusion: g_h3 (for grad_x, no transpose) [line 17 backward] | |
| # 7. %dot.20: g_x_h1 = g_h1 @ w1.T [line 15 backward] | |
| # 8. %dot.18: g_x_h3 = g_h3 @ w3.T [line 16 backward] | |
| # 9. %dot.27: g_w1 = x.T @ g_h1 [line 15 backward] | |
| # 10. %dot.28: g_w3 = x.T @ g_h3 [line 16 backward] | |
| # 11. %dot.29: g_w2 = h.T @ grad_out [line 18 backward] | |
| # 12. %all-reduce.1: sum weight gradients across TP [communication] | |
| # 13. %all-reduce: sum input gradients across TP [communication] | |
| # 14. %wrapped_add: g_x = g_x_h1 + g_x_h3 [line 15+16 backward] | |
| # | |
| # OUTPUTS: (out, grad_x, grad_w1, grad_w3, grad_w2) | |
| # ============================================================================ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment