Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created February 4, 2026 03:14
Show Gist options
  • Select an option

  • Save ezyang/20715de362d7851598018500b23d1ed0 to your computer and use it in GitHub Desktop.

Select an option

Save ezyang/20715de362d7851598018500b23d1ed0 to your computer and use it in GitHub Desktop.
# Annotated Post-Partition HLO for mlp2.py (with grad_x)
# Forward/Backward split and source line annotations
#
# Stack Frame ID to Source Line Mapping:
# 5 → line 15: h1 = jnp.einsum("sbh,hi->sbi", rx, rw1)
# 6 → line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3) [silu part]
# 8 → line 16: h3 = jnp.einsum("sbh,hi->sbi", rx, rw3)
# 9 → line 17: h = jnp.einsum("sbi,sbi->sbi", ...) [elementwise multiply]
# 11 → line 18: out = jnp.einsum("sbi,ih->sbh", h, rw2, ...)
#
# "jvp(...)" = forward pass operation
# "transpose(jvp(...))" = backward pass operation
#
# Function returns: (out, grad_x, grad_w1, grad_w3, grad_w2)
# ============================================================================
# ----------------------------------------------------------------------------
# FUSED COMPUTATION: Backward - transpose grad_out for g_w2 computation
# Source: line 18 backward (grad_out preparation for weight gradient)
# ----------------------------------------------------------------------------
%fused_computation (param_0.2: f32[4,4,16]) -> f32[16,16] {
%param_0.2 = f32[4,4,16]{2,1,0} parameter(0)
# Transpose grad_out: [4,4,16] -> [16,4,4] -> [16,16]
%transpose.84 = f32[16,4,4]{0,2,1} transpose(%param_0.2), dimensions={2,0,1}, metadata={op_name="grad_out"}
%copy.13 = f32[16,4,4]{2,1,0} copy(%transpose.84), metadata={op_name="grad_out"}
ROOT %bitcast.15 = f32[16,16]{1,0} bitcast(%copy.13), metadata={op_name="grad_out"}
}
# ----------------------------------------------------------------------------
# FUSED COMPUTATION 1: Backward - g_h3 (for g_w3 path, with transpose)
# Source: line 17 backward - computing gradient for h3
# g_h3 = g_h * silu(h1) [where h = silu(h1) * h3]
# This version includes transpose for weight gradient computation
# ----------------------------------------------------------------------------
# line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3) - backward for h3
%fused_computation.1 (param_0.8: f32[16,16], param_1.3: f32[4,4,16], param_2.1: f32[16,16]) -> f32[16,16] {
# param_0.8 = g_h (from line 18 backward)
# param_1.3 = sigmoid(h1) (saved from forward, line 17 silu)
# param_2.1 = h1 (from line 15 forward)
%param_0.8 = f32[16,16]{1,0} parameter(0)
# line 18 backward: g_h coming from output gradient
%bitcast.17 = f32[4,4,16]{2,1,0} bitcast(%param_0.8), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/dot_general" stack_frame_id=11}
%param_2.1 = f32[16,16]{1,0} parameter(2)
# line 15 forward: h1 = x @ w1
%bitcast.18 = f32[4,4,16]{2,1,0} bitcast(%param_2.1), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5}
%param_1.3 = f32[4,4,16]{2,1,0} parameter(1)
# line 17 forward: silu(h1) = sigmoid(h1) * h1
%mul.40 = f32[4,4,16]{2,1,0} multiply(%bitcast.18, %param_1.3), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6}
# line 17 backward: g_h3 = g_h * silu(h1)
%multiply.4 = f32[4,4,16]{2,1,0} multiply(%bitcast.17, %mul.40), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
%transpose.85 = f32[16,4,4]{0,2,1} transpose(%multiply.4), dimensions={2,0,1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
%copy.14 = f32[16,4,4]{2,1,0} copy(%transpose.85), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
ROOT %bitcast.16 = f32[16,16]{1,0} bitcast(%copy.14), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
}
# ----------------------------------------------------------------------------
# FUSED COMPUTATION 2: Backward - g_h1 (for g_w1 path, with transpose)
# Source: line 17 backward - computing gradient for h1 through silu
# g_h1 = g_h * h3 * silu'(h1) where silu'(x) = sigmoid(x) + x*sigmoid(x)*(1-sigmoid(x))
# This version includes transpose for weight gradient computation
# ----------------------------------------------------------------------------
# line 17: backward through silu - g_h1
%fused_computation.2 (param_0.13: f32[4,4,16], param_1.9: f32[16,16], param_2.7: f32[16,16], param_3.7: f32[16,16]) -> f32[16,16] {
# param_0.13 = sigmoid(h1) (saved activation)
# param_1.9 = g_h (from line 18 backward)
# param_2.7 = h3 (from line 16 forward)
# param_3.7 = h1 (from line 15 forward)
%param_1.9 = f32[16,16]{1,0} parameter(1)
%param_2.7 = f32[16,16]{1,0} parameter(2)
# line 17 backward: g_h * h3 (gradient through sbi,sbi->sbi w.r.t. first operand)
%multiply.5 = f32[16,16]{1,0} multiply(%param_1.9, %param_2.7), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
%bitcast.20 = f32[4,4,16]{2,1,0} bitcast(%multiply.5), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
%param_0.13 = f32[4,4,16]{2,1,0} parameter(0)
# line 17 backward through silu: term1 = g_silu * sigmoid(h1)
%mul.44 = f32[4,4,16]{2,1,0} multiply(%bitcast.20, %param_0.13), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6}
%param_3.7 = f32[16,16]{1,0} parameter(3)
# line 15 forward: h1
%bitcast.21 = f32[4,4,16]{2,1,0} bitcast(%param_3.7), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5}
# line 17 backward through silu: g_silu * h1
%mul.43 = f32[4,4,16]{2,1,0} multiply(%bitcast.21, %bitcast.20), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6}
# line 17: constant 1 for silu derivative
%constant.35 = f32[] constant(1), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/reshard" stack_frame_id=6}
%jvp_jit_silu__.2 = f32[4,4,16]{2,1,0} broadcast(%constant.35), dimensions={}, metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))"}
# line 17 backward: (1 - sigmoid(h1))
%sub.16 = f32[4,4,16]{2,1,0} subtract(%jvp_jit_silu__.2, %param_0.13), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/sub" stack_frame_id=6}
# line 17 backward: sigmoid(h1) * (1 - sigmoid(h1))
%mul.42 = f32[4,4,16]{2,1,0} multiply(%param_0.13, %sub.16), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6}
# line 17 backward: term2 = g_silu * h1 * sigmoid(h1) * (1-sigmoid(h1))
%mul.41 = f32[4,4,16]{2,1,0} multiply(%mul.43, %mul.42), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6}
# line 17 backward: g_h1 = term1 + term2
%add_any.7 = f32[4,4,16]{2,1,0} add(%mul.44, %mul.41), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6}
%transpose.86 = f32[16,4,4]{0,2,1} transpose(%add_any.7), dimensions={2,0,1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6}
%copy.15 = f32[16,4,4]{2,1,0} copy(%transpose.86), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6}
ROOT %bitcast.19 = f32[16,16]{1,0} bitcast(%copy.15), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6}
}
# ----------------------------------------------------------------------------
# FUSED COMPUTATION 3: Backward - g_h1 (for g_x path, no transpose)
# Source: line 17 backward - computing gradient for h1 for input gradient
# Same computation as fused_computation.2 but without the transpose at the end
# Used for computing g_x = g_h1 @ w1.T
# ----------------------------------------------------------------------------
# line 17: backward through silu - g_h1 for grad_x computation
%fused_computation.3 (param_0.16: f32[4,4,16], param_1.15: f32[16,16], param_2.13: f32[16,16], param_3.15: f32[16,16]) -> f32[16,16] {
%param_1.15 = f32[16,16]{1,0} parameter(1)
%param_2.13 = f32[16,16]{1,0} parameter(2)
# line 17 backward: g_h * h3
%multiply.6 = f32[16,16]{1,0} multiply(%param_1.15, %param_2.13), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
%bitcast.23 = f32[4,4,16]{2,1,0} bitcast(%multiply.6), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
%param_0.16 = f32[4,4,16]{2,1,0} parameter(0)
# line 17 backward: g_silu * sigmoid(h1)
%mul.48 = f32[4,4,16]{2,1,0} multiply(%bitcast.23, %param_0.16), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6}
%param_3.15 = f32[16,16]{1,0} parameter(3)
%bitcast.24 = f32[4,4,16]{2,1,0} bitcast(%param_3.15), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5}
# line 17 backward: g_silu * h1
%mul.47 = f32[4,4,16]{2,1,0} multiply(%bitcast.24, %bitcast.23), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6}
%constant.36 = f32[] constant(1), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/reshard" stack_frame_id=6}
%jvp_jit_silu__.3 = f32[4,4,16]{2,1,0} broadcast(%constant.36), dimensions={}, metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))"}
# line 17 backward: (1 - sigmoid(h1))
%sub.17 = f32[4,4,16]{2,1,0} subtract(%jvp_jit_silu__.3, %param_0.16), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/sub" stack_frame_id=6}
# line 17 backward: sigmoid(h1) * (1 - sigmoid(h1))
%mul.46 = f32[4,4,16]{2,1,0} multiply(%param_0.16, %sub.17), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6}
# line 17 backward: term2
%mul.45 = f32[4,4,16]{2,1,0} multiply(%mul.47, %mul.46), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/mul" stack_frame_id=6}
# line 17 backward: g_h1 = term1 + term2 (no transpose - for grad_x path)
%add_any.8 = f32[4,4,16]{2,1,0} add(%mul.48, %mul.45), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6}
ROOT %bitcast.22 = f32[16,16]{1,0} bitcast(%add_any.8), metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6}
}
# ----------------------------------------------------------------------------
# FUSED COMPUTATION 4: Backward - g_h3 (for g_x path, no transpose)
# Source: line 17 backward - computing gradient for h3 for input gradient
# g_h3 = g_h * silu(h1) - without transpose, for grad_x computation
# ----------------------------------------------------------------------------
# line 17: backward for h3 - for grad_x computation
%fused_computation.4 (param_0.20: f32[16,16], param_1.19: f32[4,4,16], param_2.15: f32[16,16]) -> f32[16,16] {
%param_0.20 = f32[16,16]{1,0} parameter(0)
# line 18 backward: g_h
%bitcast.26 = f32[4,4,16]{2,1,0} bitcast(%param_0.20), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/dot_general" stack_frame_id=11}
%param_2.15 = f32[16,16]{1,0} parameter(2)
# line 15 forward: h1
%bitcast.27 = f32[4,4,16]{2,1,0} bitcast(%param_2.15), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5}
%param_1.19 = f32[4,4,16]{2,1,0} parameter(1)
# line 17 forward: silu(h1) = h1 * sigmoid(h1)
%mul.49 = f32[4,4,16]{2,1,0} multiply(%bitcast.27, %param_1.19), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6}
# line 17 backward: g_h3 = g_h * silu(h1) (no transpose - for grad_x path)
%multiply.7 = f32[4,4,16]{2,1,0} multiply(%bitcast.26, %mul.49), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
ROOT %bitcast.25 = f32[16,16]{1,0} bitcast(%multiply.7), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
}
# ----------------------------------------------------------------------------
# FUSED COMPUTATION 5: Forward - compute h = silu(h1) * h3
# Source: line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3)
# ----------------------------------------------------------------------------
# line 17: h = silu(h1) * h3 (forward pass)
%fused_computation.5 (param_0.23: f32[16,16], param_1.23: f32[4,4,16], param_2.17: f32[16,16]) -> f32[16,16] {
# param_0.23 = h3 (from line 16)
# param_1.23 = sigmoid(h1) (from silu computation)
# param_2.17 = h1 (from line 15)
%param_2.17 = f32[16,16]{1,0} parameter(2)
# line 15 forward: h1 = x @ w1
%bitcast.30 = f32[4,4,16]{2,1,0} bitcast(%param_2.17), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5}
%param_1.23 = f32[4,4,16]{2,1,0} parameter(1)
# line 17 forward: silu(h1) = h1 * sigmoid(h1)
%mul.50 = f32[4,4,16]{2,1,0} multiply(%bitcast.30, %param_1.23), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/mul" stack_frame_id=6}
%param_0.23 = f32[16,16]{1,0} parameter(0)
# line 16 forward: h3 = x @ w3
%bitcast.29 = f32[4,4,16]{2,1,0} bitcast(%param_0.23), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=8}
# line 17 forward: h = silu(h1) * h3
%multiply.8 = f32[4,4,16]{2,1,0} multiply(%mul.50, %bitcast.29), metadata={op_name="jit(forward_and_backward)/jvp(sbi,sbi->sbi)/dot_general" stack_frame_id=9}
ROOT %bitcast.28 = f32[16,16]{1,0} bitcast(%multiply.8), metadata={op_name="jit(forward_and_backward)/jvp(sbi,sbi->sbi)/dot_general" stack_frame_id=9}
}
# ----------------------------------------------------------------------------
# FUSED COMPUTATION 6: Forward - compute sigmoid(h1) for silu
# Source: line 17: jax.nn.silu(h1) = h1 * sigmoid(h1), this computes sigmoid(h1)
# silu(x) = x * sigmoid(x), sigmoid(x) = 1 / (1 + exp(-x))
# ----------------------------------------------------------------------------
# line 17: sigmoid(h1) computation (part of silu forward)
%fused_computation.6 (param_0.26: f32[16,16]) -> f32[4,4,16] {
# line 17: constant 1 for sigmoid
%constant.37 = f32[] constant(1), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/reshard" stack_frame_id=6}
%jvp_jit_silu__.8 = f32[4,4,16]{2,1,0} broadcast(%constant.37), dimensions={}, metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))"}
%param_0.26 = f32[16,16]{1,0} parameter(0)
# line 15 forward: h1 = x @ w1 (input to silu)
%bitcast.31 = f32[4,4,16]{2,1,0} bitcast(%param_0.26), metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5}
# line 17: -h1 for sigmoid
%neg.8 = f32[4,4,16]{2,1,0} negate(%bitcast.31), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/neg" stack_frame_id=6}
# line 17: exp(-h1)
%exp.8 = f32[4,4,16]{2,1,0} exponential(%neg.8), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/exp" stack_frame_id=6}
# line 17: 1 + exp(-h1)
%add.18 = f32[4,4,16]{2,1,0} add(%jvp_jit_silu__.8, %exp.8), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/add" stack_frame_id=6}
# line 17: sigmoid(h1) = 1 / (1 + exp(-h1))
ROOT %div.12 = f32[4,4,16]{2,1,0} divide(%jvp_jit_silu__.8, %add.18), metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/div" stack_frame_id=6}
}
# ----------------------------------------------------------------------------
# WRAPPED ADD: Backward - sum g_x from h1 and h3 paths
# Source: line 15+16 backward - combining input gradients from both paths
# g_x = g_x_h1 + g_x_h3
# ----------------------------------------------------------------------------
# line 15+16 backward: sum input gradients
%wrapped_add_computation (param_0.27: f32[4,4,16], param_1.29: f32[4,4,16]) -> f32[4,4,16] {
%param_0.27 = f32[4,4,16]{2,1,0} parameter(0)
%param_1.29 = f32[4,4,16]{2,1,0} parameter(1)
# g_x = g_x_h3 + g_x_h1
ROOT %add_any.9 = f32[4,4,16]{2,1,0} add(%param_0.27, %param_1.29), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/add_any" stack_frame_id=5}
}
# ----------------------------------------------------------------------------
# REDUCTION COMPUTATIONS: Used by all-reduce for gradient accumulation
# ----------------------------------------------------------------------------
# Used for all-reduce to sum input gradients across TP dimension
%region_0.1 (dot_general.24: f32[], dot_general.0: f32[]) -> f32[] {
%dot_general.24 = f32[] parameter(0), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general"}
%dot_general.0 = f32[] parameter(1), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general"}
ROOT %dot_general.1 = f32[] add(%dot_general.24, %dot_general.0), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8}
}
# Used for all-reduce to sum weight gradients across TP dimension
%region_2.5 (transpose.0: f32[], transpose.1: f32[]) -> f32[] {
%transpose.0 = f32[] parameter(0), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose"}
%transpose.1 = f32[] parameter(1), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose"}
ROOT %transpose.2 = f32[] add(%transpose.0, %transpose.1), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5}
}
# ============================================================================
# MAIN ENTRY POINT
# Returns: (out, grad_x, grad_w1, grad_w3, grad_w2)
# ============================================================================
ENTRY %main.11_spmd (param.5: f32[4,4,16], param.6: f32[16,16], param.7: f32[16,16], param.8: f32[16,16], param.9: f32[4,4,16]) -> (f32[4,4,16], f32[4,4,16], f32[16,16], f32[16,16], f32[16,16]) {
# ========== PARAMETERS ==========
# line 10: rx = jax.reshard(x, ...) - input activations
%param.5 = f32[4,4,16]{2,1,0} parameter(0), sharding={devices=[1,2,1,2]<=[4] last_tile_dim_replicate}, metadata={op_name="x"}
# line 11: rw1 = jax.reshard(w1, ...) - first FFN weight
%param.6 = f32[16,16]{1,0} parameter(1), sharding={devices=[1,2,2]<=[2,2]T(1,0) last_tile_dim_replicate}, metadata={op_name="w1"}
# line 12: rw3 = jax.reshard(w3, ...) - gate weight
%param.7 = f32[16,16]{1,0} parameter(2), sharding={devices=[1,2,2]<=[2,2]T(1,0) last_tile_dim_replicate}, metadata={op_name="w3"}
# line 13: rw2 = jax.reshard(w2, ...) - second FFN weight
%param.8 = f32[16,16]{1,0} parameter(3), sharding={devices=[2,1,2]<=[2,2]T(1,0) last_tile_dim_replicate}, metadata={op_name="w2"}
# grad_out: gradient from loss w.r.t. output
%param.9 = f32[4,4,16]{2,1,0} parameter(4), sharding={devices=[1,2,1,2]<=[4] last_tile_dim_replicate}, metadata={op_name="grad_out"}
# ========== LAYOUT TRANSFORMATIONS ==========
# Reshape x for matmul: [4,4,16] -> [16,16]
%bitcast = f32[16,16]{1,0} bitcast(%param.5), metadata={op_name="x"}
# Reshape grad_out for matmul: [4,4,16] -> [16,16]
%bitcast.5 = f32[16,16]{1,0} bitcast(%param.9), metadata={op_name="grad_out"}
# ========== BACKWARD: Prepare grad_out for g_w2 ==========
# Transpose grad_out for weight gradient computation
%copy_bitcast_fusion = f32[16,16]{1,0} fusion(%param.9), kind=kLoop, calls=%fused_computation, metadata={op_name="grad_out"}
# ========== FORWARD: Compute h1 = x @ w1 ==========
# line 15: h1 = jnp.einsum("sbh,hi->sbi", rx, rw1)
%dot.12 = f32[16,16]{1,0} dot(%bitcast, %param.6), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=5}
# ========== FORWARD: Compute h3 = x @ w3 ==========
# line 16: h3 = jnp.einsum("sbh,hi->sbi", rx, rw3)
%dot.13 = f32[16,16]{1,0} dot(%bitcast, %param.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(forward_and_backward)/jvp(sbh,hi->sbi)/dot_general" stack_frame_id=8}
# ========== BACKWARD: Compute g_h = grad_out @ w2.T ==========
# line 18 backward: gradient flows back from output through w2
%dot.16 = f32[16,16]{1,0} dot(%bitcast.5, %param.8), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/dot_general" stack_frame_id=11}
# ========== FORWARD: Compute sigmoid(h1) for silu ==========
# line 17: sigmoid(h1) = 1 / (1 + exp(-h1))
%add_divide_fusion = f32[4,4,16]{2,1,0} fusion(%dot.12), kind=kLoop, calls=%fused_computation.6, metadata={op_name="jit(forward_and_backward)/jvp(jit(silu))/div" stack_frame_id=6}
# ========== BACKWARD: Compute g_h1 (for grad_x path, no transpose) ==========
# line 17 backward: g_h1 = g_h * h3 * silu'(h1)
%add_bitcast_fusion = f32[16,16]{1,0} fusion(%add_divide_fusion, %dot.16, %dot.13, %dot.12), kind=kLoop, calls=%fused_computation.3, metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6}
# ========== FORWARD: Compute h = silu(h1) * h3 ==========
# line 17: h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3)
%multiply_bitcast_fusion.1 = f32[16,16]{1,0} fusion(%dot.13, %add_divide_fusion, %dot.12), kind=kLoop, calls=%fused_computation.5, metadata={op_name="jit(forward_and_backward)/jvp(sbi,sbi->sbi)/dot_general" stack_frame_id=9}
# ========== BACKWARD: Compute g_h3 (for g_w3 path, with transpose) ==========
# line 17 backward: g_h3 = g_h * silu(h1)
%copy_bitcast_fusion.1 = f32[16,16]{1,0} fusion(%dot.16, %add_divide_fusion, %dot.12), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
# ========== BACKWARD: Compute g_h1 (for g_w1 path, with transpose) ==========
# line 17 backward: g_h1 = g_h * h3 * silu'(h1) (transposed for weight gradient)
%copy_bitcast_fusion.2 = f32[16,16]{1,0} fusion(%add_divide_fusion, %dot.16, %dot.13, %dot.12), kind=kLoop, calls=%fused_computation.2, metadata={op_name="jit(forward_and_backward)/transpose(jvp(jit(silu)))/add_any" stack_frame_id=6}
# ========== BACKWARD: Compute g_h3 (for grad_x path, no transpose) ==========
# line 17 backward: g_h3 for input gradient computation
%multiply_bitcast_fusion = f32[16,16]{1,0} fusion(%dot.16, %add_divide_fusion, %dot.12), kind=kLoop, calls=%fused_computation.4, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,sbi->sbi))/dot_general" stack_frame_id=9}
# ========== BACKWARD: Compute g_x from h1 path = g_h1 @ w1.T ==========
# line 15 backward: input gradient from h1 path
%dot.20 = f32[16,16]{1,0} dot(%add_bitcast_fusion, %param.6), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=5}
# ========== FORWARD: Compute out = h @ w2 ==========
# line 18: out = jnp.einsum("sbi,ih->sbh", h, rw2, ...)
%dot.15 = f32[16,16]{1,0} dot(%multiply_bitcast_fusion.1, %param.8), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(forward_and_backward)/jvp(sbi,ih->sbh)/dot_general" stack_frame_id=11}
# ========== BACKWARD: Compute g_w2 = h.T @ grad_out ==========
# line 18 backward: weight gradient for w2
%dot.29 = f32[16,16]{1,0} dot(%multiply_bitcast_fusion.1, %copy_bitcast_fusion), lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/transpose" stack_frame_id=11}
# ========== BACKWARD: Compute g_w3 = x.T @ g_h3 ==========
# line 16 backward: weight gradient for w3
%dot.28 = f32[16,16]{1,0} dot(%bitcast, %copy_bitcast_fusion.1), lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=8}
# ========== BACKWARD: Compute g_w1 = x.T @ g_h1 ==========
# line 15 backward: weight gradient for w1
%dot.27 = f32[16,16]{1,0} dot(%bitcast, %copy_bitcast_fusion.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5}
# ========== BACKWARD: Compute g_x from h3 path = g_h3 @ w3.T ==========
# line 16 backward: input gradient from h3 path
%dot.18 = f32[16,16]{1,0} dot(%multiply_bitcast_fusion, %param.7), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8}
# ========== LAYOUT: Reshape g_x components ==========
# line 15 backward: reshape g_x_h1
%bitcast.11 = f32[4,4,16]{2,1,0} bitcast(%dot.20), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=5}
# line 18 forward: reshape out
%bitcast.4 = f32[4,4,16]{2,1,0} bitcast(%dot.15), metadata={op_name="jit(forward_and_backward)/jvp(sbi,ih->sbh)/dot_general" stack_frame_id=11}
# ========== BACKWARD: All-reduce weight gradients across TP ==========
# Sum weight gradients across tensor parallel devices (replica_groups={{0,2},{1,3}})
# Returns: (g_w1, g_w3, g_w2)
%all-reduce.1 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}) all-reduce(%dot.27, %dot.28, %dot.29), channel_id=3, replica_groups={{0,2},{1,3}}, use_global_device_ids=true, to_apply=%region_2.5, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5}
# line 16 backward: reshape g_x_h3
%bitcast.8 = f32[4,4,16]{2,1,0} bitcast(%dot.18), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8}
# Extract weight gradients from all-reduce tuple
# line 15 backward: g_w1 (reduced)
%get-tuple-element.4 = f32[16,16]{1,0} get-tuple-element(%all-reduce.1), index=0, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5}
# line 16 backward: g_w3 (reduced)
%get-tuple-element.6 = f32[16,16]{1,0} get-tuple-element(%all-reduce.1), index=1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=8}
# line 18 backward: g_w2 (reduced)
%get-tuple-element.8 = f32[16,16]{1,0} get-tuple-element(%all-reduce.1), index=2, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/transpose/jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/transpose" stack_frame_id=11}
# ========== BACKWARD: All-reduce input gradients across TP ==========
# Sum input gradients across tensor parallel devices (replica_groups={{0,1},{2,3}})
# Returns: (g_x_h3, g_x_h1)
%all-reduce = (f32[4,4,16]{2,1,0}, f32[4,4,16]{2,1,0}) all-reduce(%bitcast.8, %bitcast.11), channel_id=1, replica_groups={{0,1},{2,3}}, use_global_device_ids=true, to_apply=%region_0.1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8}
# Extract input gradients from all-reduce tuple
# line 16 backward: g_x_h3 (reduced across TP)
%get-tuple-element = f32[4,4,16]{2,1,0} get-tuple-element(%all-reduce), index=0, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8}
# line 15 backward: g_x_h1 (reduced across TP)
%get-tuple-element.2 = f32[4,4,16]{2,1,0} get-tuple-element(%all-reduce), index=1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=5}
# ========== BACKWARD: Sum input gradients from both paths ==========
# line 15+16 backward: g_x = g_x_h1 + g_x_h3
%wrapped_add = f32[4,4,16]{2,1,0} fusion(%get-tuple-element, %get-tuple-element.2), kind=kLoop, calls=%wrapped_add_computation, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/add_any" stack_frame_id=5}
# ========== FINAL OUTPUT ==========
# Returns: (out, grad_x, grad_w1, grad_w3, grad_w2)
ROOT %tuple.6 = (f32[4,4,16]{2,1,0}, f32[4,4,16]{2,1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(%bitcast.4, %wrapped_add, %get-tuple-element.4, %get-tuple-element.6, %get-tuple-element.8)
}
# ============================================================================
# EXECUTION FLOW SUMMARY
# ============================================================================
#
# FORWARD PASS (in execution order):
# 1. %dot.12: h1 = x @ w1 [line 15]
# 2. %dot.13: h3 = x @ w3 [line 16]
# 3. %add_divide_fusion: sigmoid(h1) [line 17 - silu part]
# 4. %multiply_bitcast_fusion.1: h = silu(h1) * h3 [line 17]
# 5. %dot.15: out = h @ w2 [line 18]
#
# BACKWARD PASS (in execution order):
# 1. %copy_bitcast_fusion: transpose grad_out [prep for g_w2]
# 2. %dot.16: g_h = grad_out @ w2.T [line 18 backward]
# 3. %add_bitcast_fusion: g_h1 (for grad_x, no transpose) [line 17 backward]
# 4. %copy_bitcast_fusion.1: g_h3 (for g_w3, transposed) [line 17 backward]
# 5. %copy_bitcast_fusion.2: g_h1 (for g_w1, transposed) [line 17 backward]
# 6. %multiply_bitcast_fusion: g_h3 (for grad_x, no transpose) [line 17 backward]
# 7. %dot.20: g_x_h1 = g_h1 @ w1.T [line 15 backward]
# 8. %dot.18: g_x_h3 = g_h3 @ w3.T [line 16 backward]
# 9. %dot.27: g_w1 = x.T @ g_h1 [line 15 backward]
# 10. %dot.28: g_w3 = x.T @ g_h3 [line 16 backward]
# 11. %dot.29: g_w2 = h.T @ grad_out [line 18 backward]
# 12. %all-reduce.1: sum weight gradients across TP [communication]
# 13. %all-reduce: sum input gradients across TP [communication]
# 14. %wrapped_add: g_x = g_x_h1 + g_x_h3 [line 15+16 backward]
#
# OUTPUTS: (out, grad_x, grad_w1, grad_w3, grad_w2)
# ============================================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment