Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save parthmannan/dc0c918dcea33bb8664c6577349f1d01 to your computer and use it in GitHub Desktop.
Save parthmannan/dc0c918dcea33bb8664c6577349f1d01 to your computer and use it in GitHub Desktop.
Falcon 7B Thunder Debug
@torch.no_grad()
@no_autocast
def augmented_forward_fn(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_bias, t_transformer_h_0_norm_1_weight, t_transformer_ln_f_bias, t_transformer_ln_f_weight, t_transformer_wte_weight):
# idx: "cuda:0 i64[1, 2048]"
# tos1: "cuda:0 bf16[2048, 64]"
# t_lm_head_weight: "cuda:0 bf16[65024, 4096]"
# t_sin: "cuda:0 bf16[2048, 64]"
# t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[8448, 4096]"
# t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 8192]"
# t_transformer_h_0_mlp_fc_weight: "cuda:0 bf16[18176, 4096]"
# t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 18176]"
# t_transformer_h_0_norm_1_bias: "cuda:0 bf16[4096]"
# t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]"
# t_transformer_ln_f_bias: "cuda:0 bf16[4096]"
# t_transformer_ln_f_weight: "cuda:0 bf16[4096]"
# t_transformer_wte_weight: "cuda:0 bf16[65024, 4096]"
t4 = torch.nn.functional.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4096]"
# t4 = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4096]"
# t158 = ltorch.reshape(idx, [2048]) # t158: "cuda:0 i64[2048]"
# t158 = prims.reshape(idx, (2048,)) # t158: "cuda:0 i64[2048]"
# t159 = prims.take(t_transformer_wte_weight, t158, 0) # t159: "cuda:0 bf16[2048, 4096]"
# t4 = ltorch.reshape(t159, [1, 2048, 4096]) # t4: "cuda:0 bf16[1, 2048, 4096]"
# t4 = prims.reshape(t159, (1, 2048, 4096)) # t4: "cuda:0 bf16[1, 2048, 4096]"
t166 = torch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t166: "cuda:0 bf16[1, 4096]"
# t166 = ltorch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t166: "cuda:0 bf16[1, 4096]"
# t166 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, [1, 4096], [1]) # t166: "cuda:0 bf16[1, 4096]"
t167 = torch.unsqueeze(t166, 1) # t167: "cuda:0 bf16[1, 1, 4096]"
# t167 = ltorch.unsqueeze(t166, 1) # t167: "cuda:0 bf16[1, 1, 4096]"
# t167 = prims.broadcast_in_dim(t166, [1, 1, 4096], [0, 2]) # t167: "cuda:0 bf16[1, 1, 4096]"
del t166
t19 = Tensor.expand(t167, (1, 2048, 4096)) # t19: "cuda:0 bf16[1, 2048, 4096]"
# t19 = ltorch.expand(t167, (1, 2048, 4096)) # t19: "cuda:0 bf16[1, 2048, 4096]"
# t19 = prims.broadcast_in_dim(t167, (1, 2048, 4096), (0, 1, 2)) # t19: "cuda:0 bf16[1, 2048, 4096]"
del t167
t169 = torch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t169: "cuda:0 bf16[1, 4096]"
# t169 = ltorch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t169: "cuda:0 bf16[1, 4096]"
# t169 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_bias, [1, 4096], [1]) # t169: "cuda:0 bf16[1, 4096]"
t170 = torch.unsqueeze(t169, 1) # t170: "cuda:0 bf16[1, 1, 4096]"
# t170 = ltorch.unsqueeze(t169, 1) # t170: "cuda:0 bf16[1, 1, 4096]"
# t170 = prims.broadcast_in_dim(t169, [1, 1, 4096], [0, 2]) # t170: "cuda:0 bf16[1, 1, 4096]"
del t169
t22 = Tensor.expand(t170, (1, 2048, 4096)) # t22: "cuda:0 bf16[1, 2048, 4096]"
# t22 = ltorch.expand(t170, (1, 2048, 4096)) # t22: "cuda:0 bf16[1, 2048, 4096]"
# t22 = prims.broadcast_in_dim(t170, (1, 2048, 4096), (0, 1, 2)) # t22: "cuda:0 bf16[1, 2048, 4096]"
del t170
[t13, t25, t9] = nvFusion0(t19, t22, t4)
# t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 2048, 4096]"
# (t8, t9) = prims.var_mean(t5, (2,), correction=0)
# t10 = prims.broadcast_in_dim(t8, [1, 2048, 1], [0, 1]) # t10: "cuda:0 f32[1, 2048, 1]"
# t11 = prims.broadcast_in_dim(t9, [1, 2048, 1], [0, 1]) # t11: "cuda:0 f32[1, 2048, 1]"
# t12 = prims.add(t10, 1e-05) # t12: "cuda:0 f32[1, 2048, 1]"
# t13 = prims.rsqrt(t12) # t13: "cuda:0 f32[1, 2048, 1]"
# t14 = prims.broadcast_in_dim(t11, (1, 2048, 4096), (0, 1, 2)) # t14: "cuda:0 f32[1, 2048, 4096]"
# t16 = prims.sub(t5, t14) # t16: "cuda:0 f32[1, 2048, 4096]"
# t17 = prims.broadcast_in_dim(t13, (1, 2048, 4096), (0, 1, 2)) # t17: "cuda:0 f32[1, 2048, 4096]"
# t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 2048, 4096]"
# t20 = prims.convert_element_type(t19, dtypes.float32) # t20: "cuda:0 f32[1, 2048, 4096]"
# t21 = prims.mul(t18, t20) # t21: "cuda:0 f32[1, 2048, 4096]"
# t23 = prims.convert_element_type(t22, dtypes.float32) # t23: "cuda:0 f32[1, 2048, 4096]"
# t24 = prims.add(t21, t23) # t24: "cuda:0 f32[1, 2048, 4096]"
# t25 = prims.convert_element_type(t24, dtypes.bfloat16) # t25: "cuda:0 bf16[1, 2048, 4096]"
del t22
t26 = torch.nn.functional.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]"
# t26 = ltorch.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]"
# t26 = prims.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]"
t107 = torch.nn.functional.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]"
# t107 = ltorch.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]"
# t107 = prims.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]"
t0 = torch_slice_prim_impl(tos1, [0, 0], [2048, 64], [1, 1]) # t0: "cuda:0 bf16[2048, 64]"
t1 = torch_slice_prim_impl(t_sin, [0, 0], [2048, 64], [1, 1]) # t1: "cuda:0 bf16[2048, 64]"
t27 = torch.reshape(t26, (1, 2048, 1, 66, 128)) # t27: "cuda:0 bf16[1, 2048, 1, 66, 128]"
# t27 = ltorch.reshape(t26, (1, 2048, 1, 66, 128)) # t27: "cuda:0 bf16[1, 2048, 1, 66, 128]"
# t27 = prims.reshape(t26, (1, 2048, 1, 66, 128)) # t27: "cuda:0 bf16[1, 2048, 1, 66, 128]"
del t26
t28 = torch.permute(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 66, 2048, 128]"
# t28 = ltorch.permute(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 66, 2048, 128]"
# t28 = prims.transpose(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 66, 2048, 128]"
del t27
(t29, t30, t31) = torch.split(t28, (64, 1, 1), 2)
# (t29, t30, t31) = ltorch.split(t28, (64, 1, 1), 2)
# t29 = prims.slice_prim(t28, [0, 0, 0, 0, 0], [1, 1, 64, 2048, 128], [1, 1, 1, 1, 1]) # t29: "cuda:0 bf16[1, 1, 64, 2048, 128]"
# t30 = prims.slice_prim(t28, [0, 0, 64, 0, 0], [1, 1, 65, 2048, 128], [1, 1, 1, 1, 1]) # t30: "cuda:0 bf16[1, 1, 1, 2048, 128]"
# t31 = prims.slice_prim(t28, [0, 0, 65, 0, 0], [1, 1, 66, 2048, 128], [1, 1, 1, 1, 1]) # t31: "cuda:0 bf16[1, 1, 1, 2048, 128]"
del t28
t32 = Tensor.expand(t30, (1, 1, 64, 2048, 128)) # t32: "cuda:0 bf16[1, 1, 64, 2048, 128]"
# t32 = ltorch.expand(t30, (1, 1, 64, 2048, 128)) # t32: "cuda:0 bf16[1, 1, 64, 2048, 128]"
# t32 = prims.broadcast_in_dim(t30, (1, 1, 64, 2048, 128), (0, 1, 2, 3, 4)) # t32: "cuda:0 bf16[1, 1, 64, 2048, 128]"
del t30
t38 = Tensor.expand(t31, (1, 1, 64, 2048, 128)) # t38: "cuda:0 bf16[1, 1, 64, 2048, 128]"
# t38 = ltorch.expand(t31, (1, 1, 64, 2048, 128)) # t38: "cuda:0 bf16[1, 1, 64, 2048, 128]"
# t38 = prims.broadcast_in_dim(t31, (1, 1, 64, 2048, 128), (0, 1, 2, 3, 4)) # t38: "cuda:0 bf16[1, 1, 64, 2048, 128]"
del t31
t39 = torch.reshape(t29, (1, 64, 2048, 128)) # t39: "cuda:0 bf16[1, 64, 2048, 128]"
# t39 = ltorch.reshape(t29, (1, 64, 2048, 128)) # t39: "cuda:0 bf16[1, 64, 2048, 128]"
# t39 = prims.reshape(t29, (1, 64, 2048, 128)) # t39: "cuda:0 bf16[1, 64, 2048, 128]"
del t29
t45 = torch.reshape(t32, (1, 64, 2048, 128)) # t45: "cuda:0 bf16[1, 64, 2048, 128]"
# t45 = ltorch.reshape(t32, (1, 64, 2048, 128)) # t45: "cuda:0 bf16[1, 64, 2048, 128]"
# t45 = prims.reshape(t32, (1, 64, 2048, 128)) # t45: "cuda:0 bf16[1, 64, 2048, 128]"
del t32
t51 = torch.reshape(t38, (1, 64, 2048, 128)) # t51: "cuda:0 bf16[1, 64, 2048, 128]"
# t51 = ltorch.reshape(t38, (1, 64, 2048, 128)) # t51: "cuda:0 bf16[1, 64, 2048, 128]"
# t51 = prims.reshape(t38, (1, 64, 2048, 128)) # t51: "cuda:0 bf16[1, 64, 2048, 128]"
del t38
t52 = torch_slice_prim_impl(t39, [0, 0, 0, 0], [1, 64, 2048, 64], [1, 1, 1, 1]) # t52: "cuda:0 bf16[1, 64, 2048, 64]"
t53 = torch_slice_prim_impl(t52, [0, 0, 0, 0], [1, 64, 2048, 32], [1, 1, 1, 1]) # t53: "cuda:0 bf16[1, 64, 2048, 32]"
t54 = torch_slice_prim_impl(t52, [0, 0, 0, 32], [1, 64, 2048, 64], [1, 1, 1, 1]) # t54: "cuda:0 bf16[1, 64, 2048, 32]"
t187 = torch.unsqueeze(t0, 0) # t187: "cuda:0 bf16[1, 2048, 64]"
# t187 = ltorch.unsqueeze(t0, 0) # t187: "cuda:0 bf16[1, 2048, 64]"
# t187 = prims.broadcast_in_dim(t0, [1, 2048, 64], [1, 2]) # t187: "cuda:0 bf16[1, 2048, 64]"
del t0
t188 = torch.unsqueeze(t187, 1) # t188: "cuda:0 bf16[1, 1, 2048, 64]"
# t188 = ltorch.unsqueeze(t187, 1) # t188: "cuda:0 bf16[1, 1, 2048, 64]"
# t188 = prims.broadcast_in_dim(t187, [1, 1, 2048, 64], [0, 2, 3]) # t188: "cuda:0 bf16[1, 1, 2048, 64]"
del t187
t59 = Tensor.expand(t188, (1, 64, 2048, 64)) # t59: "cuda:0 bf16[1, 64, 2048, 64]"
# t59 = ltorch.expand(t188, (1, 64, 2048, 64)) # t59: "cuda:0 bf16[1, 64, 2048, 64]"
# t59 = prims.broadcast_in_dim(t188, (1, 64, 2048, 64), (0, 1, 2, 3)) # t59: "cuda:0 bf16[1, 64, 2048, 64]"
del t188
t190 = torch.unsqueeze(t1, 0) # t190: "cuda:0 bf16[1, 2048, 64]"
# t190 = ltorch.unsqueeze(t1, 0) # t190: "cuda:0 bf16[1, 2048, 64]"
# t190 = prims.broadcast_in_dim(t1, [1, 2048, 64], [1, 2]) # t190: "cuda:0 bf16[1, 2048, 64]"
del t1
t191 = torch.unsqueeze(t190, 1) # t191: "cuda:0 bf16[1, 1, 2048, 64]"
# t191 = ltorch.unsqueeze(t190, 1) # t191: "cuda:0 bf16[1, 1, 2048, 64]"
# t191 = prims.broadcast_in_dim(t190, [1, 1, 2048, 64], [0, 2, 3]) # t191: "cuda:0 bf16[1, 1, 2048, 64]"
del t190
t64 = Tensor.expand(t191, (1, 64, 2048, 64)) # t64: "cuda:0 bf16[1, 64, 2048, 64]"
# t64 = ltorch.expand(t191, (1, 64, 2048, 64)) # t64: "cuda:0 bf16[1, 64, 2048, 64]"
# t64 = prims.broadcast_in_dim(t191, (1, 64, 2048, 64), (0, 1, 2, 3)) # t64: "cuda:0 bf16[1, 64, 2048, 64]"
del t191
t73 = torch_slice_prim_impl(t45, [0, 0, 0, 0], [1, 64, 2048, 64], [1, 1, 1, 1]) # t73: "cuda:0 bf16[1, 64, 2048, 64]"
t74 = torch_slice_prim_impl(t73, [0, 0, 0, 0], [1, 64, 2048, 32], [1, 1, 1, 1]) # t74: "cuda:0 bf16[1, 64, 2048, 32]"
t75 = torch_slice_prim_impl(t73, [0, 0, 0, 32], [1, 64, 2048, 64], [1, 1, 1, 1]) # t75: "cuda:0 bf16[1, 64, 2048, 32]"
t95 = torch_slice_prim_impl(t39, [0, 0, 0, 64], [1, 64, 2048, 128], [1, 1, 1, 1]) # t95: "cuda:0 bf16[1, 64, 2048, 64]"
del t39
t97 = torch_slice_prim_impl(t45, [0, 0, 0, 64], [1, 64, 2048, 128], [1, 1, 1, 1]) # t97: "cuda:0 bf16[1, 64, 2048, 64]"
del t45
[t123, t96, t99] = nvFusion1(t107, t52, t53, t54, t59, t64, t73, t74, t75, t95, t97)
# t55 = prims.convert_element_type(t54, dtypes.float32) # t55: "cuda:0 f32[1, 64, 2048, 32]"
# t56 = prims.neg(t55) # t56: "cuda:0 f32[1, 64, 2048, 32]"
# t57 = prims.convert_element_type(t56, dtypes.bfloat16) # t57: "cuda:0 bf16[1, 64, 2048, 32]"
# t58 = prims.cat((t57, t53), -1) # t58: "cuda:0 bf16[1, 64, 2048, 64]"
# t60 = prims.convert_element_type(t52, dtypes.float32) # t60: "cuda:0 f32[1, 64, 2048, 64]"
# t61 = prims.convert_element_type(t59, dtypes.float32) # t61: "cuda:0 f32[1, 64, 2048, 64]"
# t62 = prims.mul(t60, t61) # t62: "cuda:0 f32[1, 64, 2048, 64]"
# t65 = prims.convert_element_type(t58, dtypes.float32) # t65: "cuda:0 f32[1, 64, 2048, 64]"
# t66 = prims.convert_element_type(t64, dtypes.float32) # t66: "cuda:0 f32[1, 64, 2048, 64]"
# t67 = prims.mul(t65, t66) # t67: "cuda:0 f32[1, 64, 2048, 64]"
# t71 = prims.add(t62, t67) # t71: "cuda:0 f32[1, 64, 2048, 64]"
# t72 = prims.convert_element_type(t71, dtypes.bfloat16) # t72: "cuda:0 bf16[1, 64, 2048, 64]"
# t76 = prims.convert_element_type(t75, dtypes.float32) # t76: "cuda:0 f32[1, 64, 2048, 32]"
# t77 = prims.neg(t76) # t77: "cuda:0 f32[1, 64, 2048, 32]"
# t78 = prims.convert_element_type(t77, dtypes.bfloat16) # t78: "cuda:0 bf16[1, 64, 2048, 32]"
# t80 = prims.cat((t78, t74), -1) # t80: "cuda:0 bf16[1, 64, 2048, 64]"
# t82 = prims.convert_element_type(t73, dtypes.float32) # t82: "cuda:0 f32[1, 64, 2048, 64]"
# t84 = prims.mul(t82, t61) # t84: "cuda:0 f32[1, 64, 2048, 64]"
# t87 = prims.convert_element_type(t80, dtypes.float32) # t87: "cuda:0 f32[1, 64, 2048, 64]"
# t89 = prims.mul(t87, t66) # t89: "cuda:0 f32[1, 64, 2048, 64]"
# t93 = prims.add(t84, t89) # t93: "cuda:0 f32[1, 64, 2048, 64]"
# t94 = prims.convert_element_type(t93, dtypes.bfloat16) # t94: "cuda:0 bf16[1, 64, 2048, 64]"
# t96 = prims.cat((t72, t95), -1) # t96: "cuda:0 bf16[1, 64, 2048, 128]"
# t99 = prims.cat((t94, t97), -1) # t99: "cuda:0 bf16[1, 64, 2048, 128]"
# t108 = prims.convert_element_type(t107, dtypes.float32) # t108: "cuda:0 f32[1, 2048, 18176]"
# t109 = prims.div(t108, 1.4142135623730951) # t109: "cuda:0 f32[1, 2048, 18176]"
# t112 = prims.erf(t109) # t112: "cuda:0 f32[1, 2048, 18176]"
# t115 = prims.mul(0.5, t112) # t115: "cuda:0 f32[1, 2048, 18176]"
# t118 = prims.add(0.5, t115) # t118: "cuda:0 f32[1, 2048, 18176]"
# t122 = prims.mul(t108, t118) # t122: "cuda:0 f32[1, 2048, 18176]"
# t123 = prims.convert_element_type(t122, dtypes.bfloat16) # t123: "cuda:0 bf16[1, 2048, 18176]"
del t52, t53, t54, t73, t74, t75, t95, t97
(t100, t101, t102, t103) = cudnn_sdpa_fwd(t96, t99, t51, None, 0.0, True, scale=0.08838834764831843)
t124 = torch.nn.functional.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]"
# t124 = ltorch.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]"
# t124 = prims.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]"
t104 = torch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]"
# t104 = ltorch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]"
# t104 = prims.transpose(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]"
t105 = torch.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]"
# t105 = ltorch.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]"
# t105 = prims.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]"
del t104
t106 = torch.nn.functional.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]"
# t106 = ltorch.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]"
# t106 = prims.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]"
t200 = torch.unsqueeze(t_transformer_ln_f_weight, 0) # t200: "cuda:0 bf16[1, 4096]"
# t200 = ltorch.unsqueeze(t_transformer_ln_f_weight, 0) # t200: "cuda:0 bf16[1, 4096]"
# t200 = prims.broadcast_in_dim(t_transformer_ln_f_weight, [1, 4096], [1]) # t200: "cuda:0 bf16[1, 4096]"
t201 = torch.unsqueeze(t200, 1) # t201: "cuda:0 bf16[1, 1, 4096]"
# t201 = ltorch.unsqueeze(t200, 1) # t201: "cuda:0 bf16[1, 1, 4096]"
# t201 = prims.broadcast_in_dim(t200, [1, 1, 4096], [0, 2]) # t201: "cuda:0 bf16[1, 1, 4096]"
del t200
t150 = Tensor.expand(t201, (1, 2048, 4096)) # t150: "cuda:0 bf16[1, 2048, 4096]"
# t150 = ltorch.expand(t201, (1, 2048, 4096)) # t150: "cuda:0 bf16[1, 2048, 4096]"
# t150 = prims.broadcast_in_dim(t201, (1, 2048, 4096), (0, 1, 2)) # t150: "cuda:0 bf16[1, 2048, 4096]"
del t201
t203 = torch.unsqueeze(t_transformer_ln_f_bias, 0) # t203: "cuda:0 bf16[1, 4096]"
# t203 = ltorch.unsqueeze(t_transformer_ln_f_bias, 0) # t203: "cuda:0 bf16[1, 4096]"
# t203 = prims.broadcast_in_dim(t_transformer_ln_f_bias, [1, 4096], [1]) # t203: "cuda:0 bf16[1, 4096]"
t204 = torch.unsqueeze(t203, 1) # t204: "cuda:0 bf16[1, 1, 4096]"
# t204 = ltorch.unsqueeze(t203, 1) # t204: "cuda:0 bf16[1, 1, 4096]"
# t204 = prims.broadcast_in_dim(t203, [1, 1, 4096], [0, 2]) # t204: "cuda:0 bf16[1, 1, 4096]"
del t203
t153 = Tensor.expand(t204, (1, 2048, 4096)) # t153: "cuda:0 bf16[1, 2048, 4096]"
# t153 = ltorch.expand(t204, (1, 2048, 4096)) # t153: "cuda:0 bf16[1, 2048, 4096]"
# t153 = prims.broadcast_in_dim(t204, (1, 2048, 4096), (0, 1, 2)) # t153: "cuda:0 bf16[1, 2048, 4096]"
del t204
[t139, t144, t156] = nvFusion2(t106, t124, t150, t153, t4)
# t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 2048, 4096]"
# t126 = prims.convert_element_type(t106, dtypes.float32) # t126: "cuda:0 f32[1, 2048, 4096]"
# t127 = prims.add(t125, t126) # t127: "cuda:0 f32[1, 2048, 4096]"
# t130 = prims.convert_element_type(t4, dtypes.float32) # t130: "cuda:0 f32[1, 2048, 4096]"
# t131 = prims.add(t127, t130) # t131: "cuda:0 f32[1, 2048, 4096]"
# (t138, t139) = prims.var_mean(t131, (2,), correction=0)
# t140 = prims.broadcast_in_dim(t138, [1, 2048, 1], [0, 1]) # t140: "cuda:0 f32[1, 2048, 1]"
# t141 = prims.broadcast_in_dim(t139, [1, 2048, 1], [0, 1]) # t141: "cuda:0 f32[1, 2048, 1]"
# t143 = prims.add(t140, 1e-05) # t143: "cuda:0 f32[1, 2048, 1]"
# t144 = prims.rsqrt(t143) # t144: "cuda:0 f32[1, 2048, 1]"
# t145 = prims.broadcast_in_dim(t141, (1, 2048, 4096), (0, 1, 2)) # t145: "cuda:0 f32[1, 2048, 4096]"
# t147 = prims.sub(t131, t145) # t147: "cuda:0 f32[1, 2048, 4096]"
# t148 = prims.broadcast_in_dim(t144, (1, 2048, 4096), (0, 1, 2)) # t148: "cuda:0 f32[1, 2048, 4096]"
# t149 = prims.mul(t147, t148) # t149: "cuda:0 f32[1, 2048, 4096]"
# t151 = prims.convert_element_type(t150, dtypes.float32) # t151: "cuda:0 f32[1, 2048, 4096]"
# t152 = prims.mul(t149, t151) # t152: "cuda:0 f32[1, 2048, 4096]"
# t154 = prims.convert_element_type(t153, dtypes.float32) # t154: "cuda:0 f32[1, 2048, 4096]"
# t155 = prims.add(t152, t154) # t155: "cuda:0 f32[1, 2048, 4096]"
# t156 = prims.convert_element_type(t155, dtypes.bfloat16) # t156: "cuda:0 bf16[1, 2048, 4096]"
del t153
t157 = torch.nn.functional.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]"
# t157 = ltorch.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]"
# t157 = prims.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]"
return {'output': t157, 'flat_args': [idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_bias, t_transformer_h_0_norm_1_weight, t_transformer_ln_f_bias, t_transformer_ln_f_weight, t_transformer_wte_weight], 'flat_output': (t157,)}, ((idx, t100, t101, t102, t103, t105, t106, t107, t123, t124, t13, t139, t144, t150, t156, t19, t25, t4, t51, t59, t64, t9, t96, t99, t_lm_head_weight, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight), (False, False, True, 0.0, 0.08838834764831843, 1.4142135623730951, 0.5, 65024, 2, 0, 0))
Above trace failed
@torch.no_grad()
@no_autocast
def augmented_forward_fn(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_bias, t_transformer_h_0_norm_1_weight, t_transformer_ln_f_bias, t_transformer_ln_f_weight, t_transformer_wte_weight):
# idx: "cuda:0 i64[1, 2048]"
# tos1: "cuda:0 bf16[2048, 64]"
# t_lm_head_weight: "cuda:0 bf16[65024, 4096]"
# t_sin: "cuda:0 bf16[2048, 64]"
# t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[8448, 4096]"
# t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 8192]"
# t_transformer_h_0_mlp_fc_weight: "cuda:0 bf16[18176, 4096]"
# t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 18176]"
# t_transformer_h_0_norm_1_bias: "cuda:0 bf16[4096]"
# t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]"
# t_transformer_ln_f_bias: "cuda:0 bf16[4096]"
# t_transformer_ln_f_weight: "cuda:0 bf16[4096]"
# t_transformer_wte_weight: "cuda:0 bf16[65024, 4096]"
t4 = torch.nn.functional.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4096]"
# t4 = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4096]"
# t158 = ltorch.reshape(idx, [2048]) # t158: "cuda:0 i64[2048]"
# t158 = prims.reshape(idx, (2048,)) # t158: "cuda:0 i64[2048]"
# t159 = prims.take(t_transformer_wte_weight, t158, 0) # t159: "cuda:0 bf16[2048, 4096]"
# t4 = ltorch.reshape(t159, [1, 2048, 4096]) # t4: "cuda:0 bf16[1, 2048, 4096]"
# t4 = prims.reshape(t159, (1, 2048, 4096)) # t4: "cuda:0 bf16[1, 2048, 4096]"
t166 = torch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t166: "cuda:0 bf16[1, 4096]"
# t166 = ltorch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t166: "cuda:0 bf16[1, 4096]"
# t166 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, [1, 4096], [1]) # t166: "cuda:0 bf16[1, 4096]"
t167 = torch.unsqueeze(t166, 1) # t167: "cuda:0 bf16[1, 1, 4096]"
# t167 = ltorch.unsqueeze(t166, 1) # t167: "cuda:0 bf16[1, 1, 4096]"
# t167 = prims.broadcast_in_dim(t166, [1, 1, 4096], [0, 2]) # t167: "cuda:0 bf16[1, 1, 4096]"
del t166
t19 = Tensor.expand(t167, (1, 2048, 4096)) # t19: "cuda:0 bf16[1, 2048, 4096]"
# t19 = ltorch.expand(t167, (1, 2048, 4096)) # t19: "cuda:0 bf16[1, 2048, 4096]"
# t19 = prims.broadcast_in_dim(t167, (1, 2048, 4096), (0, 1, 2)) # t19: "cuda:0 bf16[1, 2048, 4096]"
del t167
t169 = torch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t169: "cuda:0 bf16[1, 4096]"
# t169 = ltorch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t169: "cuda:0 bf16[1, 4096]"
# t169 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_bias, [1, 4096], [1]) # t169: "cuda:0 bf16[1, 4096]"
t170 = torch.unsqueeze(t169, 1) # t170: "cuda:0 bf16[1, 1, 4096]"
# t170 = ltorch.unsqueeze(t169, 1) # t170: "cuda:0 bf16[1, 1, 4096]"
# t170 = prims.broadcast_in_dim(t169, [1, 1, 4096], [0, 2]) # t170: "cuda:0 bf16[1, 1, 4096]"
del t169
t22 = Tensor.expand(t170, (1, 2048, 4096)) # t22: "cuda:0 bf16[1, 2048, 4096]"
# t22 = ltorch.expand(t170, (1, 2048, 4096)) # t22: "cuda:0 bf16[1, 2048, 4096]"
# t22 = prims.broadcast_in_dim(t170, (1, 2048, 4096), (0, 1, 2)) # t22: "cuda:0 bf16[1, 2048, 4096]"
del t170
[t13, t25, t9] = nvFusion0(t19, t22, t4)
# t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 2048, 4096]"
# (t8, t9) = prims.var_mean(t5, (2,), correction=0)
# t10 = prims.broadcast_in_dim(t8, [1, 2048, 1], [0, 1]) # t10: "cuda:0 f32[1, 2048, 1]"
# t12 = prims.add(t10, 1e-05) # t12: "cuda:0 f32[1, 2048, 1]"
# t11 = prims.broadcast_in_dim(t9, [1, 2048, 1], [0, 1]) # t11: "cuda:0 f32[1, 2048, 1]"
# t14 = prims.broadcast_in_dim(t11, (1, 2048, 4096), (0, 1, 2)) # t14: "cuda:0 f32[1, 2048, 4096]"
# t13 = prims.rsqrt(t12) # t13: "cuda:0 f32[1, 2048, 1]"
# t16 = prims.sub(t5, t14) # t16: "cuda:0 f32[1, 2048, 4096]"
# t17 = prims.broadcast_in_dim(t13, (1, 2048, 4096), (0, 1, 2)) # t17: "cuda:0 f32[1, 2048, 4096]"
# t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 2048, 4096]"
# t20 = prims.convert_element_type(t19, dtypes.float32) # t20: "cuda:0 f32[1, 2048, 4096]"
# t21 = prims.mul(t18, t20) # t21: "cuda:0 f32[1, 2048, 4096]"
# t23 = prims.convert_element_type(t22, dtypes.float32) # t23: "cuda:0 f32[1, 2048, 4096]"
# t24 = prims.add(t21, t23) # t24: "cuda:0 f32[1, 2048, 4096]"
# t25 = prims.convert_element_type(t24, dtypes.bfloat16) # t25: "cuda:0 bf16[1, 2048, 4096]"
del t22
t26 = torch.nn.functional.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]"
# t26 = ltorch.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]"
# t26 = prims.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]"
t107 = torch.nn.functional.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]"
# t107 = ltorch.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]"
# t107 = prims.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]"
[t51, t61, t66, t96, t99] = TorchCompile0(t107, t26, t_sin, tos1)
# t0 = prims.slice_prim(tos1, [0, 0], [2048, 64], [1, 1]) # t0: "cuda:0 bf16[2048, 64]"
# t1 = prims.slice_prim(t_sin, [0, 0], [2048, 64], [1, 1]) # t1: "cuda:0 bf16[2048, 64]"
# t27 = prims.reshape(t26, (1, 2048, 1, 66, 128)) # t27: "cuda:0 bf16[1, 2048, 1, 66, 128]"
# t28 = prims.transpose(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 66, 2048, 128]"
# (t29, t30, t31) = ltorch.split(t28, (64, 1, 1), 2)
# t29 = prims.slice_prim(t28, [0, 0, 0, 0, 0], [1, 1, 64, 2048, 128], [1, 1, 1, 1, 1]) # t29: "cuda:0 bf16[1, 1, 64, 2048, 128]"
# t30 = prims.slice_prim(t28, [0, 0, 64, 0, 0], [1, 1, 65, 2048, 128], [1, 1, 1, 1, 1]) # t30: "cuda:0 bf16[1, 1, 1, 2048, 128]"
# t31 = prims.slice_prim(t28, [0, 0, 65, 0, 0], [1, 1, 66, 2048, 128], [1, 1, 1, 1, 1]) # t31: "cuda:0 bf16[1, 1, 1, 2048, 128]"
# t32 = prims.broadcast_in_dim(t30, (1, 1, 64, 2048, 128), (0, 1, 2, 3, 4)) # t32: "cuda:0 bf16[1, 1, 64, 2048, 128]"
# t38 = prims.broadcast_in_dim(t31, (1, 1, 64, 2048, 128), (0, 1, 2, 3, 4)) # t38: "cuda:0 bf16[1, 1, 64, 2048, 128]"
# t39 = prims.reshape(t29, (1, 64, 2048, 128)) # t39: "cuda:0 bf16[1, 64, 2048, 128]"
# t45 = prims.reshape(t32, (1, 64, 2048, 128)) # t45: "cuda:0 bf16[1, 64, 2048, 128]"
# t51 = prims.reshape(t38, (1, 64, 2048, 128)) # t51: "cuda:0 bf16[1, 64, 2048, 128]"
# t52 = prims.slice_prim(t39, [0, 0, 0, 0], [1, 64, 2048, 64], [1, 1, 1, 1]) # t52: "cuda:0 bf16[1, 64, 2048, 64]"
# t53 = prims.slice_prim(t52, [0, 0, 0, 0], [1, 64, 2048, 32], [1, 1, 1, 1]) # t53: "cuda:0 bf16[1, 64, 2048, 32]"
# t54 = prims.slice_prim(t52, [0, 0, 0, 32], [1, 64, 2048, 64], [1, 1, 1, 1]) # t54: "cuda:0 bf16[1, 64, 2048, 32]"
# t55 = prims.convert_element_type(t54, dtypes.float32) # t55: "cuda:0 f32[1, 64, 2048, 32]"
# t56 = prims.neg(t55) # t56: "cuda:0 f32[1, 64, 2048, 32]"
# t57 = prims.convert_element_type(t56, dtypes.bfloat16) # t57: "cuda:0 bf16[1, 64, 2048, 32]"
# t58 = prims.cat((t57, t53), -1) # t58: "cuda:0 bf16[1, 64, 2048, 64]"
# t59 = prims.broadcast_in_dim(t0, (1, 64, 2048, 64), (2, 3)) # t59: "cuda:0 bf16[1, 64, 2048, 64]"
# t60 = prims.convert_element_type(t52, dtypes.float32) # t60: "cuda:0 f32[1, 64, 2048, 64]"
# t61 = prims.convert_element_type(t59, dtypes.float32) # t61: "cuda:0 f32[1, 64, 2048, 64]"
# t62 = ltorch.mul(t60, t61) # t62: "cuda:0 f32[1, 64, 2048, 64]"
# t62 = prims.mul(t60, t61) # t62: "cuda:0 f32[1, 64, 2048, 64]"
# t63 = prims.convert_element_type(t62, dtypes.bfloat16) # t63: "cuda:0 bf16[1, 64, 2048, 64]"
# t64 = prims.broadcast_in_dim(t1, (1, 64, 2048, 64), (2, 3)) # t64: "cuda:0 bf16[1, 64, 2048, 64]"
# t65 = prims.convert_element_type(t58, dtypes.float32) # t65: "cuda:0 f32[1, 64, 2048, 64]"
# t66 = prims.convert_element_type(t64, dtypes.float32) # t66: "cuda:0 f32[1, 64, 2048, 64]"
# t67 = ltorch.mul(t65, t66) # t67: "cuda:0 f32[1, 64, 2048, 64]"
# t67 = prims.mul(t65, t66) # t67: "cuda:0 f32[1, 64, 2048, 64]"
# t68 = prims.convert_element_type(t67, dtypes.bfloat16) # t68: "cuda:0 bf16[1, 64, 2048, 64]"
# t71 = ltorch.add(t62, t67, alpha=None) # t71: "cuda:0 f32[1, 64, 2048, 64]"
# t71 = prims.add(t62, t67) # t71: "cuda:0 f32[1, 64, 2048, 64]"
# t72 = prims.convert_element_type(t71, dtypes.bfloat16) # t72: "cuda:0 bf16[1, 64, 2048, 64]"
# t73 = prims.slice_prim(t45, [0, 0, 0, 0], [1, 64, 2048, 64], [1, 1, 1, 1]) # t73: "cuda:0 bf16[1, 64, 2048, 64]"
# t74 = prims.slice_prim(t73, [0, 0, 0, 0], [1, 64, 2048, 32], [1, 1, 1, 1]) # t74: "cuda:0 bf16[1, 64, 2048, 32]"
# t75 = prims.slice_prim(t73, [0, 0, 0, 32], [1, 64, 2048, 64], [1, 1, 1, 1]) # t75: "cuda:0 bf16[1, 64, 2048, 32]"
# t76 = prims.convert_element_type(t75, dtypes.float32) # t76: "cuda:0 f32[1, 64, 2048, 32]"
# t77 = prims.neg(t76) # t77: "cuda:0 f32[1, 64, 2048, 32]"
# t78 = prims.convert_element_type(t77, dtypes.bfloat16) # t78: "cuda:0 bf16[1, 64, 2048, 32]"
# t80 = prims.cat((t78, t74), -1) # t80: "cuda:0 bf16[1, 64, 2048, 64]"
# t82 = prims.convert_element_type(t73, dtypes.float32) # t82: "cuda:0 f32[1, 64, 2048, 64]"
# t84 = ltorch.mul(t82, t61) # t84: "cuda:0 f32[1, 64, 2048, 64]"
# t84 = prims.mul(t82, t61) # t84: "cuda:0 f32[1, 64, 2048, 64]"
# t85 = prims.convert_element_type(t84, dtypes.bfloat16) # t85: "cuda:0 bf16[1, 64, 2048, 64]"
# t87 = prims.convert_element_type(t80, dtypes.float32) # t87: "cuda:0 f32[1, 64, 2048, 64]"
# t89 = ltorch.mul(t87, t66) # t89: "cuda:0 f32[1, 64, 2048, 64]"
# t89 = prims.mul(t87, t66) # t89: "cuda:0 f32[1, 64, 2048, 64]"
# t90 = prims.convert_element_type(t89, dtypes.bfloat16) # t90: "cuda:0 bf16[1, 64, 2048, 64]"
# t93 = ltorch.add(t84, t89, alpha=None) # t93: "cuda:0 f32[1, 64, 2048, 64]"
# t93 = prims.add(t84, t89) # t93: "cuda:0 f32[1, 64, 2048, 64]"
# t94 = prims.convert_element_type(t93, dtypes.bfloat16) # t94: "cuda:0 bf16[1, 64, 2048, 64]"
# t95 = prims.slice_prim(t39, [0, 0, 0, 64], [1, 64, 2048, 128], [1, 1, 1, 1]) # t95: "cuda:0 bf16[1, 64, 2048, 64]"
# t96 = prims.cat((t72, t95), -1) # t96: "cuda:0 bf16[1, 64, 2048, 128]"
# t97 = prims.slice_prim(t45, [0, 0, 0, 64], [1, 64, 2048, 128], [1, 1, 1, 1]) # t97: "cuda:0 bf16[1, 64, 2048, 64]"
# t99 = prims.cat((t94, t97), -1) # t99: "cuda:0 bf16[1, 64, 2048, 128]"
# t108 = prims.convert_element_type(t107, dtypes.float32) # t108: "cuda:0 f32[1, 2048, 18176]"
del t26
(t100, t101, t102, t103) = cudnn_sdpa_fwd(t96, t99, t51, None, 0.0, True, scale=0.08838834764831843)
t104 = torch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]"
# t104 = ltorch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]"
# t104 = prims.transpose(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]"
t105 = torch.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]"
# t105 = ltorch.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]"
# t105 = prims.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]"
del t104
[t118, t123] = nvFusion1(t107)
# t108 = prims.convert_element_type(t107, dtypes.float32) # t108: "cuda:0 f32[1, 2048, 18176]"
# t109 = prims.div(t108, 1.4142135623730951) # t109: "cuda:0 f32[1, 2048, 18176]"
# t112 = prims.erf(t109) # t112: "cuda:0 f32[1, 2048, 18176]"
# t115 = prims.mul(0.5, t112) # t115: "cuda:0 f32[1, 2048, 18176]"
# t118 = prims.add(0.5, t115) # t118: "cuda:0 f32[1, 2048, 18176]"
# t122 = prims.mul(t108, t118) # t122: "cuda:0 f32[1, 2048, 18176]"
# t123 = prims.convert_element_type(t122, dtypes.bfloat16) # t123: "cuda:0 bf16[1, 2048, 18176]"
t124 = torch.nn.functional.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]"
# t124 = ltorch.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]"
# t124 = prims.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]"
t106 = torch.nn.functional.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]"
# t106 = ltorch.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]"
# t106 = prims.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]"
t174 = torch.unsqueeze(t_transformer_ln_f_weight, 0) # t174: "cuda:0 bf16[1, 4096]"
# t174 = ltorch.unsqueeze(t_transformer_ln_f_weight, 0) # t174: "cuda:0 bf16[1, 4096]"
# t174 = prims.broadcast_in_dim(t_transformer_ln_f_weight, [1, 4096], [1]) # t174: "cuda:0 bf16[1, 4096]"
t175 = torch.unsqueeze(t174, 1) # t175: "cuda:0 bf16[1, 1, 4096]"
# t175 = ltorch.unsqueeze(t174, 1) # t175: "cuda:0 bf16[1, 1, 4096]"
# t175 = prims.broadcast_in_dim(t174, [1, 1, 4096], [0, 2]) # t175: "cuda:0 bf16[1, 1, 4096]"
del t174
t150 = Tensor.expand(t175, (1, 2048, 4096)) # t150: "cuda:0 bf16[1, 2048, 4096]"
# t150 = ltorch.expand(t175, (1, 2048, 4096)) # t150: "cuda:0 bf16[1, 2048, 4096]"
# t150 = prims.broadcast_in_dim(t175, (1, 2048, 4096), (0, 1, 2)) # t150: "cuda:0 bf16[1, 2048, 4096]"
del t175
t177 = torch.unsqueeze(t_transformer_ln_f_bias, 0) # t177: "cuda:0 bf16[1, 4096]"
# t177 = ltorch.unsqueeze(t_transformer_ln_f_bias, 0) # t177: "cuda:0 bf16[1, 4096]"
# t177 = prims.broadcast_in_dim(t_transformer_ln_f_bias, [1, 4096], [1]) # t177: "cuda:0 bf16[1, 4096]"
t178 = torch.unsqueeze(t177, 1) # t178: "cuda:0 bf16[1, 1, 4096]"
# t178 = ltorch.unsqueeze(t177, 1) # t178: "cuda:0 bf16[1, 1, 4096]"
# t178 = prims.broadcast_in_dim(t177, [1, 1, 4096], [0, 2]) # t178: "cuda:0 bf16[1, 1, 4096]"
del t177
t153 = Tensor.expand(t178, (1, 2048, 4096)) # t153: "cuda:0 bf16[1, 2048, 4096]"
# t153 = ltorch.expand(t178, (1, 2048, 4096)) # t153: "cuda:0 bf16[1, 2048, 4096]"
# t153 = prims.broadcast_in_dim(t178, (1, 2048, 4096), (0, 1, 2)) # t153: "cuda:0 bf16[1, 2048, 4096]"
del t178
[t139, t144, t156] = nvFusion2(t106, t124, t150, t153, t4)
# t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 2048, 4096]"
# t126 = prims.convert_element_type(t106, dtypes.float32) # t126: "cuda:0 f32[1, 2048, 4096]"
# t127 = prims.add(t125, t126) # t127: "cuda:0 f32[1, 2048, 4096]"
# t130 = prims.convert_element_type(t4, dtypes.float32) # t130: "cuda:0 f32[1, 2048, 4096]"
# t131 = prims.add(t127, t130) # t131: "cuda:0 f32[1, 2048, 4096]"
# (t138, t139) = prims.var_mean(t131, (2,), correction=0)
# t140 = prims.broadcast_in_dim(t138, [1, 2048, 1], [0, 1]) # t140: "cuda:0 f32[1, 2048, 1]"
# t143 = prims.add(t140, 1e-05) # t143: "cuda:0 f32[1, 2048, 1]"
# t141 = prims.broadcast_in_dim(t139, [1, 2048, 1], [0, 1]) # t141: "cuda:0 f32[1, 2048, 1]"
# t145 = prims.broadcast_in_dim(t141, (1, 2048, 4096), (0, 1, 2)) # t145: "cuda:0 f32[1, 2048, 4096]"
# t144 = prims.rsqrt(t143) # t144: "cuda:0 f32[1, 2048, 1]"
# t147 = prims.sub(t131, t145) # t147: "cuda:0 f32[1, 2048, 4096]"
# t148 = prims.broadcast_in_dim(t144, (1, 2048, 4096), (0, 1, 2)) # t148: "cuda:0 f32[1, 2048, 4096]"
# t149 = prims.mul(t147, t148) # t149: "cuda:0 f32[1, 2048, 4096]"
# t151 = prims.convert_element_type(t150, dtypes.float32) # t151: "cuda:0 f32[1, 2048, 4096]"
# t152 = prims.mul(t149, t151) # t152: "cuda:0 f32[1, 2048, 4096]"
# t154 = prims.convert_element_type(t153, dtypes.float32) # t154: "cuda:0 f32[1, 2048, 4096]"
# t155 = prims.add(t152, t154) # t155: "cuda:0 f32[1, 2048, 4096]"
# t156 = prims.convert_element_type(t155, dtypes.bfloat16) # t156: "cuda:0 bf16[1, 2048, 4096]"
del t153
t157 = torch.nn.functional.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]"
# t157 = ltorch.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]"
# t157 = prims.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]"
return {'output': t157, 'flat_args': [idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_bias, t_transformer_h_0_norm_1_weight, t_transformer_ln_f_bias, t_transformer_ln_f_weight, t_transformer_wte_weight], 'flat_output': (t157,)}, ((idx, t100, t101, t102, t103, t105, t106, t107, t118, t123, t124, t13, t139, t144, t150, t156, t19, t25, t4, t51, t61, t66, t9, t96, t99, t_lm_head_weight, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight), (False, False, True, 0.0, 0.08838834764831843, 1.4142135623730951, 0.5, 65024, 2, 0, 0))
Above trace passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment