Created
June 17, 2024 22:50
-
-
Save parthmannan/dc0c918dcea33bb8664c6577349f1d01 to your computer and use it in GitHub Desktop.
Falcon 7B Thunder Debug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@torch.no_grad() | |
@no_autocast | |
def augmented_forward_fn(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_bias, t_transformer_h_0_norm_1_weight, t_transformer_ln_f_bias, t_transformer_ln_f_weight, t_transformer_wte_weight): | |
# idx: "cuda:0 i64[1, 2048]" | |
# tos1: "cuda:0 bf16[2048, 64]" | |
# t_lm_head_weight: "cuda:0 bf16[65024, 4096]" | |
# t_sin: "cuda:0 bf16[2048, 64]" | |
# t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[8448, 4096]" | |
# t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 8192]" | |
# t_transformer_h_0_mlp_fc_weight: "cuda:0 bf16[18176, 4096]" | |
# t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 18176]" | |
# t_transformer_h_0_norm_1_bias: "cuda:0 bf16[4096]" | |
# t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]" | |
# t_transformer_ln_f_bias: "cuda:0 bf16[4096]" | |
# t_transformer_ln_f_weight: "cuda:0 bf16[4096]" | |
# t_transformer_wte_weight: "cuda:0 bf16[65024, 4096]" | |
t4 = torch.nn.functional.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4096]" | |
# t4 = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4096]" | |
# t158 = ltorch.reshape(idx, [2048]) # t158: "cuda:0 i64[2048]" | |
# t158 = prims.reshape(idx, (2048,)) # t158: "cuda:0 i64[2048]" | |
# t159 = prims.take(t_transformer_wte_weight, t158, 0) # t159: "cuda:0 bf16[2048, 4096]" | |
# t4 = ltorch.reshape(t159, [1, 2048, 4096]) # t4: "cuda:0 bf16[1, 2048, 4096]" | |
# t4 = prims.reshape(t159, (1, 2048, 4096)) # t4: "cuda:0 bf16[1, 2048, 4096]" | |
t166 = torch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t166: "cuda:0 bf16[1, 4096]" | |
# t166 = ltorch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t166: "cuda:0 bf16[1, 4096]" | |
# t166 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, [1, 4096], [1]) # t166: "cuda:0 bf16[1, 4096]" | |
t167 = torch.unsqueeze(t166, 1) # t167: "cuda:0 bf16[1, 1, 4096]" | |
# t167 = ltorch.unsqueeze(t166, 1) # t167: "cuda:0 bf16[1, 1, 4096]" | |
# t167 = prims.broadcast_in_dim(t166, [1, 1, 4096], [0, 2]) # t167: "cuda:0 bf16[1, 1, 4096]" | |
del t166 | |
t19 = Tensor.expand(t167, (1, 2048, 4096)) # t19: "cuda:0 bf16[1, 2048, 4096]" | |
# t19 = ltorch.expand(t167, (1, 2048, 4096)) # t19: "cuda:0 bf16[1, 2048, 4096]" | |
# t19 = prims.broadcast_in_dim(t167, (1, 2048, 4096), (0, 1, 2)) # t19: "cuda:0 bf16[1, 2048, 4096]" | |
del t167 | |
t169 = torch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t169: "cuda:0 bf16[1, 4096]" | |
# t169 = ltorch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t169: "cuda:0 bf16[1, 4096]" | |
# t169 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_bias, [1, 4096], [1]) # t169: "cuda:0 bf16[1, 4096]" | |
t170 = torch.unsqueeze(t169, 1) # t170: "cuda:0 bf16[1, 1, 4096]" | |
# t170 = ltorch.unsqueeze(t169, 1) # t170: "cuda:0 bf16[1, 1, 4096]" | |
# t170 = prims.broadcast_in_dim(t169, [1, 1, 4096], [0, 2]) # t170: "cuda:0 bf16[1, 1, 4096]" | |
del t169 | |
t22 = Tensor.expand(t170, (1, 2048, 4096)) # t22: "cuda:0 bf16[1, 2048, 4096]" | |
# t22 = ltorch.expand(t170, (1, 2048, 4096)) # t22: "cuda:0 bf16[1, 2048, 4096]" | |
# t22 = prims.broadcast_in_dim(t170, (1, 2048, 4096), (0, 1, 2)) # t22: "cuda:0 bf16[1, 2048, 4096]" | |
del t170 | |
[t13, t25, t9] = nvFusion0(t19, t22, t4) | |
# t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 2048, 4096]" | |
# (t8, t9) = prims.var_mean(t5, (2,), correction=0) | |
# t10 = prims.broadcast_in_dim(t8, [1, 2048, 1], [0, 1]) # t10: "cuda:0 f32[1, 2048, 1]" | |
# t11 = prims.broadcast_in_dim(t9, [1, 2048, 1], [0, 1]) # t11: "cuda:0 f32[1, 2048, 1]" | |
# t12 = prims.add(t10, 1e-05) # t12: "cuda:0 f32[1, 2048, 1]" | |
# t13 = prims.rsqrt(t12) # t13: "cuda:0 f32[1, 2048, 1]" | |
# t14 = prims.broadcast_in_dim(t11, (1, 2048, 4096), (0, 1, 2)) # t14: "cuda:0 f32[1, 2048, 4096]" | |
# t16 = prims.sub(t5, t14) # t16: "cuda:0 f32[1, 2048, 4096]" | |
# t17 = prims.broadcast_in_dim(t13, (1, 2048, 4096), (0, 1, 2)) # t17: "cuda:0 f32[1, 2048, 4096]" | |
# t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 2048, 4096]" | |
# t20 = prims.convert_element_type(t19, dtypes.float32) # t20: "cuda:0 f32[1, 2048, 4096]" | |
# t21 = prims.mul(t18, t20) # t21: "cuda:0 f32[1, 2048, 4096]" | |
# t23 = prims.convert_element_type(t22, dtypes.float32) # t23: "cuda:0 f32[1, 2048, 4096]" | |
# t24 = prims.add(t21, t23) # t24: "cuda:0 f32[1, 2048, 4096]" | |
# t25 = prims.convert_element_type(t24, dtypes.bfloat16) # t25: "cuda:0 bf16[1, 2048, 4096]" | |
del t22 | |
t26 = torch.nn.functional.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]" | |
# t26 = ltorch.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]" | |
# t26 = prims.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]" | |
t107 = torch.nn.functional.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]" | |
# t107 = ltorch.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]" | |
# t107 = prims.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]" | |
t0 = torch_slice_prim_impl(tos1, [0, 0], [2048, 64], [1, 1]) # t0: "cuda:0 bf16[2048, 64]" | |
t1 = torch_slice_prim_impl(t_sin, [0, 0], [2048, 64], [1, 1]) # t1: "cuda:0 bf16[2048, 64]" | |
t27 = torch.reshape(t26, (1, 2048, 1, 66, 128)) # t27: "cuda:0 bf16[1, 2048, 1, 66, 128]" | |
# t27 = ltorch.reshape(t26, (1, 2048, 1, 66, 128)) # t27: "cuda:0 bf16[1, 2048, 1, 66, 128]" | |
# t27 = prims.reshape(t26, (1, 2048, 1, 66, 128)) # t27: "cuda:0 bf16[1, 2048, 1, 66, 128]" | |
del t26 | |
t28 = torch.permute(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 66, 2048, 128]" | |
# t28 = ltorch.permute(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 66, 2048, 128]" | |
# t28 = prims.transpose(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 66, 2048, 128]" | |
del t27 | |
(t29, t30, t31) = torch.split(t28, (64, 1, 1), 2) | |
# (t29, t30, t31) = ltorch.split(t28, (64, 1, 1), 2) | |
# t29 = prims.slice_prim(t28, [0, 0, 0, 0, 0], [1, 1, 64, 2048, 128], [1, 1, 1, 1, 1]) # t29: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
# t30 = prims.slice_prim(t28, [0, 0, 64, 0, 0], [1, 1, 65, 2048, 128], [1, 1, 1, 1, 1]) # t30: "cuda:0 bf16[1, 1, 1, 2048, 128]" | |
# t31 = prims.slice_prim(t28, [0, 0, 65, 0, 0], [1, 1, 66, 2048, 128], [1, 1, 1, 1, 1]) # t31: "cuda:0 bf16[1, 1, 1, 2048, 128]" | |
del t28 | |
t32 = Tensor.expand(t30, (1, 1, 64, 2048, 128)) # t32: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
# t32 = ltorch.expand(t30, (1, 1, 64, 2048, 128)) # t32: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
# t32 = prims.broadcast_in_dim(t30, (1, 1, 64, 2048, 128), (0, 1, 2, 3, 4)) # t32: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
del t30 | |
t38 = Tensor.expand(t31, (1, 1, 64, 2048, 128)) # t38: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
# t38 = ltorch.expand(t31, (1, 1, 64, 2048, 128)) # t38: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
# t38 = prims.broadcast_in_dim(t31, (1, 1, 64, 2048, 128), (0, 1, 2, 3, 4)) # t38: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
del t31 | |
t39 = torch.reshape(t29, (1, 64, 2048, 128)) # t39: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t39 = ltorch.reshape(t29, (1, 64, 2048, 128)) # t39: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t39 = prims.reshape(t29, (1, 64, 2048, 128)) # t39: "cuda:0 bf16[1, 64, 2048, 128]" | |
del t29 | |
t45 = torch.reshape(t32, (1, 64, 2048, 128)) # t45: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t45 = ltorch.reshape(t32, (1, 64, 2048, 128)) # t45: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t45 = prims.reshape(t32, (1, 64, 2048, 128)) # t45: "cuda:0 bf16[1, 64, 2048, 128]" | |
del t32 | |
t51 = torch.reshape(t38, (1, 64, 2048, 128)) # t51: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t51 = ltorch.reshape(t38, (1, 64, 2048, 128)) # t51: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t51 = prims.reshape(t38, (1, 64, 2048, 128)) # t51: "cuda:0 bf16[1, 64, 2048, 128]" | |
del t38 | |
t52 = torch_slice_prim_impl(t39, [0, 0, 0, 0], [1, 64, 2048, 64], [1, 1, 1, 1]) # t52: "cuda:0 bf16[1, 64, 2048, 64]" | |
t53 = torch_slice_prim_impl(t52, [0, 0, 0, 0], [1, 64, 2048, 32], [1, 1, 1, 1]) # t53: "cuda:0 bf16[1, 64, 2048, 32]" | |
t54 = torch_slice_prim_impl(t52, [0, 0, 0, 32], [1, 64, 2048, 64], [1, 1, 1, 1]) # t54: "cuda:0 bf16[1, 64, 2048, 32]" | |
t187 = torch.unsqueeze(t0, 0) # t187: "cuda:0 bf16[1, 2048, 64]" | |
# t187 = ltorch.unsqueeze(t0, 0) # t187: "cuda:0 bf16[1, 2048, 64]" | |
# t187 = prims.broadcast_in_dim(t0, [1, 2048, 64], [1, 2]) # t187: "cuda:0 bf16[1, 2048, 64]" | |
del t0 | |
t188 = torch.unsqueeze(t187, 1) # t188: "cuda:0 bf16[1, 1, 2048, 64]" | |
# t188 = ltorch.unsqueeze(t187, 1) # t188: "cuda:0 bf16[1, 1, 2048, 64]" | |
# t188 = prims.broadcast_in_dim(t187, [1, 1, 2048, 64], [0, 2, 3]) # t188: "cuda:0 bf16[1, 1, 2048, 64]" | |
del t187 | |
t59 = Tensor.expand(t188, (1, 64, 2048, 64)) # t59: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t59 = ltorch.expand(t188, (1, 64, 2048, 64)) # t59: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t59 = prims.broadcast_in_dim(t188, (1, 64, 2048, 64), (0, 1, 2, 3)) # t59: "cuda:0 bf16[1, 64, 2048, 64]" | |
del t188 | |
t190 = torch.unsqueeze(t1, 0) # t190: "cuda:0 bf16[1, 2048, 64]" | |
# t190 = ltorch.unsqueeze(t1, 0) # t190: "cuda:0 bf16[1, 2048, 64]" | |
# t190 = prims.broadcast_in_dim(t1, [1, 2048, 64], [1, 2]) # t190: "cuda:0 bf16[1, 2048, 64]" | |
del t1 | |
t191 = torch.unsqueeze(t190, 1) # t191: "cuda:0 bf16[1, 1, 2048, 64]" | |
# t191 = ltorch.unsqueeze(t190, 1) # t191: "cuda:0 bf16[1, 1, 2048, 64]" | |
# t191 = prims.broadcast_in_dim(t190, [1, 1, 2048, 64], [0, 2, 3]) # t191: "cuda:0 bf16[1, 1, 2048, 64]" | |
del t190 | |
t64 = Tensor.expand(t191, (1, 64, 2048, 64)) # t64: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t64 = ltorch.expand(t191, (1, 64, 2048, 64)) # t64: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t64 = prims.broadcast_in_dim(t191, (1, 64, 2048, 64), (0, 1, 2, 3)) # t64: "cuda:0 bf16[1, 64, 2048, 64]" | |
del t191 | |
t73 = torch_slice_prim_impl(t45, [0, 0, 0, 0], [1, 64, 2048, 64], [1, 1, 1, 1]) # t73: "cuda:0 bf16[1, 64, 2048, 64]" | |
t74 = torch_slice_prim_impl(t73, [0, 0, 0, 0], [1, 64, 2048, 32], [1, 1, 1, 1]) # t74: "cuda:0 bf16[1, 64, 2048, 32]" | |
t75 = torch_slice_prim_impl(t73, [0, 0, 0, 32], [1, 64, 2048, 64], [1, 1, 1, 1]) # t75: "cuda:0 bf16[1, 64, 2048, 32]" | |
t95 = torch_slice_prim_impl(t39, [0, 0, 0, 64], [1, 64, 2048, 128], [1, 1, 1, 1]) # t95: "cuda:0 bf16[1, 64, 2048, 64]" | |
del t39 | |
t97 = torch_slice_prim_impl(t45, [0, 0, 0, 64], [1, 64, 2048, 128], [1, 1, 1, 1]) # t97: "cuda:0 bf16[1, 64, 2048, 64]" | |
del t45 | |
[t123, t96, t99] = nvFusion1(t107, t52, t53, t54, t59, t64, t73, t74, t75, t95, t97) | |
# t55 = prims.convert_element_type(t54, dtypes.float32) # t55: "cuda:0 f32[1, 64, 2048, 32]" | |
# t56 = prims.neg(t55) # t56: "cuda:0 f32[1, 64, 2048, 32]" | |
# t57 = prims.convert_element_type(t56, dtypes.bfloat16) # t57: "cuda:0 bf16[1, 64, 2048, 32]" | |
# t58 = prims.cat((t57, t53), -1) # t58: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t60 = prims.convert_element_type(t52, dtypes.float32) # t60: "cuda:0 f32[1, 64, 2048, 64]" | |
# t61 = prims.convert_element_type(t59, dtypes.float32) # t61: "cuda:0 f32[1, 64, 2048, 64]" | |
# t62 = prims.mul(t60, t61) # t62: "cuda:0 f32[1, 64, 2048, 64]" | |
# t65 = prims.convert_element_type(t58, dtypes.float32) # t65: "cuda:0 f32[1, 64, 2048, 64]" | |
# t66 = prims.convert_element_type(t64, dtypes.float32) # t66: "cuda:0 f32[1, 64, 2048, 64]" | |
# t67 = prims.mul(t65, t66) # t67: "cuda:0 f32[1, 64, 2048, 64]" | |
# t71 = prims.add(t62, t67) # t71: "cuda:0 f32[1, 64, 2048, 64]" | |
# t72 = prims.convert_element_type(t71, dtypes.bfloat16) # t72: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t76 = prims.convert_element_type(t75, dtypes.float32) # t76: "cuda:0 f32[1, 64, 2048, 32]" | |
# t77 = prims.neg(t76) # t77: "cuda:0 f32[1, 64, 2048, 32]" | |
# t78 = prims.convert_element_type(t77, dtypes.bfloat16) # t78: "cuda:0 bf16[1, 64, 2048, 32]" | |
# t80 = prims.cat((t78, t74), -1) # t80: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t82 = prims.convert_element_type(t73, dtypes.float32) # t82: "cuda:0 f32[1, 64, 2048, 64]" | |
# t84 = prims.mul(t82, t61) # t84: "cuda:0 f32[1, 64, 2048, 64]" | |
# t87 = prims.convert_element_type(t80, dtypes.float32) # t87: "cuda:0 f32[1, 64, 2048, 64]" | |
# t89 = prims.mul(t87, t66) # t89: "cuda:0 f32[1, 64, 2048, 64]" | |
# t93 = prims.add(t84, t89) # t93: "cuda:0 f32[1, 64, 2048, 64]" | |
# t94 = prims.convert_element_type(t93, dtypes.bfloat16) # t94: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t96 = prims.cat((t72, t95), -1) # t96: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t99 = prims.cat((t94, t97), -1) # t99: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t108 = prims.convert_element_type(t107, dtypes.float32) # t108: "cuda:0 f32[1, 2048, 18176]" | |
# t109 = prims.div(t108, 1.4142135623730951) # t109: "cuda:0 f32[1, 2048, 18176]" | |
# t112 = prims.erf(t109) # t112: "cuda:0 f32[1, 2048, 18176]" | |
# t115 = prims.mul(0.5, t112) # t115: "cuda:0 f32[1, 2048, 18176]" | |
# t118 = prims.add(0.5, t115) # t118: "cuda:0 f32[1, 2048, 18176]" | |
# t122 = prims.mul(t108, t118) # t122: "cuda:0 f32[1, 2048, 18176]" | |
# t123 = prims.convert_element_type(t122, dtypes.bfloat16) # t123: "cuda:0 bf16[1, 2048, 18176]" | |
del t52, t53, t54, t73, t74, t75, t95, t97 | |
(t100, t101, t102, t103) = cudnn_sdpa_fwd(t96, t99, t51, None, 0.0, True, scale=0.08838834764831843) | |
t124 = torch.nn.functional.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]" | |
# t124 = ltorch.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]" | |
# t124 = prims.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]" | |
t104 = torch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]" | |
# t104 = ltorch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]" | |
# t104 = prims.transpose(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]" | |
t105 = torch.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]" | |
# t105 = ltorch.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]" | |
# t105 = prims.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]" | |
del t104 | |
t106 = torch.nn.functional.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]" | |
# t106 = ltorch.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]" | |
# t106 = prims.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]" | |
t200 = torch.unsqueeze(t_transformer_ln_f_weight, 0) # t200: "cuda:0 bf16[1, 4096]" | |
# t200 = ltorch.unsqueeze(t_transformer_ln_f_weight, 0) # t200: "cuda:0 bf16[1, 4096]" | |
# t200 = prims.broadcast_in_dim(t_transformer_ln_f_weight, [1, 4096], [1]) # t200: "cuda:0 bf16[1, 4096]" | |
t201 = torch.unsqueeze(t200, 1) # t201: "cuda:0 bf16[1, 1, 4096]" | |
# t201 = ltorch.unsqueeze(t200, 1) # t201: "cuda:0 bf16[1, 1, 4096]" | |
# t201 = prims.broadcast_in_dim(t200, [1, 1, 4096], [0, 2]) # t201: "cuda:0 bf16[1, 1, 4096]" | |
del t200 | |
t150 = Tensor.expand(t201, (1, 2048, 4096)) # t150: "cuda:0 bf16[1, 2048, 4096]" | |
# t150 = ltorch.expand(t201, (1, 2048, 4096)) # t150: "cuda:0 bf16[1, 2048, 4096]" | |
# t150 = prims.broadcast_in_dim(t201, (1, 2048, 4096), (0, 1, 2)) # t150: "cuda:0 bf16[1, 2048, 4096]" | |
del t201 | |
t203 = torch.unsqueeze(t_transformer_ln_f_bias, 0) # t203: "cuda:0 bf16[1, 4096]" | |
# t203 = ltorch.unsqueeze(t_transformer_ln_f_bias, 0) # t203: "cuda:0 bf16[1, 4096]" | |
# t203 = prims.broadcast_in_dim(t_transformer_ln_f_bias, [1, 4096], [1]) # t203: "cuda:0 bf16[1, 4096]" | |
t204 = torch.unsqueeze(t203, 1) # t204: "cuda:0 bf16[1, 1, 4096]" | |
# t204 = ltorch.unsqueeze(t203, 1) # t204: "cuda:0 bf16[1, 1, 4096]" | |
# t204 = prims.broadcast_in_dim(t203, [1, 1, 4096], [0, 2]) # t204: "cuda:0 bf16[1, 1, 4096]" | |
del t203 | |
t153 = Tensor.expand(t204, (1, 2048, 4096)) # t153: "cuda:0 bf16[1, 2048, 4096]" | |
# t153 = ltorch.expand(t204, (1, 2048, 4096)) # t153: "cuda:0 bf16[1, 2048, 4096]" | |
# t153 = prims.broadcast_in_dim(t204, (1, 2048, 4096), (0, 1, 2)) # t153: "cuda:0 bf16[1, 2048, 4096]" | |
del t204 | |
[t139, t144, t156] = nvFusion2(t106, t124, t150, t153, t4) | |
# t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 2048, 4096]" | |
# t126 = prims.convert_element_type(t106, dtypes.float32) # t126: "cuda:0 f32[1, 2048, 4096]" | |
# t127 = prims.add(t125, t126) # t127: "cuda:0 f32[1, 2048, 4096]" | |
# t130 = prims.convert_element_type(t4, dtypes.float32) # t130: "cuda:0 f32[1, 2048, 4096]" | |
# t131 = prims.add(t127, t130) # t131: "cuda:0 f32[1, 2048, 4096]" | |
# (t138, t139) = prims.var_mean(t131, (2,), correction=0) | |
# t140 = prims.broadcast_in_dim(t138, [1, 2048, 1], [0, 1]) # t140: "cuda:0 f32[1, 2048, 1]" | |
# t141 = prims.broadcast_in_dim(t139, [1, 2048, 1], [0, 1]) # t141: "cuda:0 f32[1, 2048, 1]" | |
# t143 = prims.add(t140, 1e-05) # t143: "cuda:0 f32[1, 2048, 1]" | |
# t144 = prims.rsqrt(t143) # t144: "cuda:0 f32[1, 2048, 1]" | |
# t145 = prims.broadcast_in_dim(t141, (1, 2048, 4096), (0, 1, 2)) # t145: "cuda:0 f32[1, 2048, 4096]" | |
# t147 = prims.sub(t131, t145) # t147: "cuda:0 f32[1, 2048, 4096]" | |
# t148 = prims.broadcast_in_dim(t144, (1, 2048, 4096), (0, 1, 2)) # t148: "cuda:0 f32[1, 2048, 4096]" | |
# t149 = prims.mul(t147, t148) # t149: "cuda:0 f32[1, 2048, 4096]" | |
# t151 = prims.convert_element_type(t150, dtypes.float32) # t151: "cuda:0 f32[1, 2048, 4096]" | |
# t152 = prims.mul(t149, t151) # t152: "cuda:0 f32[1, 2048, 4096]" | |
# t154 = prims.convert_element_type(t153, dtypes.float32) # t154: "cuda:0 f32[1, 2048, 4096]" | |
# t155 = prims.add(t152, t154) # t155: "cuda:0 f32[1, 2048, 4096]" | |
# t156 = prims.convert_element_type(t155, dtypes.bfloat16) # t156: "cuda:0 bf16[1, 2048, 4096]" | |
del t153 | |
t157 = torch.nn.functional.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]" | |
# t157 = ltorch.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]" | |
# t157 = prims.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]" | |
return {'output': t157, 'flat_args': [idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_bias, t_transformer_h_0_norm_1_weight, t_transformer_ln_f_bias, t_transformer_ln_f_weight, t_transformer_wte_weight], 'flat_output': (t157,)}, ((idx, t100, t101, t102, t103, t105, t106, t107, t123, t124, t13, t139, t144, t150, t156, t19, t25, t4, t51, t59, t64, t9, t96, t99, t_lm_head_weight, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight), (False, False, True, 0.0, 0.08838834764831843, 1.4142135623730951, 0.5, 65024, 2, 0, 0)) | |
Above trace failed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@torch.no_grad() | |
@no_autocast | |
def augmented_forward_fn(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_bias, t_transformer_h_0_norm_1_weight, t_transformer_ln_f_bias, t_transformer_ln_f_weight, t_transformer_wte_weight): | |
# idx: "cuda:0 i64[1, 2048]" | |
# tos1: "cuda:0 bf16[2048, 64]" | |
# t_lm_head_weight: "cuda:0 bf16[65024, 4096]" | |
# t_sin: "cuda:0 bf16[2048, 64]" | |
# t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[8448, 4096]" | |
# t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 8192]" | |
# t_transformer_h_0_mlp_fc_weight: "cuda:0 bf16[18176, 4096]" | |
# t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 18176]" | |
# t_transformer_h_0_norm_1_bias: "cuda:0 bf16[4096]" | |
# t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]" | |
# t_transformer_ln_f_bias: "cuda:0 bf16[4096]" | |
# t_transformer_ln_f_weight: "cuda:0 bf16[4096]" | |
# t_transformer_wte_weight: "cuda:0 bf16[65024, 4096]" | |
t4 = torch.nn.functional.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4096]" | |
# t4 = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4096]" | |
# t158 = ltorch.reshape(idx, [2048]) # t158: "cuda:0 i64[2048]" | |
# t158 = prims.reshape(idx, (2048,)) # t158: "cuda:0 i64[2048]" | |
# t159 = prims.take(t_transformer_wte_weight, t158, 0) # t159: "cuda:0 bf16[2048, 4096]" | |
# t4 = ltorch.reshape(t159, [1, 2048, 4096]) # t4: "cuda:0 bf16[1, 2048, 4096]" | |
# t4 = prims.reshape(t159, (1, 2048, 4096)) # t4: "cuda:0 bf16[1, 2048, 4096]" | |
t166 = torch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t166: "cuda:0 bf16[1, 4096]" | |
# t166 = ltorch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t166: "cuda:0 bf16[1, 4096]" | |
# t166 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, [1, 4096], [1]) # t166: "cuda:0 bf16[1, 4096]" | |
t167 = torch.unsqueeze(t166, 1) # t167: "cuda:0 bf16[1, 1, 4096]" | |
# t167 = ltorch.unsqueeze(t166, 1) # t167: "cuda:0 bf16[1, 1, 4096]" | |
# t167 = prims.broadcast_in_dim(t166, [1, 1, 4096], [0, 2]) # t167: "cuda:0 bf16[1, 1, 4096]" | |
del t166 | |
t19 = Tensor.expand(t167, (1, 2048, 4096)) # t19: "cuda:0 bf16[1, 2048, 4096]" | |
# t19 = ltorch.expand(t167, (1, 2048, 4096)) # t19: "cuda:0 bf16[1, 2048, 4096]" | |
# t19 = prims.broadcast_in_dim(t167, (1, 2048, 4096), (0, 1, 2)) # t19: "cuda:0 bf16[1, 2048, 4096]" | |
del t167 | |
t169 = torch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t169: "cuda:0 bf16[1, 4096]" | |
# t169 = ltorch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t169: "cuda:0 bf16[1, 4096]" | |
# t169 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_bias, [1, 4096], [1]) # t169: "cuda:0 bf16[1, 4096]" | |
t170 = torch.unsqueeze(t169, 1) # t170: "cuda:0 bf16[1, 1, 4096]" | |
# t170 = ltorch.unsqueeze(t169, 1) # t170: "cuda:0 bf16[1, 1, 4096]" | |
# t170 = prims.broadcast_in_dim(t169, [1, 1, 4096], [0, 2]) # t170: "cuda:0 bf16[1, 1, 4096]" | |
del t169 | |
t22 = Tensor.expand(t170, (1, 2048, 4096)) # t22: "cuda:0 bf16[1, 2048, 4096]" | |
# t22 = ltorch.expand(t170, (1, 2048, 4096)) # t22: "cuda:0 bf16[1, 2048, 4096]" | |
# t22 = prims.broadcast_in_dim(t170, (1, 2048, 4096), (0, 1, 2)) # t22: "cuda:0 bf16[1, 2048, 4096]" | |
del t170 | |
[t13, t25, t9] = nvFusion0(t19, t22, t4) | |
# t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 2048, 4096]" | |
# (t8, t9) = prims.var_mean(t5, (2,), correction=0) | |
# t10 = prims.broadcast_in_dim(t8, [1, 2048, 1], [0, 1]) # t10: "cuda:0 f32[1, 2048, 1]" | |
# t12 = prims.add(t10, 1e-05) # t12: "cuda:0 f32[1, 2048, 1]" | |
# t11 = prims.broadcast_in_dim(t9, [1, 2048, 1], [0, 1]) # t11: "cuda:0 f32[1, 2048, 1]" | |
# t14 = prims.broadcast_in_dim(t11, (1, 2048, 4096), (0, 1, 2)) # t14: "cuda:0 f32[1, 2048, 4096]" | |
# t13 = prims.rsqrt(t12) # t13: "cuda:0 f32[1, 2048, 1]" | |
# t16 = prims.sub(t5, t14) # t16: "cuda:0 f32[1, 2048, 4096]" | |
# t17 = prims.broadcast_in_dim(t13, (1, 2048, 4096), (0, 1, 2)) # t17: "cuda:0 f32[1, 2048, 4096]" | |
# t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 2048, 4096]" | |
# t20 = prims.convert_element_type(t19, dtypes.float32) # t20: "cuda:0 f32[1, 2048, 4096]" | |
# t21 = prims.mul(t18, t20) # t21: "cuda:0 f32[1, 2048, 4096]" | |
# t23 = prims.convert_element_type(t22, dtypes.float32) # t23: "cuda:0 f32[1, 2048, 4096]" | |
# t24 = prims.add(t21, t23) # t24: "cuda:0 f32[1, 2048, 4096]" | |
# t25 = prims.convert_element_type(t24, dtypes.bfloat16) # t25: "cuda:0 bf16[1, 2048, 4096]" | |
del t22 | |
t26 = torch.nn.functional.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]" | |
# t26 = ltorch.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]" | |
# t26 = prims.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 8448]" | |
t107 = torch.nn.functional.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]" | |
# t107 = ltorch.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]" | |
# t107 = prims.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]" | |
[t51, t61, t66, t96, t99] = TorchCompile0(t107, t26, t_sin, tos1) | |
# t0 = prims.slice_prim(tos1, [0, 0], [2048, 64], [1, 1]) # t0: "cuda:0 bf16[2048, 64]" | |
# t1 = prims.slice_prim(t_sin, [0, 0], [2048, 64], [1, 1]) # t1: "cuda:0 bf16[2048, 64]" | |
# t27 = prims.reshape(t26, (1, 2048, 1, 66, 128)) # t27: "cuda:0 bf16[1, 2048, 1, 66, 128]" | |
# t28 = prims.transpose(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 66, 2048, 128]" | |
# (t29, t30, t31) = ltorch.split(t28, (64, 1, 1), 2) | |
# t29 = prims.slice_prim(t28, [0, 0, 0, 0, 0], [1, 1, 64, 2048, 128], [1, 1, 1, 1, 1]) # t29: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
# t30 = prims.slice_prim(t28, [0, 0, 64, 0, 0], [1, 1, 65, 2048, 128], [1, 1, 1, 1, 1]) # t30: "cuda:0 bf16[1, 1, 1, 2048, 128]" | |
# t31 = prims.slice_prim(t28, [0, 0, 65, 0, 0], [1, 1, 66, 2048, 128], [1, 1, 1, 1, 1]) # t31: "cuda:0 bf16[1, 1, 1, 2048, 128]" | |
# t32 = prims.broadcast_in_dim(t30, (1, 1, 64, 2048, 128), (0, 1, 2, 3, 4)) # t32: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
# t38 = prims.broadcast_in_dim(t31, (1, 1, 64, 2048, 128), (0, 1, 2, 3, 4)) # t38: "cuda:0 bf16[1, 1, 64, 2048, 128]" | |
# t39 = prims.reshape(t29, (1, 64, 2048, 128)) # t39: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t45 = prims.reshape(t32, (1, 64, 2048, 128)) # t45: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t51 = prims.reshape(t38, (1, 64, 2048, 128)) # t51: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t52 = prims.slice_prim(t39, [0, 0, 0, 0], [1, 64, 2048, 64], [1, 1, 1, 1]) # t52: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t53 = prims.slice_prim(t52, [0, 0, 0, 0], [1, 64, 2048, 32], [1, 1, 1, 1]) # t53: "cuda:0 bf16[1, 64, 2048, 32]" | |
# t54 = prims.slice_prim(t52, [0, 0, 0, 32], [1, 64, 2048, 64], [1, 1, 1, 1]) # t54: "cuda:0 bf16[1, 64, 2048, 32]" | |
# t55 = prims.convert_element_type(t54, dtypes.float32) # t55: "cuda:0 f32[1, 64, 2048, 32]" | |
# t56 = prims.neg(t55) # t56: "cuda:0 f32[1, 64, 2048, 32]" | |
# t57 = prims.convert_element_type(t56, dtypes.bfloat16) # t57: "cuda:0 bf16[1, 64, 2048, 32]" | |
# t58 = prims.cat((t57, t53), -1) # t58: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t59 = prims.broadcast_in_dim(t0, (1, 64, 2048, 64), (2, 3)) # t59: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t60 = prims.convert_element_type(t52, dtypes.float32) # t60: "cuda:0 f32[1, 64, 2048, 64]" | |
# t61 = prims.convert_element_type(t59, dtypes.float32) # t61: "cuda:0 f32[1, 64, 2048, 64]" | |
# t62 = ltorch.mul(t60, t61) # t62: "cuda:0 f32[1, 64, 2048, 64]" | |
# t62 = prims.mul(t60, t61) # t62: "cuda:0 f32[1, 64, 2048, 64]" | |
# t63 = prims.convert_element_type(t62, dtypes.bfloat16) # t63: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t64 = prims.broadcast_in_dim(t1, (1, 64, 2048, 64), (2, 3)) # t64: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t65 = prims.convert_element_type(t58, dtypes.float32) # t65: "cuda:0 f32[1, 64, 2048, 64]" | |
# t66 = prims.convert_element_type(t64, dtypes.float32) # t66: "cuda:0 f32[1, 64, 2048, 64]" | |
# t67 = ltorch.mul(t65, t66) # t67: "cuda:0 f32[1, 64, 2048, 64]" | |
# t67 = prims.mul(t65, t66) # t67: "cuda:0 f32[1, 64, 2048, 64]" | |
# t68 = prims.convert_element_type(t67, dtypes.bfloat16) # t68: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t71 = ltorch.add(t62, t67, alpha=None) # t71: "cuda:0 f32[1, 64, 2048, 64]" | |
# t71 = prims.add(t62, t67) # t71: "cuda:0 f32[1, 64, 2048, 64]" | |
# t72 = prims.convert_element_type(t71, dtypes.bfloat16) # t72: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t73 = prims.slice_prim(t45, [0, 0, 0, 0], [1, 64, 2048, 64], [1, 1, 1, 1]) # t73: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t74 = prims.slice_prim(t73, [0, 0, 0, 0], [1, 64, 2048, 32], [1, 1, 1, 1]) # t74: "cuda:0 bf16[1, 64, 2048, 32]" | |
# t75 = prims.slice_prim(t73, [0, 0, 0, 32], [1, 64, 2048, 64], [1, 1, 1, 1]) # t75: "cuda:0 bf16[1, 64, 2048, 32]" | |
# t76 = prims.convert_element_type(t75, dtypes.float32) # t76: "cuda:0 f32[1, 64, 2048, 32]" | |
# t77 = prims.neg(t76) # t77: "cuda:0 f32[1, 64, 2048, 32]" | |
# t78 = prims.convert_element_type(t77, dtypes.bfloat16) # t78: "cuda:0 bf16[1, 64, 2048, 32]" | |
# t80 = prims.cat((t78, t74), -1) # t80: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t82 = prims.convert_element_type(t73, dtypes.float32) # t82: "cuda:0 f32[1, 64, 2048, 64]" | |
# t84 = ltorch.mul(t82, t61) # t84: "cuda:0 f32[1, 64, 2048, 64]" | |
# t84 = prims.mul(t82, t61) # t84: "cuda:0 f32[1, 64, 2048, 64]" | |
# t85 = prims.convert_element_type(t84, dtypes.bfloat16) # t85: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t87 = prims.convert_element_type(t80, dtypes.float32) # t87: "cuda:0 f32[1, 64, 2048, 64]" | |
# t89 = ltorch.mul(t87, t66) # t89: "cuda:0 f32[1, 64, 2048, 64]" | |
# t89 = prims.mul(t87, t66) # t89: "cuda:0 f32[1, 64, 2048, 64]" | |
# t90 = prims.convert_element_type(t89, dtypes.bfloat16) # t90: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t93 = ltorch.add(t84, t89, alpha=None) # t93: "cuda:0 f32[1, 64, 2048, 64]" | |
# t93 = prims.add(t84, t89) # t93: "cuda:0 f32[1, 64, 2048, 64]" | |
# t94 = prims.convert_element_type(t93, dtypes.bfloat16) # t94: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t95 = prims.slice_prim(t39, [0, 0, 0, 64], [1, 64, 2048, 128], [1, 1, 1, 1]) # t95: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t96 = prims.cat((t72, t95), -1) # t96: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t97 = prims.slice_prim(t45, [0, 0, 0, 64], [1, 64, 2048, 128], [1, 1, 1, 1]) # t97: "cuda:0 bf16[1, 64, 2048, 64]" | |
# t99 = prims.cat((t94, t97), -1) # t99: "cuda:0 bf16[1, 64, 2048, 128]" | |
# t108 = prims.convert_element_type(t107, dtypes.float32) # t108: "cuda:0 f32[1, 2048, 18176]" | |
del t26 | |
(t100, t101, t102, t103) = cudnn_sdpa_fwd(t96, t99, t51, None, 0.0, True, scale=0.08838834764831843) | |
t104 = torch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]" | |
# t104 = ltorch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]" | |
# t104 = prims.transpose(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 64, 128]" | |
t105 = torch.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]" | |
# t105 = ltorch.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]" | |
# t105 = prims.reshape(t104, (1, 2048, 8192)) # t105: "cuda:0 bf16[1, 2048, 8192]" | |
del t104 | |
[t118, t123] = nvFusion1(t107) | |
# t108 = prims.convert_element_type(t107, dtypes.float32) # t108: "cuda:0 f32[1, 2048, 18176]" | |
# t109 = prims.div(t108, 1.4142135623730951) # t109: "cuda:0 f32[1, 2048, 18176]" | |
# t112 = prims.erf(t109) # t112: "cuda:0 f32[1, 2048, 18176]" | |
# t115 = prims.mul(0.5, t112) # t115: "cuda:0 f32[1, 2048, 18176]" | |
# t118 = prims.add(0.5, t115) # t118: "cuda:0 f32[1, 2048, 18176]" | |
# t122 = prims.mul(t108, t118) # t122: "cuda:0 f32[1, 2048, 18176]" | |
# t123 = prims.convert_element_type(t122, dtypes.bfloat16) # t123: "cuda:0 bf16[1, 2048, 18176]" | |
t124 = torch.nn.functional.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]" | |
# t124 = ltorch.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]" | |
# t124 = prims.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4096]" | |
t106 = torch.nn.functional.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]" | |
# t106 = ltorch.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]" | |
# t106 = prims.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4096]" | |
t174 = torch.unsqueeze(t_transformer_ln_f_weight, 0) # t174: "cuda:0 bf16[1, 4096]" | |
# t174 = ltorch.unsqueeze(t_transformer_ln_f_weight, 0) # t174: "cuda:0 bf16[1, 4096]" | |
# t174 = prims.broadcast_in_dim(t_transformer_ln_f_weight, [1, 4096], [1]) # t174: "cuda:0 bf16[1, 4096]" | |
t175 = torch.unsqueeze(t174, 1) # t175: "cuda:0 bf16[1, 1, 4096]" | |
# t175 = ltorch.unsqueeze(t174, 1) # t175: "cuda:0 bf16[1, 1, 4096]" | |
# t175 = prims.broadcast_in_dim(t174, [1, 1, 4096], [0, 2]) # t175: "cuda:0 bf16[1, 1, 4096]" | |
del t174 | |
t150 = Tensor.expand(t175, (1, 2048, 4096)) # t150: "cuda:0 bf16[1, 2048, 4096]" | |
# t150 = ltorch.expand(t175, (1, 2048, 4096)) # t150: "cuda:0 bf16[1, 2048, 4096]" | |
# t150 = prims.broadcast_in_dim(t175, (1, 2048, 4096), (0, 1, 2)) # t150: "cuda:0 bf16[1, 2048, 4096]" | |
del t175 | |
t177 = torch.unsqueeze(t_transformer_ln_f_bias, 0) # t177: "cuda:0 bf16[1, 4096]" | |
# t177 = ltorch.unsqueeze(t_transformer_ln_f_bias, 0) # t177: "cuda:0 bf16[1, 4096]" | |
# t177 = prims.broadcast_in_dim(t_transformer_ln_f_bias, [1, 4096], [1]) # t177: "cuda:0 bf16[1, 4096]" | |
t178 = torch.unsqueeze(t177, 1) # t178: "cuda:0 bf16[1, 1, 4096]" | |
# t178 = ltorch.unsqueeze(t177, 1) # t178: "cuda:0 bf16[1, 1, 4096]" | |
# t178 = prims.broadcast_in_dim(t177, [1, 1, 4096], [0, 2]) # t178: "cuda:0 bf16[1, 1, 4096]" | |
del t177 | |
t153 = Tensor.expand(t178, (1, 2048, 4096)) # t153: "cuda:0 bf16[1, 2048, 4096]" | |
# t153 = ltorch.expand(t178, (1, 2048, 4096)) # t153: "cuda:0 bf16[1, 2048, 4096]" | |
# t153 = prims.broadcast_in_dim(t178, (1, 2048, 4096), (0, 1, 2)) # t153: "cuda:0 bf16[1, 2048, 4096]" | |
del t178 | |
[t139, t144, t156] = nvFusion2(t106, t124, t150, t153, t4) | |
# t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 2048, 4096]" | |
# t126 = prims.convert_element_type(t106, dtypes.float32) # t126: "cuda:0 f32[1, 2048, 4096]" | |
# t127 = prims.add(t125, t126) # t127: "cuda:0 f32[1, 2048, 4096]" | |
# t130 = prims.convert_element_type(t4, dtypes.float32) # t130: "cuda:0 f32[1, 2048, 4096]" | |
# t131 = prims.add(t127, t130) # t131: "cuda:0 f32[1, 2048, 4096]" | |
# (t138, t139) = prims.var_mean(t131, (2,), correction=0) | |
# t140 = prims.broadcast_in_dim(t138, [1, 2048, 1], [0, 1]) # t140: "cuda:0 f32[1, 2048, 1]" | |
# t143 = prims.add(t140, 1e-05) # t143: "cuda:0 f32[1, 2048, 1]" | |
# t141 = prims.broadcast_in_dim(t139, [1, 2048, 1], [0, 1]) # t141: "cuda:0 f32[1, 2048, 1]" | |
# t145 = prims.broadcast_in_dim(t141, (1, 2048, 4096), (0, 1, 2)) # t145: "cuda:0 f32[1, 2048, 4096]" | |
# t144 = prims.rsqrt(t143) # t144: "cuda:0 f32[1, 2048, 1]" | |
# t147 = prims.sub(t131, t145) # t147: "cuda:0 f32[1, 2048, 4096]" | |
# t148 = prims.broadcast_in_dim(t144, (1, 2048, 4096), (0, 1, 2)) # t148: "cuda:0 f32[1, 2048, 4096]" | |
# t149 = prims.mul(t147, t148) # t149: "cuda:0 f32[1, 2048, 4096]" | |
# t151 = prims.convert_element_type(t150, dtypes.float32) # t151: "cuda:0 f32[1, 2048, 4096]" | |
# t152 = prims.mul(t149, t151) # t152: "cuda:0 f32[1, 2048, 4096]" | |
# t154 = prims.convert_element_type(t153, dtypes.float32) # t154: "cuda:0 f32[1, 2048, 4096]" | |
# t155 = prims.add(t152, t154) # t155: "cuda:0 f32[1, 2048, 4096]" | |
# t156 = prims.convert_element_type(t155, dtypes.bfloat16) # t156: "cuda:0 bf16[1, 2048, 4096]" | |
del t153 | |
t157 = torch.nn.functional.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]" | |
# t157 = ltorch.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]" | |
# t157 = prims.linear(t156, t_lm_head_weight, None) # t157: "cuda:0 bf16[1, 2048, 65024]" | |
return {'output': t157, 'flat_args': [idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_bias, t_transformer_h_0_norm_1_weight, t_transformer_ln_f_bias, t_transformer_ln_f_weight, t_transformer_wte_weight], 'flat_output': (t157,)}, ((idx, t100, t101, t102, t103, t105, t106, t107, t118, t123, t124, t13, t139, t144, t150, t156, t19, t25, t4, t51, t61, t66, t9, t96, t99, t_lm_head_weight, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight), (False, False, True, 0.0, 0.08838834764831843, 1.4142135623730951, 0.5, 65024, 2, 0, 0)) | |
Above trace passed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment