Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Created June 16, 2024 06:28
Show Gist options
  • Save crcrpar/e1aef85af9d49bca120e48a17f1f801d to your computer and use it in GitHub Desktop.
Save crcrpar/e1aef85af9d49bca120e48a17f1f801d to your computer and use it in GitHub Desktop.
# Constructed by Delete Last Used (took 16 milliseconds)
import operator
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t33, = cotangents
clear_mutable_collection(cotangents)
del cotangents
idx, t100, t1007, t101, t102, t103, t105, t1054, t1057, t1058, t1059, t106, \
t1060, t1061, t1068, t1069, t107, t1070, t1089, t1090, t1098, t1105, t1110, \
t1116, t1122, t1168, t1215, t1218, t1219, t1220, t1221, t1222, t1229, t123, \
t1230, t1231, t124, t1250, t1251, t1259, t1266, t1271, t1277, t1283, t13, t132, \
t1329, t1376, t1379, t1380, t1381, t1382, t1383, t139, t1390, t1391, t1392, \
t1411, t1412, t1420, t1427, t1432, t1438, t144, t1444, t1490, t150, t1537, \
t1540, t1541, t1542, t1543, t1544, t1551, t1552, t1553, t156, t1572, t1573, \
t1581, t1588, t1593, t1599, t1605, t1651, t1698, t1701, t1702, t1703, t1704, \
t1705, t1712, t1713, t1714, t1733, t1734, t1742, t1749, t1754, t1760, t1766, \
t1812, t1859, t1862, t1863, t1864, t1865, t1866, t1873, t1874, t1875, t1894, \
t1895, t19, t1903, t1910, t1915, t1921, t1927, t1973, t202, t2020, t2023, \
t2024, t2025, t2026, t2027, t2034, t2035, t2036, t2055, t2056, t2064, t2071, \
t2076, t2082, t2088, t2134, t2181, t2184, t2185, t2186, t2187, t2188, t2195, \
t2196, t2197, t2216, t2217, t2225, t2232, t2237, t2243, t2249, t2295, t2342, \
t2345, t2346, t2347, t2348, t2349, t2356, t2357, t2358, t2377, t2378, t2386, \
t2393, t2398, t2404, t2410, t2456, t249, t25, t2503, t2506, t2507, t2508, \
t2509, t2510, t2517, t2518, t2519, t252, t253, t2538, t2539, t254, t2547, t255, \
t2554, t2559, t256, t2565, t2571, t2617, t263, t264, t265, t2664, t2667, t2668, \
t2669, t2670, t2671, t2678, t2679, t2680, t2699, t2700, t2708, t2715, t2720, \
t2726, t2732, t2778, t2825, t2828, t2829, t2830, t2831, t2832, t2839, t284, \
t2840, t2841, t285, t2860, t2861, t2869, t2876, t2881, t2887, t2893, t293, \
t2939, t2986, t2989, t2990, t2991, t2992, t2993, t300, t3000, t3001, t3002, \
t3021, t3022, t3030, t3037, t3042, t3048, t305, t3054, t3100, t311, t3147, \
t3150, t3151, t3152, t3153, t3154, t3161, t3162, t3163, t317, t3182, t3183, \
t3191, t3198, t3203, t3209, t3215, t3261, t3308, t3311, t3312, t3313, t3314, \
t3315, t3322, t3323, t3324, t3343, t3344, t3352, t3359, t3364, t3370, t3376, \
t3422, t3469, t3472, t3473, t3474, t3475, t3476, t3483, t3484, t3485, t3504, \
t3505, t3513, t3520, t3525, t3531, t3537, t3583, t363, t3630, t3633, t3634, \
t3635, t3636, t3637, t3644, t3645, t3646, t3665, t3666, t3674, t3681, t3686, \
t3692, t3698, t3744, t3791, t3794, t3795, t3796, t3797, t3798, t3805, t3806, \
t3807, t3826, t3827, t3835, t3842, t3847, t3853, t3859, t3905, t3952, t3955, \
t3956, t3957, t3958, t3959, t3966, t3967, t3968, t3987, t3988, t3996, t4, \
t4003, t4008, t4014, t4020, t4066, t410, t4113, t4116, t4117, t4118, t4119, \
t4120, t4127, t4128, t4129, t413, t414, t4148, t4149, t415, t4157, t416, t4164, \
t4169, t417, t4175, t4181, t4227, t424, t425, t426, t4274, t4277, t4278, t4279, \
t4280, t4281, t4288, t4289, t4290, t4309, t4310, t4318, t4325, t4330, t4336, \
t4342, t4388, t4435, t4438, t4439, t4440, t4441, t4442, t4449, t445, t4450, \
t4451, t446, t4470, t4471, t4479, t4486, t4491, t4497, t4503, t454, t4549, \
t4596, t4599, t4600, t4601, t4602, t4603, t461, t4610, t4611, t4612, t4631, \
t4632, t4640, t4647, t4652, t4658, t466, t4664, t4710, t472, t4757, t4760, \
t4761, t4762, t4763, t4764, t4771, t4772, t4773, t478, t4792, t4793, t4801, \
t4808, t4813, t4819, t4825, t4871, t4918, t4921, t4922, t4923, t4924, t4925, \
t4932, t4933, t4934, t4953, t4954, t4962, t4969, t4974, t4980, t4986, t5032, \
t5079, t5082, t5083, t5084, t5085, t5086, t5093, t5094, t5095, t51, t5114, \
t5115, t5130, t5135, t5141, t5147, t524, t571, t574, t575, t576, t577, t578, \
t585, t586, t587, t606, t607, t61, t615, t622, t627, t633, t639, t66, t685, \
t732, t735, t736, t737, t738, t739, t746, t747, t748, t767, t768, t776, t783, \
t788, t794, t800, t846, t893, t896, t897, t898, t899, t9, t900, t907, t908, \
t909, t928, t929, t937, t944, t949, t955, t96, t961, t99, t_lm_head_weight, \
t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, \
t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, \
t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, \
t_transformer_h_10_mlp_fc_weight, t_transformer_h_10_mlp_proj_weight, \
t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, \
t_transformer_h_11_mlp_fc_weight, t_transformer_h_11_mlp_proj_weight, \
t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, \
t_transformer_h_12_mlp_fc_weight, t_transformer_h_12_mlp_proj_weight, \
t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, \
t_transformer_h_13_mlp_fc_weight, t_transformer_h_13_mlp_proj_weight, \
t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, \
t_transformer_h_14_mlp_fc_weight, t_transformer_h_14_mlp_proj_weight, \
t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, \
t_transformer_h_15_mlp_fc_weight, t_transformer_h_15_mlp_proj_weight, \
t_transformer_h_16_attn_attn_weight, t_transformer_h_16_attn_proj_weight, \
t_transformer_h_16_mlp_fc_weight, t_transformer_h_16_mlp_proj_weight, \
t_transformer_h_17_attn_attn_weight, t_transformer_h_17_attn_proj_weight, \
t_transformer_h_17_mlp_fc_weight, t_transformer_h_17_mlp_proj_weight, \
t_transformer_h_18_attn_attn_weight, t_transformer_h_18_attn_proj_weight, \
t_transformer_h_18_mlp_fc_weight, t_transformer_h_18_mlp_proj_weight, \
t_transformer_h_19_attn_attn_weight, t_transformer_h_19_attn_proj_weight, \
t_transformer_h_19_mlp_fc_weight, t_transformer_h_19_mlp_proj_weight, \
t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, \
t_transformer_h_1_mlp_fc_weight, t_transformer_h_1_mlp_proj_weight, \
t_transformer_h_20_attn_attn_weight, t_transformer_h_20_attn_proj_weight, \
t_transformer_h_20_mlp_fc_weight, t_transformer_h_20_mlp_proj_weight, \
t_transformer_h_21_attn_attn_weight, t_transformer_h_21_attn_proj_weight, \
t_transformer_h_21_mlp_fc_weight, t_transformer_h_21_mlp_proj_weight, \
t_transformer_h_22_attn_attn_weight, t_transformer_h_22_attn_proj_weight, \
t_transformer_h_22_mlp_fc_weight, t_transformer_h_22_mlp_proj_weight, \
t_transformer_h_23_attn_attn_weight, t_transformer_h_23_attn_proj_weight, \
t_transformer_h_23_mlp_fc_weight, t_transformer_h_23_mlp_proj_weight, \
t_transformer_h_24_attn_attn_weight, t_transformer_h_24_attn_proj_weight, \
t_transformer_h_24_mlp_fc_weight, t_transformer_h_24_mlp_proj_weight, \
t_transformer_h_25_attn_attn_weight, t_transformer_h_25_attn_proj_weight, \
t_transformer_h_25_mlp_fc_weight, t_transformer_h_25_mlp_proj_weight, \
t_transformer_h_26_attn_attn_weight, t_transformer_h_26_attn_proj_weight, \
t_transformer_h_26_mlp_fc_weight, t_transformer_h_26_mlp_proj_weight, \
t_transformer_h_27_attn_attn_weight, t_transformer_h_27_attn_proj_weight, \
t_transformer_h_27_mlp_fc_weight, t_transformer_h_27_mlp_proj_weight, \
t_transformer_h_28_attn_attn_weight, t_transformer_h_28_attn_proj_weight, \
t_transformer_h_28_mlp_fc_weight, t_transformer_h_28_mlp_proj_weight, \
t_transformer_h_29_attn_attn_weight, t_transformer_h_29_attn_proj_weight, \
t_transformer_h_29_mlp_fc_weight, t_transformer_h_29_mlp_proj_weight, \
t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, \
t_transformer_h_2_mlp_fc_weight, t_transformer_h_2_mlp_proj_weight, \
t_transformer_h_30_attn_attn_weight, t_transformer_h_30_attn_proj_weight, \
t_transformer_h_30_mlp_fc_weight, t_transformer_h_30_mlp_proj_weight, \
t_transformer_h_31_attn_attn_weight, t_transformer_h_31_attn_proj_weight, \
t_transformer_h_31_mlp_fc_weight, t_transformer_h_31_mlp_proj_weight, \
t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, \
t_transformer_h_3_mlp_fc_weight, t_transformer_h_3_mlp_proj_weight, \
t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, \
t_transformer_h_4_mlp_fc_weight, t_transformer_h_4_mlp_proj_weight, \
t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, \
t_transformer_h_5_mlp_fc_weight, t_transformer_h_5_mlp_proj_weight, \
t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, \
t_transformer_h_6_mlp_fc_weight, t_transformer_h_6_mlp_proj_weight, \
t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, \
t_transformer_h_7_mlp_fc_weight, t_transformer_h_7_mlp_proj_weight, \
t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, \
t_transformer_h_8_mlp_fc_weight, t_transformer_h_8_mlp_proj_weight, \
t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, \
t_transformer_h_9_mlp_fc_weight, t_transformer_h_9_mlp_proj_weight, = C0
clear_mutable_collection(C0)
del C0
b1, b1015, b1079, b1143, b119, b1207, b1271, b1335, b1399, b1463, b1527, b1591, \
b1655, b1719, b1783, b183, b1847, b1911, b1975, b2, b2039, b247, b311, b375, \
b439, b503, b55, b567, b631, b695, b759, b823, b887, b951, f1014, f1016, f1023, \
f1025, f1078, f1080, f1087, f1089, f1142, f1144, f1151, f1153, f118, f120, \
f1206, f1208, f1215, f1217, f127, f1270, f1272, f1279, f1281, f129, f1334, \
f1336, f1343, f1345, f1398, f1400, f1407, f1409, f1462, f1464, f1471, f1473, \
f1526, f1528, f1535, f1537, f1590, f1592, f1599, f1601, f1654, f1656, f1663, \
f1665, f1718, f1720, f1727, f1729, f1782, f1784, f1791, f1793, f182, f184, \
f1846, f1848, f1855, f1857, f191, f1910, f1912, f1919, f1921, f193, f1974, \
f1976, f1983, f1985, f2038, f2040, f2047, f2049, f246, f248, f255, f257, f310, \
f312, f319, f321, f374, f376, f383, f385, f438, f440, f447, f449, f502, f504, \
f511, f513, f54, f56, f566, f568, f575, f577, f63, f630, f632, f639, f641, f65, \
f694, f696, f703, f705, f758, f760, f767, f769, f822, f824, f831, f833, f886, \
f888, f895, f897, f950, f952, f959, f961, i0, i1033, i1051, i1097, i1115, \
i1161, i1179, i1225, i1243, i1289, i1307, i1353, i137, i1371, i1417, i1435, \
i1481, i1499, i1545, i155, i1563, i1609, i1627, i1673, i1691, i1737, i1755, \
i1801, i1819, i1865, i1883, i1929, i1947, i1993, i201, i2011, i2057, i219, \
i265, i27, i283, i329, i347, i393, i411, i457, i475, i521, i539, i585, i603, \
i649, i667, i713, i73, i731, i777, i795, i841, i859, i9, i905, i91, i923, i969, \
i987, = C1
clear_mutable_collection(C1)
del C1
t12436 = torch.reshape(t33, (-1, 65024)) # t12436: "cuda:0 bf16[2048, 65024]"
# t12436 = ltorch.reshape(t33, (-1, 65024)) # t12436: "cuda:0 bf16[2048, 65024]"
# t12436 = prims.reshape(t33, (2048, 65024)) # t12436: "cuda:0 bf16[2048, 65024]"
del t33
t12437 = torch.permute(t12436, (1, 0)) # t12437: "cuda:0 bf16[65024, 2048]"
# t12437 = ltorch.permute(t12436, (1, 0)) # t12437: "cuda:0 bf16[65024, 2048]"
# t12437 = prims.transpose(t12436, (1, 0)) # t12437: "cuda:0 bf16[65024, 2048]"
t12438 = torch.reshape(t5147, (-1, 4544)) # t12438: "cuda:0 bf16[2048, 4544]"
# t12438 = ltorch.reshape(t5147, (-1, 4544)) # t12438: "cuda:0 bf16[2048, 4544]"
# t12438 = prims.reshape(t5147, (2048, 4544)) # t12438: "cuda:0 bf16[2048, 4544]"
del t5147
t18718 = torch.reshape(t1572, (-1, 18176)) # t18718: "cuda:0 bf16[2048, 18176]"
# t18718 = ltorch.reshape(t1572, (-1, 18176)) # t18718: "cuda:0 bf16[2048, 18176]"
# t18718 = prims.reshape(t1572, (2048, 18176)) # t18718: "cuda:0 bf16[2048, 18176]"
del t1572
t15605 = torch.reshape(t3343, (-1, 18176)) # t15605: "cuda:0 bf16[2048, 18176]"
# t15605 = ltorch.reshape(t3343, (-1, 18176)) # t15605: "cuda:0 bf16[2048, 18176]"
# t15605 = prims.reshape(t3343, (2048, 18176)) # t15605: "cuda:0 bf16[2048, 18176]"
del t3343
t18752 = torch.reshape(t1444, (-1, 4544)) # t18752: "cuda:0 bf16[2048, 4544]"
# t18752 = ltorch.reshape(t1444, (-1, 4544)) # t18752: "cuda:0 bf16[2048, 4544]"
# t18752 = prims.reshape(t1444, (2048, 4544)) # t18752: "cuda:0 bf16[2048, 4544]"
del t1444
t18759 = torch.reshape(t1551, (-1, 4544)) # t18759: "cuda:0 bf16[2048, 4544]"
# t18759 = ltorch.reshape(t1551, (-1, 4544)) # t18759: "cuda:0 bf16[2048, 4544]"
# t18759 = prims.reshape(t1551, (2048, 4544)) # t18759: "cuda:0 bf16[2048, 4544]"
del t1551
t12500 = torch.reshape(t5114, (-1, 18176)) # t12500: "cuda:0 bf16[2048, 18176]"
# t12500 = ltorch.reshape(t5114, (-1, 18176)) # t12500: "cuda:0 bf16[2048, 18176]"
# t12500 = prims.reshape(t5114, (2048, 18176)) # t12500: "cuda:0 bf16[2048, 18176]"
del t5114
t15639 = torch.reshape(t3215, (-1, 4544)) # t15639: "cuda:0 bf16[2048, 4544]"
# t15639 = ltorch.reshape(t3215, (-1, 4544)) # t15639: "cuda:0 bf16[2048, 4544]"
# t15639 = prims.reshape(t3215, (2048, 4544)) # t15639: "cuda:0 bf16[2048, 4544]"
del t3215
t15646 = torch.reshape(t3322, (-1, 4544)) # t15646: "cuda:0 bf16[2048, 4544]"
# t15646 = ltorch.reshape(t3322, (-1, 4544)) # t15646: "cuda:0 bf16[2048, 4544]"
# t15646 = prims.reshape(t3322, (2048, 4544)) # t15646: "cuda:0 bf16[2048, 4544]"
del t3322
t12534 = torch.reshape(t4986, (-1, 4544)) # t12534: "cuda:0 bf16[2048, 4544]"
# t12534 = ltorch.reshape(t4986, (-1, 4544)) # t12534: "cuda:0 bf16[2048, 4544]"
# t12534 = prims.reshape(t4986, (2048, 4544)) # t12534: "cuda:0 bf16[2048, 4544]"
del t4986
t12541 = torch.reshape(t5093, (-1, 4544)) # t12541: "cuda:0 bf16[2048, 4544]"
# t12541 = ltorch.reshape(t5093, (-1, 4544)) # t12541: "cuda:0 bf16[2048, 4544]"
# t12541 = prims.reshape(t5093, (2048, 4544)) # t12541: "cuda:0 bf16[2048, 4544]"
del t5093
t19001 = torch.reshape(t1411, (-1, 18176)) # t19001: "cuda:0 bf16[2048, 18176]"
# t19001 = ltorch.reshape(t1411, (-1, 18176)) # t19001: "cuda:0 bf16[2048, 18176]"
# t19001 = prims.reshape(t1411, (2048, 18176)) # t19001: "cuda:0 bf16[2048, 18176]"
del t1411
t15888 = torch.reshape(t3182, (-1, 18176)) # t15888: "cuda:0 bf16[2048, 18176]"
# t15888 = ltorch.reshape(t3182, (-1, 18176)) # t15888: "cuda:0 bf16[2048, 18176]"
# t15888 = prims.reshape(t3182, (2048, 18176)) # t15888: "cuda:0 bf16[2048, 18176]"
del t3182
t19035 = torch.reshape(t1283, (-1, 4544)) # t19035: "cuda:0 bf16[2048, 4544]"
# t19035 = ltorch.reshape(t1283, (-1, 4544)) # t19035: "cuda:0 bf16[2048, 4544]"
# t19035 = prims.reshape(t1283, (2048, 4544)) # t19035: "cuda:0 bf16[2048, 4544]"
del t1283
t19042 = torch.reshape(t1390, (-1, 4544)) # t19042: "cuda:0 bf16[2048, 4544]"
# t19042 = ltorch.reshape(t1390, (-1, 4544)) # t19042: "cuda:0 bf16[2048, 4544]"
# t19042 = prims.reshape(t1390, (2048, 4544)) # t19042: "cuda:0 bf16[2048, 4544]"
del t1390
t12775 = torch.reshape(t4953, (-1, 18176)) # t12775: "cuda:0 bf16[2048, 18176]"
# t12775 = ltorch.reshape(t4953, (-1, 18176)) # t12775: "cuda:0 bf16[2048, 18176]"
# t12775 = prims.reshape(t4953, (2048, 18176)) # t12775: "cuda:0 bf16[2048, 18176]"
del t4953
t15922 = torch.reshape(t3054, (-1, 4544)) # t15922: "cuda:0 bf16[2048, 4544]"
# t15922 = ltorch.reshape(t3054, (-1, 4544)) # t15922: "cuda:0 bf16[2048, 4544]"
# t15922 = prims.reshape(t3054, (2048, 4544)) # t15922: "cuda:0 bf16[2048, 4544]"
del t3054
t15929 = torch.reshape(t3161, (-1, 4544)) # t15929: "cuda:0 bf16[2048, 4544]"
# t15929 = ltorch.reshape(t3161, (-1, 4544)) # t15929: "cuda:0 bf16[2048, 4544]"
# t15929 = prims.reshape(t3161, (2048, 4544)) # t15929: "cuda:0 bf16[2048, 4544]"
del t3161
t12809 = torch.reshape(t4825, (-1, 4544)) # t12809: "cuda:0 bf16[2048, 4544]"
# t12809 = ltorch.reshape(t4825, (-1, 4544)) # t12809: "cuda:0 bf16[2048, 4544]"
# t12809 = prims.reshape(t4825, (2048, 4544)) # t12809: "cuda:0 bf16[2048, 4544]"
del t4825
t12816 = torch.reshape(t4932, (-1, 4544)) # t12816: "cuda:0 bf16[2048, 4544]"
# t12816 = ltorch.reshape(t4932, (-1, 4544)) # t12816: "cuda:0 bf16[2048, 4544]"
# t12816 = prims.reshape(t4932, (2048, 4544)) # t12816: "cuda:0 bf16[2048, 4544]"
del t4932
t19284 = torch.reshape(t1250, (-1, 18176)) # t19284: "cuda:0 bf16[2048, 18176]"
# t19284 = ltorch.reshape(t1250, (-1, 18176)) # t19284: "cuda:0 bf16[2048, 18176]"
# t19284 = prims.reshape(t1250, (2048, 18176)) # t19284: "cuda:0 bf16[2048, 18176]"
del t1250
t16171 = torch.reshape(t3021, (-1, 18176)) # t16171: "cuda:0 bf16[2048, 18176]"
# t16171 = ltorch.reshape(t3021, (-1, 18176)) # t16171: "cuda:0 bf16[2048, 18176]"
# t16171 = prims.reshape(t3021, (2048, 18176)) # t16171: "cuda:0 bf16[2048, 18176]"
del t3021
t19318 = torch.reshape(t1122, (-1, 4544)) # t19318: "cuda:0 bf16[2048, 4544]"
# t19318 = ltorch.reshape(t1122, (-1, 4544)) # t19318: "cuda:0 bf16[2048, 4544]"
# t19318 = prims.reshape(t1122, (2048, 4544)) # t19318: "cuda:0 bf16[2048, 4544]"
del t1122
t19325 = torch.reshape(t1229, (-1, 4544)) # t19325: "cuda:0 bf16[2048, 4544]"
# t19325 = ltorch.reshape(t1229, (-1, 4544)) # t19325: "cuda:0 bf16[2048, 4544]"
# t19325 = prims.reshape(t1229, (2048, 4544)) # t19325: "cuda:0 bf16[2048, 4544]"
del t1229
t13058 = torch.reshape(t4792, (-1, 18176)) # t13058: "cuda:0 bf16[2048, 18176]"
# t13058 = ltorch.reshape(t4792, (-1, 18176)) # t13058: "cuda:0 bf16[2048, 18176]"
# t13058 = prims.reshape(t4792, (2048, 18176)) # t13058: "cuda:0 bf16[2048, 18176]"
del t4792
t16205 = torch.reshape(t2893, (-1, 4544)) # t16205: "cuda:0 bf16[2048, 4544]"
# t16205 = ltorch.reshape(t2893, (-1, 4544)) # t16205: "cuda:0 bf16[2048, 4544]"
# t16205 = prims.reshape(t2893, (2048, 4544)) # t16205: "cuda:0 bf16[2048, 4544]"
del t2893
t16212 = torch.reshape(t3000, (-1, 4544)) # t16212: "cuda:0 bf16[2048, 4544]"
# t16212 = ltorch.reshape(t3000, (-1, 4544)) # t16212: "cuda:0 bf16[2048, 4544]"
# t16212 = prims.reshape(t3000, (2048, 4544)) # t16212: "cuda:0 bf16[2048, 4544]"
del t3000
t13092 = torch.reshape(t4664, (-1, 4544)) # t13092: "cuda:0 bf16[2048, 4544]"
# t13092 = ltorch.reshape(t4664, (-1, 4544)) # t13092: "cuda:0 bf16[2048, 4544]"
# t13092 = prims.reshape(t4664, (2048, 4544)) # t13092: "cuda:0 bf16[2048, 4544]"
del t4664
t13099 = torch.reshape(t4771, (-1, 4544)) # t13099: "cuda:0 bf16[2048, 4544]"
# t13099 = ltorch.reshape(t4771, (-1, 4544)) # t13099: "cuda:0 bf16[2048, 4544]"
# t13099 = prims.reshape(t4771, (2048, 4544)) # t13099: "cuda:0 bf16[2048, 4544]"
del t4771
t19567 = torch.reshape(t1089, (-1, 18176)) # t19567: "cuda:0 bf16[2048, 18176]"
# t19567 = ltorch.reshape(t1089, (-1, 18176)) # t19567: "cuda:0 bf16[2048, 18176]"
# t19567 = prims.reshape(t1089, (2048, 18176)) # t19567: "cuda:0 bf16[2048, 18176]"
del t1089
t16454 = torch.reshape(t2860, (-1, 18176)) # t16454: "cuda:0 bf16[2048, 18176]"
# t16454 = ltorch.reshape(t2860, (-1, 18176)) # t16454: "cuda:0 bf16[2048, 18176]"
# t16454 = prims.reshape(t2860, (2048, 18176)) # t16454: "cuda:0 bf16[2048, 18176]"
del t2860
t19601 = torch.reshape(t961, (-1, 4544)) # t19601: "cuda:0 bf16[2048, 4544]"
# t19601 = ltorch.reshape(t961, (-1, 4544)) # t19601: "cuda:0 bf16[2048, 4544]"
# t19601 = prims.reshape(t961, (2048, 4544)) # t19601: "cuda:0 bf16[2048, 4544]"
del t961
t19608 = torch.reshape(t1068, (-1, 4544)) # t19608: "cuda:0 bf16[2048, 4544]"
# t19608 = ltorch.reshape(t1068, (-1, 4544)) # t19608: "cuda:0 bf16[2048, 4544]"
# t19608 = prims.reshape(t1068, (2048, 4544)) # t19608: "cuda:0 bf16[2048, 4544]"
del t1068
t13341 = torch.reshape(t4631, (-1, 18176)) # t13341: "cuda:0 bf16[2048, 18176]"
# t13341 = ltorch.reshape(t4631, (-1, 18176)) # t13341: "cuda:0 bf16[2048, 18176]"
# t13341 = prims.reshape(t4631, (2048, 18176)) # t13341: "cuda:0 bf16[2048, 18176]"
del t4631
t16488 = torch.reshape(t2732, (-1, 4544)) # t16488: "cuda:0 bf16[2048, 4544]"
# t16488 = ltorch.reshape(t2732, (-1, 4544)) # t16488: "cuda:0 bf16[2048, 4544]"
# t16488 = prims.reshape(t2732, (2048, 4544)) # t16488: "cuda:0 bf16[2048, 4544]"
del t2732
t16495 = torch.reshape(t2839, (-1, 4544)) # t16495: "cuda:0 bf16[2048, 4544]"
# t16495 = ltorch.reshape(t2839, (-1, 4544)) # t16495: "cuda:0 bf16[2048, 4544]"
# t16495 = prims.reshape(t2839, (2048, 4544)) # t16495: "cuda:0 bf16[2048, 4544]"
del t2839
t13375 = torch.reshape(t4503, (-1, 4544)) # t13375: "cuda:0 bf16[2048, 4544]"
# t13375 = ltorch.reshape(t4503, (-1, 4544)) # t13375: "cuda:0 bf16[2048, 4544]"
# t13375 = prims.reshape(t4503, (2048, 4544)) # t13375: "cuda:0 bf16[2048, 4544]"
del t4503
t13382 = torch.reshape(t4610, (-1, 4544)) # t13382: "cuda:0 bf16[2048, 4544]"
# t13382 = ltorch.reshape(t4610, (-1, 4544)) # t13382: "cuda:0 bf16[2048, 4544]"
# t13382 = prims.reshape(t4610, (2048, 4544)) # t13382: "cuda:0 bf16[2048, 4544]"
del t4610
t19850 = torch.reshape(t928, (-1, 18176)) # t19850: "cuda:0 bf16[2048, 18176]"
# t19850 = ltorch.reshape(t928, (-1, 18176)) # t19850: "cuda:0 bf16[2048, 18176]"
# t19850 = prims.reshape(t928, (2048, 18176)) # t19850: "cuda:0 bf16[2048, 18176]"
del t928
t16737 = torch.reshape(t2699, (-1, 18176)) # t16737: "cuda:0 bf16[2048, 18176]"
# t16737 = ltorch.reshape(t2699, (-1, 18176)) # t16737: "cuda:0 bf16[2048, 18176]"
# t16737 = prims.reshape(t2699, (2048, 18176)) # t16737: "cuda:0 bf16[2048, 18176]"
del t2699
t19884 = torch.reshape(t800, (-1, 4544)) # t19884: "cuda:0 bf16[2048, 4544]"
# t19884 = ltorch.reshape(t800, (-1, 4544)) # t19884: "cuda:0 bf16[2048, 4544]"
# t19884 = prims.reshape(t800, (2048, 4544)) # t19884: "cuda:0 bf16[2048, 4544]"
del t800
t19891 = torch.reshape(t907, (-1, 4544)) # t19891: "cuda:0 bf16[2048, 4544]"
# t19891 = ltorch.reshape(t907, (-1, 4544)) # t19891: "cuda:0 bf16[2048, 4544]"
# t19891 = prims.reshape(t907, (2048, 4544)) # t19891: "cuda:0 bf16[2048, 4544]"
del t907
t13624 = torch.reshape(t4470, (-1, 18176)) # t13624: "cuda:0 bf16[2048, 18176]"
# t13624 = ltorch.reshape(t4470, (-1, 18176)) # t13624: "cuda:0 bf16[2048, 18176]"
# t13624 = prims.reshape(t4470, (2048, 18176)) # t13624: "cuda:0 bf16[2048, 18176]"
del t4470
t16771 = torch.reshape(t2571, (-1, 4544)) # t16771: "cuda:0 bf16[2048, 4544]"
# t16771 = ltorch.reshape(t2571, (-1, 4544)) # t16771: "cuda:0 bf16[2048, 4544]"
# t16771 = prims.reshape(t2571, (2048, 4544)) # t16771: "cuda:0 bf16[2048, 4544]"
del t2571
t16778 = torch.reshape(t2678, (-1, 4544)) # t16778: "cuda:0 bf16[2048, 4544]"
# t16778 = ltorch.reshape(t2678, (-1, 4544)) # t16778: "cuda:0 bf16[2048, 4544]"
# t16778 = prims.reshape(t2678, (2048, 4544)) # t16778: "cuda:0 bf16[2048, 4544]"
del t2678
t13658 = torch.reshape(t4342, (-1, 4544)) # t13658: "cuda:0 bf16[2048, 4544]"
# t13658 = ltorch.reshape(t4342, (-1, 4544)) # t13658: "cuda:0 bf16[2048, 4544]"
# t13658 = prims.reshape(t4342, (2048, 4544)) # t13658: "cuda:0 bf16[2048, 4544]"
del t4342
t13665 = torch.reshape(t4449, (-1, 4544)) # t13665: "cuda:0 bf16[2048, 4544]"
# t13665 = ltorch.reshape(t4449, (-1, 4544)) # t13665: "cuda:0 bf16[2048, 4544]"
# t13665 = prims.reshape(t4449, (2048, 4544)) # t13665: "cuda:0 bf16[2048, 4544]"
del t4449
t20133 = torch.reshape(t767, (-1, 18176)) # t20133: "cuda:0 bf16[2048, 18176]"
# t20133 = ltorch.reshape(t767, (-1, 18176)) # t20133: "cuda:0 bf16[2048, 18176]"
# t20133 = prims.reshape(t767, (2048, 18176)) # t20133: "cuda:0 bf16[2048, 18176]"
del t767
t17020 = torch.reshape(t2538, (-1, 18176)) # t17020: "cuda:0 bf16[2048, 18176]"
# t17020 = ltorch.reshape(t2538, (-1, 18176)) # t17020: "cuda:0 bf16[2048, 18176]"
# t17020 = prims.reshape(t2538, (2048, 18176)) # t17020: "cuda:0 bf16[2048, 18176]"
del t2538
t20167 = torch.reshape(t639, (-1, 4544)) # t20167: "cuda:0 bf16[2048, 4544]"
# t20167 = ltorch.reshape(t639, (-1, 4544)) # t20167: "cuda:0 bf16[2048, 4544]"
# t20167 = prims.reshape(t639, (2048, 4544)) # t20167: "cuda:0 bf16[2048, 4544]"
del t639
t20174 = torch.reshape(t746, (-1, 4544)) # t20174: "cuda:0 bf16[2048, 4544]"
# t20174 = ltorch.reshape(t746, (-1, 4544)) # t20174: "cuda:0 bf16[2048, 4544]"
# t20174 = prims.reshape(t746, (2048, 4544)) # t20174: "cuda:0 bf16[2048, 4544]"
del t746
t13907 = torch.reshape(t4309, (-1, 18176)) # t13907: "cuda:0 bf16[2048, 18176]"
# t13907 = ltorch.reshape(t4309, (-1, 18176)) # t13907: "cuda:0 bf16[2048, 18176]"
# t13907 = prims.reshape(t4309, (2048, 18176)) # t13907: "cuda:0 bf16[2048, 18176]"
del t4309
t17054 = torch.reshape(t2410, (-1, 4544)) # t17054: "cuda:0 bf16[2048, 4544]"
# t17054 = ltorch.reshape(t2410, (-1, 4544)) # t17054: "cuda:0 bf16[2048, 4544]"
# t17054 = prims.reshape(t2410, (2048, 4544)) # t17054: "cuda:0 bf16[2048, 4544]"
del t2410
t17061 = torch.reshape(t2517, (-1, 4544)) # t17061: "cuda:0 bf16[2048, 4544]"
# t17061 = ltorch.reshape(t2517, (-1, 4544)) # t17061: "cuda:0 bf16[2048, 4544]"
# t17061 = prims.reshape(t2517, (2048, 4544)) # t17061: "cuda:0 bf16[2048, 4544]"
del t2517
t13941 = torch.reshape(t4181, (-1, 4544)) # t13941: "cuda:0 bf16[2048, 4544]"
# t13941 = ltorch.reshape(t4181, (-1, 4544)) # t13941: "cuda:0 bf16[2048, 4544]"
# t13941 = prims.reshape(t4181, (2048, 4544)) # t13941: "cuda:0 bf16[2048, 4544]"
del t4181
t13948 = torch.reshape(t4288, (-1, 4544)) # t13948: "cuda:0 bf16[2048, 4544]"
# t13948 = ltorch.reshape(t4288, (-1, 4544)) # t13948: "cuda:0 bf16[2048, 4544]"
# t13948 = prims.reshape(t4288, (2048, 4544)) # t13948: "cuda:0 bf16[2048, 4544]"
del t4288
t20416 = torch.reshape(t606, (-1, 18176)) # t20416: "cuda:0 bf16[2048, 18176]"
# t20416 = ltorch.reshape(t606, (-1, 18176)) # t20416: "cuda:0 bf16[2048, 18176]"
# t20416 = prims.reshape(t606, (2048, 18176)) # t20416: "cuda:0 bf16[2048, 18176]"
del t606
t17303 = torch.reshape(t2377, (-1, 18176)) # t17303: "cuda:0 bf16[2048, 18176]"
# t17303 = ltorch.reshape(t2377, (-1, 18176)) # t17303: "cuda:0 bf16[2048, 18176]"
# t17303 = prims.reshape(t2377, (2048, 18176)) # t17303: "cuda:0 bf16[2048, 18176]"
del t2377
t20450 = torch.reshape(t478, (-1, 4544)) # t20450: "cuda:0 bf16[2048, 4544]"
# t20450 = ltorch.reshape(t478, (-1, 4544)) # t20450: "cuda:0 bf16[2048, 4544]"
# t20450 = prims.reshape(t478, (2048, 4544)) # t20450: "cuda:0 bf16[2048, 4544]"
del t478
t20457 = torch.reshape(t585, (-1, 4544)) # t20457: "cuda:0 bf16[2048, 4544]"
# t20457 = ltorch.reshape(t585, (-1, 4544)) # t20457: "cuda:0 bf16[2048, 4544]"
# t20457 = prims.reshape(t585, (2048, 4544)) # t20457: "cuda:0 bf16[2048, 4544]"
del t585
t14190 = torch.reshape(t4148, (-1, 18176)) # t14190: "cuda:0 bf16[2048, 18176]"
# t14190 = ltorch.reshape(t4148, (-1, 18176)) # t14190: "cuda:0 bf16[2048, 18176]"
# t14190 = prims.reshape(t4148, (2048, 18176)) # t14190: "cuda:0 bf16[2048, 18176]"
del t4148
t17337 = torch.reshape(t2249, (-1, 4544)) # t17337: "cuda:0 bf16[2048, 4544]"
# t17337 = ltorch.reshape(t2249, (-1, 4544)) # t17337: "cuda:0 bf16[2048, 4544]"
# t17337 = prims.reshape(t2249, (2048, 4544)) # t17337: "cuda:0 bf16[2048, 4544]"
del t2249
t17344 = torch.reshape(t2356, (-1, 4544)) # t17344: "cuda:0 bf16[2048, 4544]"
# t17344 = ltorch.reshape(t2356, (-1, 4544)) # t17344: "cuda:0 bf16[2048, 4544]"
# t17344 = prims.reshape(t2356, (2048, 4544)) # t17344: "cuda:0 bf16[2048, 4544]"
del t2356
t14224 = torch.reshape(t4020, (-1, 4544)) # t14224: "cuda:0 bf16[2048, 4544]"
# t14224 = ltorch.reshape(t4020, (-1, 4544)) # t14224: "cuda:0 bf16[2048, 4544]"
# t14224 = prims.reshape(t4020, (2048, 4544)) # t14224: "cuda:0 bf16[2048, 4544]"
del t4020
t14231 = torch.reshape(t4127, (-1, 4544)) # t14231: "cuda:0 bf16[2048, 4544]"
# t14231 = ltorch.reshape(t4127, (-1, 4544)) # t14231: "cuda:0 bf16[2048, 4544]"
# t14231 = prims.reshape(t4127, (2048, 4544)) # t14231: "cuda:0 bf16[2048, 4544]"
del t4127
t20699 = torch.reshape(t445, (-1, 18176)) # t20699: "cuda:0 bf16[2048, 18176]"
# t20699 = ltorch.reshape(t445, (-1, 18176)) # t20699: "cuda:0 bf16[2048, 18176]"
# t20699 = prims.reshape(t445, (2048, 18176)) # t20699: "cuda:0 bf16[2048, 18176]"
del t445
t17586 = torch.reshape(t2216, (-1, 18176)) # t17586: "cuda:0 bf16[2048, 18176]"
# t17586 = ltorch.reshape(t2216, (-1, 18176)) # t17586: "cuda:0 bf16[2048, 18176]"
# t17586 = prims.reshape(t2216, (2048, 18176)) # t17586: "cuda:0 bf16[2048, 18176]"
del t2216
t20733 = torch.reshape(t317, (-1, 4544)) # t20733: "cuda:0 bf16[2048, 4544]"
# t20733 = ltorch.reshape(t317, (-1, 4544)) # t20733: "cuda:0 bf16[2048, 4544]"
# t20733 = prims.reshape(t317, (2048, 4544)) # t20733: "cuda:0 bf16[2048, 4544]"
del t317
t20740 = torch.reshape(t424, (-1, 4544)) # t20740: "cuda:0 bf16[2048, 4544]"
# t20740 = ltorch.reshape(t424, (-1, 4544)) # t20740: "cuda:0 bf16[2048, 4544]"
# t20740 = prims.reshape(t424, (2048, 4544)) # t20740: "cuda:0 bf16[2048, 4544]"
del t424
t14473 = torch.reshape(t3987, (-1, 18176)) # t14473: "cuda:0 bf16[2048, 18176]"
# t14473 = ltorch.reshape(t3987, (-1, 18176)) # t14473: "cuda:0 bf16[2048, 18176]"
# t14473 = prims.reshape(t3987, (2048, 18176)) # t14473: "cuda:0 bf16[2048, 18176]"
del t3987
t17620 = torch.reshape(t2088, (-1, 4544)) # t17620: "cuda:0 bf16[2048, 4544]"
# t17620 = ltorch.reshape(t2088, (-1, 4544)) # t17620: "cuda:0 bf16[2048, 4544]"
# t17620 = prims.reshape(t2088, (2048, 4544)) # t17620: "cuda:0 bf16[2048, 4544]"
del t2088
t17627 = torch.reshape(t2195, (-1, 4544)) # t17627: "cuda:0 bf16[2048, 4544]"
# t17627 = ltorch.reshape(t2195, (-1, 4544)) # t17627: "cuda:0 bf16[2048, 4544]"
# t17627 = prims.reshape(t2195, (2048, 4544)) # t17627: "cuda:0 bf16[2048, 4544]"
del t2195
t14507 = torch.reshape(t3859, (-1, 4544)) # t14507: "cuda:0 bf16[2048, 4544]"
# t14507 = ltorch.reshape(t3859, (-1, 4544)) # t14507: "cuda:0 bf16[2048, 4544]"
# t14507 = prims.reshape(t3859, (2048, 4544)) # t14507: "cuda:0 bf16[2048, 4544]"
del t3859
t14514 = torch.reshape(t3966, (-1, 4544)) # t14514: "cuda:0 bf16[2048, 4544]"
# t14514 = ltorch.reshape(t3966, (-1, 4544)) # t14514: "cuda:0 bf16[2048, 4544]"
# t14514 = prims.reshape(t3966, (2048, 4544)) # t14514: "cuda:0 bf16[2048, 4544]"
del t3966
t20982 = torch.reshape(t284, (-1, 18176)) # t20982: "cuda:0 bf16[2048, 18176]"
# t20982 = ltorch.reshape(t284, (-1, 18176)) # t20982: "cuda:0 bf16[2048, 18176]"
# t20982 = prims.reshape(t284, (2048, 18176)) # t20982: "cuda:0 bf16[2048, 18176]"
del t284
t17869 = torch.reshape(t2055, (-1, 18176)) # t17869: "cuda:0 bf16[2048, 18176]"
# t17869 = ltorch.reshape(t2055, (-1, 18176)) # t17869: "cuda:0 bf16[2048, 18176]"
# t17869 = prims.reshape(t2055, (2048, 18176)) # t17869: "cuda:0 bf16[2048, 18176]"
del t2055
t21016 = torch.reshape(t156, (-1, 4544)) # t21016: "cuda:0 bf16[2048, 4544]"
# t21016 = ltorch.reshape(t156, (-1, 4544)) # t21016: "cuda:0 bf16[2048, 4544]"
# t21016 = prims.reshape(t156, (2048, 4544)) # t21016: "cuda:0 bf16[2048, 4544]"
del t156
t21023 = torch.reshape(t263, (-1, 4544)) # t21023: "cuda:0 bf16[2048, 4544]"
# t21023 = ltorch.reshape(t263, (-1, 4544)) # t21023: "cuda:0 bf16[2048, 4544]"
# t21023 = prims.reshape(t263, (2048, 4544)) # t21023: "cuda:0 bf16[2048, 4544]"
del t263
t14756 = torch.reshape(t3826, (-1, 18176)) # t14756: "cuda:0 bf16[2048, 18176]"
# t14756 = ltorch.reshape(t3826, (-1, 18176)) # t14756: "cuda:0 bf16[2048, 18176]"
# t14756 = prims.reshape(t3826, (2048, 18176)) # t14756: "cuda:0 bf16[2048, 18176]"
del t3826
t17903 = torch.reshape(t1927, (-1, 4544)) # t17903: "cuda:0 bf16[2048, 4544]"
# t17903 = ltorch.reshape(t1927, (-1, 4544)) # t17903: "cuda:0 bf16[2048, 4544]"
# t17903 = prims.reshape(t1927, (2048, 4544)) # t17903: "cuda:0 bf16[2048, 4544]"
del t1927
t17910 = torch.reshape(t2034, (-1, 4544)) # t17910: "cuda:0 bf16[2048, 4544]"
# t17910 = ltorch.reshape(t2034, (-1, 4544)) # t17910: "cuda:0 bf16[2048, 4544]"
# t17910 = prims.reshape(t2034, (2048, 4544)) # t17910: "cuda:0 bf16[2048, 4544]"
del t2034
t14790 = torch.reshape(t3698, (-1, 4544)) # t14790: "cuda:0 bf16[2048, 4544]"
# t14790 = ltorch.reshape(t3698, (-1, 4544)) # t14790: "cuda:0 bf16[2048, 4544]"
# t14790 = prims.reshape(t3698, (2048, 4544)) # t14790: "cuda:0 bf16[2048, 4544]"
del t3698
t14797 = torch.reshape(t3805, (-1, 4544)) # t14797: "cuda:0 bf16[2048, 4544]"
# t14797 = ltorch.reshape(t3805, (-1, 4544)) # t14797: "cuda:0 bf16[2048, 4544]"
# t14797 = prims.reshape(t3805, (2048, 4544)) # t14797: "cuda:0 bf16[2048, 4544]"
del t3805
t21265 = torch.reshape(t123, (-1, 18176)) # t21265: "cuda:0 bf16[2048, 18176]"
# t21265 = ltorch.reshape(t123, (-1, 18176)) # t21265: "cuda:0 bf16[2048, 18176]"
# t21265 = prims.reshape(t123, (2048, 18176)) # t21265: "cuda:0 bf16[2048, 18176]"
del t123
t18152 = torch.reshape(t1894, (-1, 18176)) # t18152: "cuda:0 bf16[2048, 18176]"
# t18152 = ltorch.reshape(t1894, (-1, 18176)) # t18152: "cuda:0 bf16[2048, 18176]"
# t18152 = prims.reshape(t1894, (2048, 18176)) # t18152: "cuda:0 bf16[2048, 18176]"
del t1894
t21299 = torch.reshape(t25, (-1, 4544)) # t21299: "cuda:0 bf16[2048, 4544]"
# t21299 = ltorch.reshape(t25, (-1, 4544)) # t21299: "cuda:0 bf16[2048, 4544]"
# t21299 = prims.reshape(t25, (2048, 4544)) # t21299: "cuda:0 bf16[2048, 4544]"
del t25
t21306 = torch.reshape(t105, (-1, 4544)) # t21306: "cuda:0 bf16[2048, 4544]"
# t21306 = ltorch.reshape(t105, (-1, 4544)) # t21306: "cuda:0 bf16[2048, 4544]"
# t21306 = prims.reshape(t105, (2048, 4544)) # t21306: "cuda:0 bf16[2048, 4544]"
del t105
t15039 = torch.reshape(t3665, (-1, 18176)) # t15039: "cuda:0 bf16[2048, 18176]"
# t15039 = ltorch.reshape(t3665, (-1, 18176)) # t15039: "cuda:0 bf16[2048, 18176]"
# t15039 = prims.reshape(t3665, (2048, 18176)) # t15039: "cuda:0 bf16[2048, 18176]"
del t3665
t18186 = torch.reshape(t1766, (-1, 4544)) # t18186: "cuda:0 bf16[2048, 4544]"
# t18186 = ltorch.reshape(t1766, (-1, 4544)) # t18186: "cuda:0 bf16[2048, 4544]"
# t18186 = prims.reshape(t1766, (2048, 4544)) # t18186: "cuda:0 bf16[2048, 4544]"
del t1766
t18193 = torch.reshape(t1873, (-1, 4544)) # t18193: "cuda:0 bf16[2048, 4544]"
# t18193 = ltorch.reshape(t1873, (-1, 4544)) # t18193: "cuda:0 bf16[2048, 4544]"
# t18193 = prims.reshape(t1873, (2048, 4544)) # t18193: "cuda:0 bf16[2048, 4544]"
del t1873
t15073 = torch.reshape(t3537, (-1, 4544)) # t15073: "cuda:0 bf16[2048, 4544]"
# t15073 = ltorch.reshape(t3537, (-1, 4544)) # t15073: "cuda:0 bf16[2048, 4544]"
# t15073 = prims.reshape(t3537, (2048, 4544)) # t15073: "cuda:0 bf16[2048, 4544]"
del t3537
t15080 = torch.reshape(t3644, (-1, 4544)) # t15080: "cuda:0 bf16[2048, 4544]"
# t15080 = ltorch.reshape(t3644, (-1, 4544)) # t15080: "cuda:0 bf16[2048, 4544]"
# t15080 = prims.reshape(t3644, (2048, 4544)) # t15080: "cuda:0 bf16[2048, 4544]"
del t3644
t18435 = torch.reshape(t1733, (-1, 18176)) # t18435: "cuda:0 bf16[2048, 18176]"
# t18435 = ltorch.reshape(t1733, (-1, 18176)) # t18435: "cuda:0 bf16[2048, 18176]"
# t18435 = prims.reshape(t1733, (2048, 18176)) # t18435: "cuda:0 bf16[2048, 18176]"
del t1733
t15322 = torch.reshape(t3504, (-1, 18176)) # t15322: "cuda:0 bf16[2048, 18176]"
# t15322 = ltorch.reshape(t3504, (-1, 18176)) # t15322: "cuda:0 bf16[2048, 18176]"
# t15322 = prims.reshape(t3504, (2048, 18176)) # t15322: "cuda:0 bf16[2048, 18176]"
del t3504
t18469 = torch.reshape(t1605, (-1, 4544)) # t18469: "cuda:0 bf16[2048, 4544]"
# t18469 = ltorch.reshape(t1605, (-1, 4544)) # t18469: "cuda:0 bf16[2048, 4544]"
# t18469 = prims.reshape(t1605, (2048, 4544)) # t18469: "cuda:0 bf16[2048, 4544]"
del t1605
t18476 = torch.reshape(t1712, (-1, 4544)) # t18476: "cuda:0 bf16[2048, 4544]"
# t18476 = ltorch.reshape(t1712, (-1, 4544)) # t18476: "cuda:0 bf16[2048, 4544]"
# t18476 = prims.reshape(t1712, (2048, 4544)) # t18476: "cuda:0 bf16[2048, 4544]"
del t1712
t15356 = torch.reshape(t3376, (-1, 4544)) # t15356: "cuda:0 bf16[2048, 4544]"
# t15356 = ltorch.reshape(t3376, (-1, 4544)) # t15356: "cuda:0 bf16[2048, 4544]"
# t15356 = prims.reshape(t3376, (2048, 4544)) # t15356: "cuda:0 bf16[2048, 4544]"
del t3376
t15363 = torch.reshape(t3483, (-1, 4544)) # t15363: "cuda:0 bf16[2048, 4544]"
# t15363 = ltorch.reshape(t3483, (-1, 4544)) # t15363: "cuda:0 bf16[2048, 4544]"
# t15363 = prims.reshape(t3483, (2048, 4544)) # t15363: "cuda:0 bf16[2048, 4544]"
del t3483
i15573 = operator.sub(4544, i1353) # i15573: "int 4544"
# i15573 = prims.sub(4544, i1353) # i15573: "int 4544"
del i1353
i12472 = operator.sub(4544, i2057) # i12472: "int 4544"
# i12472 = prims.sub(4544, i2057) # i12472: "int 4544"
del i2057
i20384 = operator.sub(4544, i265) # i20384: "int 4544"
# i20384 = prims.sub(4544, i265) # i20384: "int 4544"
del i265
i17271 = operator.sub(4544, i969) # i17271: "int 4544"
# i17271 = prims.sub(4544, i969) # i17271: "int 4544"
del i969
i14158 = operator.sub(4544, i1673) # i14158: "int 4544"
# i14158 = prims.sub(4544, i1673) # i14158: "int 4544"
del i1673
i18969 = operator.sub(4544, i585) # i18969: "int 4544"
# i18969 = prims.sub(4544, i585) # i18969: "int 4544"
del i585
i15856 = operator.sub(4544, i1289) # i15856: "int 4544"
# i15856 = prims.sub(4544, i1289) # i15856: "int 4544"
del i1289
i12743 = operator.sub(4544, i1993) # i12743: "int 4544"
# i12743 = prims.sub(4544, i1993) # i12743: "int 4544"
del i1993
i20667 = operator.sub(4544, i201) # i20667: "int 4544"
# i20667 = prims.sub(4544, i201) # i20667: "int 4544"
del i201
i17554 = operator.sub(4544, i905) # i17554: "int 4544"
# i17554 = prims.sub(4544, i905) # i17554: "int 4544"
del i905
i14441 = operator.sub(4544, i1609) # i14441: "int 4544"
# i14441 = prims.sub(4544, i1609) # i14441: "int 4544"
del i1609
i19252 = operator.sub(4544, i521) # i19252: "int 4544"
# i19252 = prims.sub(4544, i521) # i19252: "int 4544"
del i521
i16139 = operator.sub(4544, i1225) # i16139: "int 4544"
# i16139 = prims.sub(4544, i1225) # i16139: "int 4544"
del i1225
i13026 = operator.sub(4544, i1929) # i13026: "int 4544"
# i13026 = prims.sub(4544, i1929) # i13026: "int 4544"
del i1929
i20950 = operator.sub(4544, i137) # i20950: "int 4544"
# i20950 = prims.sub(4544, i137) # i20950: "int 4544"
del i137
i17837 = operator.sub(4544, i841) # i17837: "int 4544"
# i17837 = prims.sub(4544, i841) # i17837: "int 4544"
del i841
i14724 = operator.sub(4544, i1545) # i14724: "int 4544"
# i14724 = prims.sub(4544, i1545) # i14724: "int 4544"
del i1545
i19535 = operator.sub(4544, i457) # i19535: "int 4544"
# i19535 = prims.sub(4544, i457) # i19535: "int 4544"
del i457
i16422 = operator.sub(4544, i1161) # i16422: "int 4544"
# i16422 = prims.sub(4544, i1161) # i16422: "int 4544"
del i1161
i13309 = operator.sub(4544, i1865) # i13309: "int 4544"
# i13309 = prims.sub(4544, i1865) # i13309: "int 4544"
del i1865
i21233 = operator.sub(4544, i73) # i21233: "int 4544"
# i21233 = prims.sub(4544, i73) # i21233: "int 4544"
del i73
i18120 = operator.sub(4544, i777) # i18120: "int 4544"
# i18120 = prims.sub(4544, i777) # i18120: "int 4544"
del i777
i15007 = operator.sub(4544, i1481) # i15007: "int 4544"
# i15007 = prims.sub(4544, i1481) # i15007: "int 4544"
del i1481
i19818 = operator.sub(4544, i393) # i19818: "int 4544"
# i19818 = prims.sub(4544, i393) # i19818: "int 4544"
del i393
i16705 = operator.sub(4544, i1097) # i16705: "int 4544"
# i16705 = prims.sub(4544, i1097) # i16705: "int 4544"
del i1097
i13592 = operator.sub(4544, i1801) # i13592: "int 4544"
# i13592 = prims.sub(4544, i1801) # i13592: "int 4544"
del i1801
i21516 = operator.sub(4544, i9) # i21516: "int 4544"
# i21516 = prims.sub(4544, i9) # i21516: "int 4544"
del i9
i18403 = operator.sub(4544, i713) # i18403: "int 4544"
# i18403 = prims.sub(4544, i713) # i18403: "int 4544"
del i713
i15290 = operator.sub(4544, i1417) # i15290: "int 4544"
# i15290 = prims.sub(4544, i1417) # i15290: "int 4544"
del i1417
i20101 = operator.sub(4544, i329) # i20101: "int 4544"
# i20101 = prims.sub(4544, i329) # i20101: "int 4544"
del i329
i16988 = operator.sub(4544, i1033) # i16988: "int 4544"
# i16988 = prims.sub(4544, i1033) # i16988: "int 4544"
del i1033
i13875 = operator.sub(4544, i1737) # i13875: "int 4544"
# i13875 = prims.sub(4544, i1737) # i13875: "int 4544"
del i1737
i18686 = operator.sub(4544, i649) # i18686: "int 4544"
# i18686 = prims.sub(4544, i649) # i18686: "int 4544"
del i649
t12434 = torch.matmul(t12436, t_lm_head_weight) # t12434: "cuda:0 bf16[2048, 4544]"
# t12434 = ltorch.matmul(t12433, t_lm_head_weight) # t12434: "cuda:0 bf16[2048, 4544]"
# t12434 = prims.matmul(t12433, t_lm_head_weight) # t12434: "cuda:0 bf16[2048, 4544]"
del t12436, t_lm_head_weight
t12439 = torch.matmul(t12437, t12438) # t12439: "cuda:0 bf16[65024, 4544]"
# t12439 = ltorch.matmul(t12437, t12438) # t12439: "cuda:0 bf16[65024, 4544]"
# t12439 = prims.matmul(t12437, t12438) # t12439: "cuda:0 bf16[65024, 4544]"
del t12437, t12438
t12435 = torch.reshape(t12434, (1, 2048, 4544)) # t12435: "cuda:0 bf16[1, 2048, 4544]"
# t12435 = ltorch.reshape(t12434, (1, 2048, 4544)) # t12435: "cuda:0 bf16[1, 2048, 4544]"
# t12435 = prims.reshape(t12434, (1, 2048, 4544)) # t12435: "cuda:0 bf16[1, 2048, 4544]"
del t12434
[t12444, t12450, t12488] = nvFusion0(i12472, t12435, t4962, t5094, t5115, t5130, t5135, t5141)
# t5121 = prims.convert_element_type(t4962, dtypes.float32) # t5121: "cuda:0 f32[1, 2048, 4544]"
# t5116 = prims.convert_element_type(t5115, dtypes.float32) # t5116: "cuda:0 f32[1, 2048, 4544]"
# t5117 = prims.convert_element_type(t5094, dtypes.float32) # t5117: "cuda:0 f32[1, 2048, 4544]"
# t5118 = prims.add(t5116, t5117) # t5118: "cuda:0 f32[1, 2048, 4544]"
# t5122 = prims.add(t5118, t5121) # t5122: "cuda:0 f32[1, 2048, 4544]"
# t5132 = prims.broadcast_in_dim(t5130, [1, 2048, 1], [0, 1]) # t5132: "cuda:0 f32[1, 2048, 1]"
# t5136 = prims.broadcast_in_dim(t5132, (1, 2048, 4544), (0, 1, 2)) # t5136: "cuda:0 f32[1, 2048, 4544]"
# t5138 = prims.sub(t5122, t5136) # t5138: "cuda:0 f32[1, 2048, 4544]"
# t5139 = prims.broadcast_in_dim(t5135, (1, 2048, 4544), (0, 1, 2)) # t5139: "cuda:0 f32[1, 2048, 4544]"
# t5140 = prims.mul(t5138, t5139) # t5140: "cuda:0 f32[1, 2048, 4544]"
# t5142 = prims.convert_element_type(t5141, dtypes.float32) # t5142: "cuda:0 f32[1, 2048, 4544]"
# t12440 = prims.convert_element_type(t12435, dtypes.float32) # t12440: "cuda:0 f32[1, 2048, 4544]"
# t12443 = prims.sum(t12440, (0, 1)) # t12443: "cuda:0 f32[4544]"
# t12444 = prims.convert_element_type(t12443, dtypes.bfloat16) # t12444: "cuda:0 bf16[4544]"
# t12445 = prims.mul(t5142, t12440) # t12445: "cuda:0 f32[1, 2048, 4544]"
# t12446 = prims.mul(t5140, t12440) # t12446: "cuda:0 f32[1, 2048, 4544]"
# t12449 = prims.sum(t12446, (0, 1)) # t12449: "cuda:0 f32[4544]"
# t12450 = prims.convert_element_type(t12449, dtypes.bfloat16) # t12450: "cuda:0 bf16[4544]"
# t12451 = prims.mul(t5139, t12445) # t12451: "cuda:0 f32[1, 2048, 4544]"
# t12452 = prims.mul(t5138, t12445) # t12452: "cuda:0 f32[1, 2048, 4544]"
# t12453 = prims.sum(t12452, (0, 2)) # t12453: "cuda:0 f32[2048]"
# t12454 = prims.broadcast_in_dim(t12453, [1, 2048, 1], [1]) # t12454: "cuda:0 f32[1, 2048, 1]"
# t12455 = prims.neg(t12451) # t12455: "cuda:0 f32[1, 2048, 4544]"
# t12457 = prims.sum(t12455, (0, 2)) # t12457: "cuda:0 f32[2048]"
# t12458 = prims.broadcast_in_dim(t12457, [1, 2048, 1], [1]) # t12458: "cuda:0 f32[1, 2048, 1]"
# t12459 = prims.mul(-0.5, t12454) # t12459: "cuda:0 f32[1, 2048, 1]"
# t12460 = prims.pow(t5135, 3.0) # t12460: "cuda:0 f32[1, 2048, 1]"
# t12461 = prims.mul(t12459, t12460) # t12461: "cuda:0 f32[1, 2048, 1]"
# t12463 = prims.sum(t12458, (0, 2)) # t12463: "cuda:0 f32[2048]"
# t12464 = prims.broadcast_in_dim(t12463, [1, 2048], [1]) # t12464: "cuda:0 f32[1, 2048]"
# t12465 = prims.sum(t12461, (0, 2)) # t12465: "cuda:0 f32[2048]"
# t12466 = prims.broadcast_in_dim(t12465, [1, 2048], [1]) # t12466: "cuda:0 f32[1, 2048]"
# t12469 = prims.broadcast_in_dim(t12464, [1, 2048, 1], [0, 1]) # t12469: "cuda:0 f32[1, 2048, 1]"
# t12470 = prims.broadcast_in_dim(t12469, (1, 2048, 4544), (0, 1, 2)) # t12470: "cuda:0 f32[1, 2048, 4544]"
# t12471 = prims.mul(0.00022007042253521127, t12470) # t12471: "cuda:0 f32[1, 2048, 4544]"
# t12473 = prims.broadcast_in_dim(t12466, [1, 2048, 1], [0, 1]) # t12473: "cuda:0 f32[1, 2048, 1]"
# t12474 = prims.broadcast_in_dim(t12473, (1, 2048, 4544), (0, 1, 2)) # t12474: "cuda:0 f32[1, 2048, 4544]"
# t12476 = prims.broadcast_in_dim(t5130, [1, 2048, 1], [0, 1]) # t12476: "cuda:0 f32[1, 2048, 1]"
# t12477 = prims.broadcast_in_dim(t12476, (1, 2048, 4544), (0, 1, 2)) # t12477: "cuda:0 f32[1, 2048, 4544]"
# t12478 = prims.mul(2.0, t12474) # t12478: "cuda:0 f32[1, 2048, 4544]"
# t12479 = prims.sub(t5122, t12477) # t12479: "cuda:0 f32[1, 2048, 4544]"
# t12480 = prims.mul(t12478, t12479) # t12480: "cuda:0 f32[1, 2048, 4544]"
# f12481 = prims.convert_element_type(i12472, float) # f12481: "float 4544.0"
# t12482 = prims.div(t12480, f12481) # t12482: "cuda:0 f32[1, 2048, 4544]"
# t12483 = prims.add(t12471, t12482) # t12483: "cuda:0 f32[1, 2048, 4544]"
# t12487 = prims.add(t12451, t12483) # t12487: "cuda:0 f32[1, 2048, 4544]"
# t12488 = prims.convert_element_type(t12487, dtypes.bfloat16) # t12488: "cuda:0 bf16[1, 2048, 4544]"
del i12472, t12435, t4962, t5094, t5115, t5130, t5135, t5141
t12495 = torch.reshape(t12488, (-1, 4544)) # t12495: "cuda:0 bf16[2048, 4544]"
# t12495 = ltorch.reshape(t12488, (-1, 4544)) # t12495: "cuda:0 bf16[2048, 4544]"
# t12495 = prims.reshape(t12488, (2048, 4544)) # t12495: "cuda:0 bf16[2048, 4544]"
t12499 = torch.permute(t12495, (1, 0)) # t12499: "cuda:0 bf16[4544, 2048]"
# t12499 = ltorch.permute(t12495, (1, 0)) # t12499: "cuda:0 bf16[4544, 2048]"
# t12499 = prims.transpose(t12495, (1, 0)) # t12499: "cuda:0 bf16[4544, 2048]"
t12501 = torch.matmul(t12499, t12500) # t12501: "cuda:0 bf16[4544, 18176]"
# t12501 = ltorch.matmul(t12499, t12500) # t12501: "cuda:0 bf16[4544, 18176]"
# t12501 = prims.matmul(t12499, t12500) # t12501: "cuda:0 bf16[4544, 18176]"
del t12500
t12537 = torch.matmul(t12495, t_transformer_h_31_attn_proj_weight) # t12537: "cuda:0 bf16[2048, 4544]"
# t12537 = ltorch.matmul(t12536, t_transformer_h_31_attn_proj_weight) # t12537: "cuda:0 bf16[2048, 4544]"
# t12537 = prims.matmul(t12536, t_transformer_h_31_attn_proj_weight) # t12537: "cuda:0 bf16[2048, 4544]"
del t_transformer_h_31_attn_proj_weight
t12542 = torch.matmul(t12499, t12541) # t12542: "cuda:0 bf16[4544, 4544]"
# t12542 = ltorch.matmul(t12540, t12541) # t12542: "cuda:0 bf16[4544, 4544]"
# t12542 = prims.matmul(t12540, t12541) # t12542: "cuda:0 bf16[4544, 4544]"
del t12499, t12541
t12496 = torch.matmul(t12495, t_transformer_h_31_mlp_proj_weight) # t12496: "cuda:0 bf16[2048, 18176]"
# t12496 = ltorch.matmul(t12495, t_transformer_h_31_mlp_proj_weight) # t12496: "cuda:0 bf16[2048, 18176]"
# t12496 = prims.matmul(t12495, t_transformer_h_31_mlp_proj_weight) # t12496: "cuda:0 bf16[2048, 18176]"
del t12495, t_transformer_h_31_mlp_proj_weight
t12538 = torch.reshape(t12537, (1, 2048, 4544)) # t12538: "cuda:0 bf16[1, 2048, 4544]"
# t12538 = ltorch.reshape(t12537, (1, 2048, 4544)) # t12538: "cuda:0 bf16[1, 2048, 4544]"
# t12538 = prims.reshape(t12537, (1, 2048, 4544)) # t12538: "cuda:0 bf16[1, 2048, 4544]"
del t12537
t12546 = torch.reshape(t12538, (1, 2048, 71, 64)) # t12546: "cuda:0 bf16[1, 2048, 71, 64]"
# t12546 = ltorch.reshape(t12538, (1, 2048, 71, 64)) # t12546: "cuda:0 bf16[1, 2048, 71, 64]"
# t12546 = prims.reshape(t12538, (1, 2048, 71, 64)) # t12546: "cuda:0 bf16[1, 2048, 71, 64]"
del t12538
t12549 = torch.permute(t12546, (0, 2, 1, 3)) # t12549: "cuda:0 bf16[1, 71, 2048, 64]"
# t12549 = ltorch.permute(t12546, (0, 2, 1, 3)) # t12549: "cuda:0 bf16[1, 71, 2048, 64]"
# t12549 = prims.transpose(t12546, (0, 2, 1, 3)) # t12549: "cuda:0 bf16[1, 71, 2048, 64]"
del t12546
t12497 = torch.reshape(t12496, (1, 2048, 18176)) # t12497: "cuda:0 bf16[1, 2048, 18176]"
# t12497 = ltorch.reshape(t12496, (1, 2048, 18176)) # t12497: "cuda:0 bf16[1, 2048, 18176]"
# t12497 = prims.reshape(t12496, (1, 2048, 18176)) # t12497: "cuda:0 bf16[1, 2048, 18176]"
del t12496
[t12528] = nvFusion1(f2047, f2049, t12497, t5095)
# t5096 = prims.convert_element_type(t5095, dtypes.float32) # t5096: "cuda:0 f32[1, 2048, 18176]"
# t5098 = prims.div(t5096, 1.4142135623730951) # t5098: "cuda:0 f32[1, 2048, 18176]"
# t5101 = prims.erf(t5098) # t5101: "cuda:0 f32[1, 2048, 18176]"
# t5105 = prims.mul(0.5, t5101) # t5105: "cuda:0 f32[1, 2048, 18176]"
# t5109 = prims.add(0.5, t5105) # t5109: "cuda:0 f32[1, 2048, 18176]"
# t12502 = prims.convert_element_type(t12497, dtypes.float32) # t12502: "cuda:0 f32[1, 2048, 18176]"
# t12503 = prims.mul(t5109, t12502) # t12503: "cuda:0 f32[1, 2048, 18176]"
# t12504 = prims.mul(t5096, t12502) # t12504: "cuda:0 f32[1, 2048, 18176]"
# t12512 = prims.mul(f2049, t12504) # t12512: "cuda:0 f32[1, 2048, 18176]"
# t12515 = prims.pow(t5098, 2.0) # t12515: "cuda:0 f32[1, 2048, 18176]"
# t12516 = prims.neg(t12515) # t12516: "cuda:0 f32[1, 2048, 18176]"
# t12517 = prims.exp(t12516) # t12517: "cuda:0 f32[1, 2048, 18176]"
# t12518 = prims.mul(1.1283791670955126, t12517) # t12518: "cuda:0 f32[1, 2048, 18176]"
# t12519 = prims.mul(t12518, t12512) # t12519: "cuda:0 f32[1, 2048, 18176]"
# t12523 = prims.div(t12519, f2047) # t12523: "cuda:0 f32[1, 2048, 18176]"
# t12527 = prims.add(t12503, t12523) # t12527: "cuda:0 f32[1, 2048, 18176]"
# t12528 = prims.convert_element_type(t12527, dtypes.bfloat16) # t12528: "cuda:0 bf16[1, 2048, 18176]"
del f2047, f2049, t12497, t5095
t12529 = torch.reshape(t12528, (-1, 18176)) # t12529: "cuda:0 bf16[2048, 18176]"
# t12529 = ltorch.reshape(t12528, (-1, 18176)) # t12529: "cuda:0 bf16[2048, 18176]"
# t12529 = prims.reshape(t12528, (2048, 18176)) # t12529: "cuda:0 bf16[2048, 18176]"
del t12528
t12533 = torch.permute(t12529, (1, 0)) # t12533: "cuda:0 bf16[18176, 2048]"
# t12533 = ltorch.permute(t12529, (1, 0)) # t12533: "cuda:0 bf16[18176, 2048]"
# t12533 = prims.transpose(t12529, (1, 0)) # t12533: "cuda:0 bf16[18176, 2048]"
(t12550, t12551, t12552) = cudnn_sdpa_bwd(t12549, t5079, t5082, t5032, None, f2038, b2039, t5083, t5084, t5085, t5086, scale=f2040, cat_grad_qkv=False)
del t12549, t5079, t5082, t5032, f2038, b2039, t5083, t5084, t5085, t5086, f2040
t12535 = torch.matmul(t12533, t12534) # t12535: "cuda:0 bf16[18176, 4544]"
# t12535 = ltorch.matmul(t12533, t12534) # t12535: "cuda:0 bf16[18176, 4544]"
# t12535 = prims.matmul(t12533, t12534) # t12535: "cuda:0 bf16[18176, 4544]"
del t12533
t12530 = torch.matmul(t12529, t_transformer_h_31_mlp_fc_weight) # t12530: "cuda:0 bf16[2048, 4544]"
# t12530 = ltorch.matmul(t12529, t_transformer_h_31_mlp_fc_weight) # t12530: "cuda:0 bf16[2048, 4544]"
# t12530 = prims.matmul(t12529, t_transformer_h_31_mlp_fc_weight) # t12530: "cuda:0 bf16[2048, 4544]"
del t12529, t_transformer_h_31_mlp_fc_weight
t12554 = torch_slice_prim_impl(t12551, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t12554: "cuda:0 bf16[1, 71, 2048, 64]"
del t12551
t12558 = torch_slice_prim_impl(t12550, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t12558: "cuda:0 bf16[1, 71, 2048, 64]"
del t12550
t12653 = torch.reshape(t12552, (1, 1, 71, 2048, 64)) # t12653: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t12653 = ltorch.reshape(t12552, (1, 1, 71, 2048, 64)) # t12653: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t12653 = prims.reshape(t12552, (1, 1, 71, 2048, 64)) # t12653: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t12552
[t12687] = nvFusion2(i2011, t12554, t12558, t12653, t61, t66)
# t12555 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t12555: "cuda:0 bf16[1, 71, 2048, 0]"
# t12556 = prims.pad(t12555, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t12556: "cuda:0 bf16[1, 71, 2048, 64]"
# t12559 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t12559: "cuda:0 bf16[1, 71, 2048, 0]"
# t12560 = prims.pad(t12559, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t12560: "cuda:0 bf16[1, 71, 2048, 64]"
# t12561 = prims.convert_element_type(t12554, dtypes.float32) # t12561: "cuda:0 f32[1, 71, 2048, 64]"
# t12565 = prims.mul(t66, t12561) # t12565: "cuda:0 f32[1, 71, 2048, 64]"
# t12568 = prims.convert_element_type(t12565, dtypes.bfloat16) # t12568: "cuda:0 bf16[1, 71, 2048, 64]"
# t12573 = prims.mul(t61, t12561) # t12573: "cuda:0 f32[1, 71, 2048, 64]"
# t12581 = prims.slice_prim(t12568, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t12581: "cuda:0 bf16[1, 71, 2048, 32]"
# t12582 = prims.slice_prim(t12568, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t12582: "cuda:0 bf16[1, 71, 2048, 32]"
# t12583 = prims.convert_element_type(t12581, dtypes.float32) # t12583: "cuda:0 f32[1, 71, 2048, 32]"
# t12584 = prims.neg(t12583) # t12584: "cuda:0 f32[1, 71, 2048, 32]"
# t12585 = prims.convert_element_type(t12584, dtypes.bfloat16) # t12585: "cuda:0 bf16[1, 71, 2048, 32]"
# t12586 = prims.pad(t12585, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t12586: "cuda:0 bf16[1, 71, 2048, 64]"
# t12588 = prims.convert_element_type(t12586, dtypes.float32) # t12588: "cuda:0 f32[1, 71, 2048, 64]"
# t12589 = prims.add(t12573, t12588) # t12589: "cuda:0 f32[1, 71, 2048, 64]"
# t12591 = prims.pad(t12582, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t12591: "cuda:0 bf16[1, 71, 2048, 64]"
# t12593 = prims.convert_element_type(t12591, dtypes.float32) # t12593: "cuda:0 f32[1, 71, 2048, 64]"
# t12594 = prims.add(t12589, t12593) # t12594: "cuda:0 f32[1, 71, 2048, 64]"
# t12595 = prims.convert_element_type(t12594, dtypes.bfloat16) # t12595: "cuda:0 bf16[1, 71, 2048, 64]"
# t12596 = prims.pad(t12595, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t12596: "cuda:0 bf16[1, 71, 2048, 64]"
# t12597 = prims.convert_element_type(t12556, dtypes.float32) # t12597: "cuda:0 f32[1, 71, 2048, 64]"
# t12598 = prims.convert_element_type(t12596, dtypes.float32) # t12598: "cuda:0 f32[1, 71, 2048, 64]"
# t12599 = prims.add(t12597, t12598) # t12599: "cuda:0 f32[1, 71, 2048, 64]"
# t12600 = prims.convert_element_type(t12599, dtypes.bfloat16) # t12600: "cuda:0 bf16[1, 71, 2048, 64]"
# t12601 = prims.convert_element_type(t12558, dtypes.float32) # t12601: "cuda:0 f32[1, 71, 2048, 64]"
# t12605 = prims.mul(t66, t12601) # t12605: "cuda:0 f32[1, 71, 2048, 64]"
# t12608 = prims.convert_element_type(t12605, dtypes.bfloat16) # t12608: "cuda:0 bf16[1, 71, 2048, 64]"
# t12617 = prims.mul(t61, t12601) # t12617: "cuda:0 f32[1, 71, 2048, 64]"
# t12629 = prims.slice_prim(t12608, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t12629: "cuda:0 bf16[1, 71, 2048, 32]"
# t12630 = prims.slice_prim(t12608, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t12630: "cuda:0 bf16[1, 71, 2048, 32]"
# t12631 = prims.convert_element_type(t12629, dtypes.float32) # t12631: "cuda:0 f32[1, 71, 2048, 32]"
# t12632 = prims.neg(t12631) # t12632: "cuda:0 f32[1, 71, 2048, 32]"
# t12633 = prims.convert_element_type(t12632, dtypes.bfloat16) # t12633: "cuda:0 bf16[1, 71, 2048, 32]"
# t12634 = prims.pad(t12633, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t12634: "cuda:0 bf16[1, 71, 2048, 64]"
# t12636 = prims.convert_element_type(t12634, dtypes.float32) # t12636: "cuda:0 f32[1, 71, 2048, 64]"
# t12637 = prims.add(t12617, t12636) # t12637: "cuda:0 f32[1, 71, 2048, 64]"
# t12639 = prims.pad(t12630, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t12639: "cuda:0 bf16[1, 71, 2048, 64]"
# t12641 = prims.convert_element_type(t12639, dtypes.float32) # t12641: "cuda:0 f32[1, 71, 2048, 64]"
# t12642 = prims.add(t12637, t12641) # t12642: "cuda:0 f32[1, 71, 2048, 64]"
# t12643 = prims.convert_element_type(t12642, dtypes.bfloat16) # t12643: "cuda:0 bf16[1, 71, 2048, 64]"
# t12644 = prims.pad(t12643, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t12644: "cuda:0 bf16[1, 71, 2048, 64]"
# t12645 = prims.convert_element_type(t12560, dtypes.float32) # t12645: "cuda:0 f32[1, 71, 2048, 64]"
# t12646 = prims.convert_element_type(t12644, dtypes.float32) # t12646: "cuda:0 f32[1, 71, 2048, 64]"
# t12647 = prims.add(t12645, t12646) # t12647: "cuda:0 f32[1, 71, 2048, 64]"
# t12648 = prims.convert_element_type(t12647, dtypes.bfloat16) # t12648: "cuda:0 bf16[1, 71, 2048, 64]"
# t12658 = prims.reshape(t12600, (1, 1, 71, 2048, 64)) # t12658: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t12663 = prims.reshape(t12648, (1, 1, 71, 2048, 64)) # t12663: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t12669 = prims.convert_element_type(t12653, dtypes.float32) # t12669: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t12670 = prims.sum(t12669, (0, 1, 2)) # t12670: "cuda:0 f32[2048, 64]"
# t12671 = prims.convert_element_type(t12670, dtypes.bfloat16) # t12671: "cuda:0 bf16[2048, 64]"
# t12672 = prims.broadcast_in_dim(t12671, [1, 1, 1, 2048, 64], [3, 4]) # t12672: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t12678 = prims.convert_element_type(t12658, dtypes.float32) # t12678: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t12679 = prims.sum(t12678, (0, 1, 2)) # t12679: "cuda:0 f32[2048, 64]"
# t12680 = prims.convert_element_type(t12679, dtypes.bfloat16) # t12680: "cuda:0 bf16[2048, 64]"
# t12681 = prims.broadcast_in_dim(t12680, [1, 1, 1, 2048, 64], [3, 4]) # t12681: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t12687 = prims.cat((t12663, t12681, t12672), i2011) # t12687: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i2011, t12554, t12558, t12653
t12693 = torch.permute(t12687, (0, 3, 1, 2, 4)) # t12693: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t12693 = ltorch.permute(t12687, (0, 3, 1, 2, 4)) # t12693: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t12693 = prims.transpose(t12687, (0, 3, 1, 2, 4)) # t12693: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t12687
t12699 = torch.reshape(t12693, (1, 2048, 4672)) # t12699: "cuda:0 bf16[1, 2048, 4672]"
# t12699 = ltorch.reshape(t12693, (1, 2048, 4672)) # t12699: "cuda:0 bf16[1, 2048, 4672]"
# t12699 = prims.reshape(t12693, (1, 2048, 4672)) # t12699: "cuda:0 bf16[1, 2048, 4672]"
del t12693
t12700 = torch.reshape(t12699, (-1, 4672)) # t12700: "cuda:0 bf16[2048, 4672]"
# t12700 = ltorch.reshape(t12699, (-1, 4672)) # t12700: "cuda:0 bf16[2048, 4672]"
# t12700 = prims.reshape(t12699, (2048, 4672)) # t12700: "cuda:0 bf16[2048, 4672]"
del t12699
t12704 = torch.permute(t12700, (1, 0)) # t12704: "cuda:0 bf16[4672, 2048]"
# t12704 = ltorch.permute(t12700, (1, 0)) # t12704: "cuda:0 bf16[4672, 2048]"
# t12704 = prims.transpose(t12700, (1, 0)) # t12704: "cuda:0 bf16[4672, 2048]"
t12706 = torch.matmul(t12704, t12534) # t12706: "cuda:0 bf16[4672, 4544]"
# t12706 = ltorch.matmul(t12704, t12705) # t12706: "cuda:0 bf16[4672, 4544]"
# t12706 = prims.matmul(t12704, t12705) # t12706: "cuda:0 bf16[4672, 4544]"
del t12704, t12534
t12701 = torch.matmul(t12700, t_transformer_h_31_attn_attn_weight) # t12701: "cuda:0 bf16[2048, 4544]"
# t12701 = ltorch.matmul(t12700, t_transformer_h_31_attn_attn_weight) # t12701: "cuda:0 bf16[2048, 4544]"
# t12701 = prims.matmul(t12700, t_transformer_h_31_attn_attn_weight) # t12701: "cuda:0 bf16[2048, 4544]"
del t12700, t_transformer_h_31_attn_attn_weight
t12531 = torch.reshape(t12530, (1, 2048, 4544)) # t12531: "cuda:0 bf16[1, 2048, 4544]"
# t12531 = ltorch.reshape(t12530, (1, 2048, 4544)) # t12531: "cuda:0 bf16[1, 2048, 4544]"
# t12531 = prims.reshape(t12530, (1, 2048, 4544)) # t12531: "cuda:0 bf16[1, 2048, 4544]"
del t12530
t12702 = torch.reshape(t12701, (1, 2048, 4544)) # t12702: "cuda:0 bf16[1, 2048, 4544]"
# t12702 = ltorch.reshape(t12701, (1, 2048, 4544)) # t12702: "cuda:0 bf16[1, 2048, 4544]"
# t12702 = prims.reshape(t12701, (1, 2048, 4544)) # t12702: "cuda:0 bf16[1, 2048, 4544]"
del t12701
[t12715, t12721, t12763] = nvFusion3(i12743, t12488, t12531, t12702, t4801, t4933, t4954, t4969, t4974, t4980)
# t4960 = prims.convert_element_type(t4801, dtypes.float32) # t4960: "cuda:0 f32[1, 2048, 4544]"
# t4955 = prims.convert_element_type(t4954, dtypes.float32) # t4955: "cuda:0 f32[1, 2048, 4544]"
# t4956 = prims.convert_element_type(t4933, dtypes.float32) # t4956: "cuda:0 f32[1, 2048, 4544]"
# t4957 = prims.add(t4955, t4956) # t4957: "cuda:0 f32[1, 2048, 4544]"
# t4961 = prims.add(t4957, t4960) # t4961: "cuda:0 f32[1, 2048, 4544]"
# t4971 = prims.broadcast_in_dim(t4969, [1, 2048, 1], [0, 1]) # t4971: "cuda:0 f32[1, 2048, 1]"
# t4975 = prims.broadcast_in_dim(t4971, (1, 2048, 4544), (0, 1, 2)) # t4975: "cuda:0 f32[1, 2048, 4544]"
# t4977 = prims.sub(t4961, t4975) # t4977: "cuda:0 f32[1, 2048, 4544]"
# t4978 = prims.broadcast_in_dim(t4974, (1, 2048, 4544), (0, 1, 2)) # t4978: "cuda:0 f32[1, 2048, 4544]"
# t4979 = prims.mul(t4977, t4978) # t4979: "cuda:0 f32[1, 2048, 4544]"
# t4981 = prims.convert_element_type(t4980, dtypes.float32) # t4981: "cuda:0 f32[1, 2048, 4544]"
# t12760 = prims.convert_element_type(t12488, dtypes.float32) # t12760: "cuda:0 f32[1, 2048, 4544]"
# t12707 = prims.convert_element_type(t12531, dtypes.float32) # t12707: "cuda:0 f32[1, 2048, 4544]"
# t12708 = prims.convert_element_type(t12702, dtypes.float32) # t12708: "cuda:0 f32[1, 2048, 4544]"
# t12709 = prims.add(t12707, t12708) # t12709: "cuda:0 f32[1, 2048, 4544]"
# t12714 = prims.sum(t12709, (0, 1)) # t12714: "cuda:0 f32[4544]"
# t12715 = prims.convert_element_type(t12714, dtypes.bfloat16) # t12715: "cuda:0 bf16[4544]"
# t12716 = prims.mul(t4981, t12709) # t12716: "cuda:0 f32[1, 2048, 4544]"
# t12717 = prims.mul(t4979, t12709) # t12717: "cuda:0 f32[1, 2048, 4544]"
# t12720 = prims.sum(t12717, (0, 1)) # t12720: "cuda:0 f32[4544]"
# t12721 = prims.convert_element_type(t12720, dtypes.bfloat16) # t12721: "cuda:0 bf16[4544]"
# t12722 = prims.mul(t4978, t12716) # t12722: "cuda:0 f32[1, 2048, 4544]"
# t12723 = prims.mul(t4977, t12716) # t12723: "cuda:0 f32[1, 2048, 4544]"
# t12724 = prims.sum(t12723, (0, 2)) # t12724: "cuda:0 f32[2048]"
# t12725 = prims.broadcast_in_dim(t12724, [1, 2048, 1], [1]) # t12725: "cuda:0 f32[1, 2048, 1]"
# t12726 = prims.neg(t12722) # t12726: "cuda:0 f32[1, 2048, 4544]"
# t12728 = prims.sum(t12726, (0, 2)) # t12728: "cuda:0 f32[2048]"
# t12729 = prims.broadcast_in_dim(t12728, [1, 2048, 1], [1]) # t12729: "cuda:0 f32[1, 2048, 1]"
# t12730 = prims.mul(-0.5, t12725) # t12730: "cuda:0 f32[1, 2048, 1]"
# t12731 = prims.pow(t4974, 3.0) # t12731: "cuda:0 f32[1, 2048, 1]"
# t12732 = prims.mul(t12730, t12731) # t12732: "cuda:0 f32[1, 2048, 1]"
# t12734 = prims.sum(t12729, (0, 2)) # t12734: "cuda:0 f32[2048]"
# t12735 = prims.broadcast_in_dim(t12734, [1, 2048], [1]) # t12735: "cuda:0 f32[1, 2048]"
# t12736 = prims.sum(t12732, (0, 2)) # t12736: "cuda:0 f32[2048]"
# t12737 = prims.broadcast_in_dim(t12736, [1, 2048], [1]) # t12737: "cuda:0 f32[1, 2048]"
# t12740 = prims.broadcast_in_dim(t12735, [1, 2048, 1], [0, 1]) # t12740: "cuda:0 f32[1, 2048, 1]"
# t12741 = prims.broadcast_in_dim(t12740, (1, 2048, 4544), (0, 1, 2)) # t12741: "cuda:0 f32[1, 2048, 4544]"
# t12742 = prims.mul(0.00022007042253521127, t12741) # t12742: "cuda:0 f32[1, 2048, 4544]"
# t12744 = prims.broadcast_in_dim(t12737, [1, 2048, 1], [0, 1]) # t12744: "cuda:0 f32[1, 2048, 1]"
# t12745 = prims.broadcast_in_dim(t12744, (1, 2048, 4544), (0, 1, 2)) # t12745: "cuda:0 f32[1, 2048, 4544]"
# t12747 = prims.broadcast_in_dim(t4969, [1, 2048, 1], [0, 1]) # t12747: "cuda:0 f32[1, 2048, 1]"
# t12748 = prims.broadcast_in_dim(t12747, (1, 2048, 4544), (0, 1, 2)) # t12748: "cuda:0 f32[1, 2048, 4544]"
# t12749 = prims.mul(2.0, t12745) # t12749: "cuda:0 f32[1, 2048, 4544]"
# t12750 = prims.sub(t4961, t12748) # t12750: "cuda:0 f32[1, 2048, 4544]"
# t12751 = prims.mul(t12749, t12750) # t12751: "cuda:0 f32[1, 2048, 4544]"
# f12752 = prims.convert_element_type(i12743, float) # f12752: "float 4544.0"
# t12753 = prims.div(t12751, f12752) # t12753: "cuda:0 f32[1, 2048, 4544]"
# t12754 = prims.add(t12742, t12753) # t12754: "cuda:0 f32[1, 2048, 4544]"
# t12758 = prims.add(t12722, t12754) # t12758: "cuda:0 f32[1, 2048, 4544]"
# t12762 = prims.add(t12760, t12758) # t12762: "cuda:0 f32[1, 2048, 4544]"
# t12763 = prims.convert_element_type(t12762, dtypes.bfloat16) # t12763: "cuda:0 bf16[1, 2048, 4544]"
del i12743, t12488, t12531, t12702, t4801, t4933, t4954, t4969, t4974, t4980
t12770 = torch.reshape(t12763, (-1, 4544)) # t12770: "cuda:0 bf16[2048, 4544]"
# t12770 = ltorch.reshape(t12763, (-1, 4544)) # t12770: "cuda:0 bf16[2048, 4544]"
# t12770 = prims.reshape(t12763, (2048, 4544)) # t12770: "cuda:0 bf16[2048, 4544]"
t12774 = torch.permute(t12770, (1, 0)) # t12774: "cuda:0 bf16[4544, 2048]"
# t12774 = ltorch.permute(t12770, (1, 0)) # t12774: "cuda:0 bf16[4544, 2048]"
# t12774 = prims.transpose(t12770, (1, 0)) # t12774: "cuda:0 bf16[4544, 2048]"
t12817 = torch.matmul(t12774, t12816) # t12817: "cuda:0 bf16[4544, 4544]"
# t12817 = ltorch.matmul(t12815, t12816) # t12817: "cuda:0 bf16[4544, 4544]"
# t12817 = prims.matmul(t12815, t12816) # t12817: "cuda:0 bf16[4544, 4544]"
del t12816
t12771 = torch.matmul(t12770, t_transformer_h_30_mlp_proj_weight) # t12771: "cuda:0 bf16[2048, 18176]"
# t12771 = ltorch.matmul(t12770, t_transformer_h_30_mlp_proj_weight) # t12771: "cuda:0 bf16[2048, 18176]"
# t12771 = prims.matmul(t12770, t_transformer_h_30_mlp_proj_weight) # t12771: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_30_mlp_proj_weight
t12776 = torch.matmul(t12774, t12775) # t12776: "cuda:0 bf16[4544, 18176]"
# t12776 = ltorch.matmul(t12774, t12775) # t12776: "cuda:0 bf16[4544, 18176]"
# t12776 = prims.matmul(t12774, t12775) # t12776: "cuda:0 bf16[4544, 18176]"
del t12774, t12775
t12812 = torch.matmul(t12770, t_transformer_h_30_attn_proj_weight) # t12812: "cuda:0 bf16[2048, 4544]"
# t12812 = ltorch.matmul(t12811, t_transformer_h_30_attn_proj_weight) # t12812: "cuda:0 bf16[2048, 4544]"
# t12812 = prims.matmul(t12811, t_transformer_h_30_attn_proj_weight) # t12812: "cuda:0 bf16[2048, 4544]"
del t12770, t_transformer_h_30_attn_proj_weight
t12772 = torch.reshape(t12771, (1, 2048, 18176)) # t12772: "cuda:0 bf16[1, 2048, 18176]"
# t12772 = ltorch.reshape(t12771, (1, 2048, 18176)) # t12772: "cuda:0 bf16[1, 2048, 18176]"
# t12772 = prims.reshape(t12771, (1, 2048, 18176)) # t12772: "cuda:0 bf16[1, 2048, 18176]"
del t12771
t12813 = torch.reshape(t12812, (1, 2048, 4544)) # t12813: "cuda:0 bf16[1, 2048, 4544]"
# t12813 = ltorch.reshape(t12812, (1, 2048, 4544)) # t12813: "cuda:0 bf16[1, 2048, 4544]"
# t12813 = prims.reshape(t12812, (1, 2048, 4544)) # t12813: "cuda:0 bf16[1, 2048, 4544]"
del t12812
t12821 = torch.reshape(t12813, (1, 2048, 71, 64)) # t12821: "cuda:0 bf16[1, 2048, 71, 64]"
# t12821 = ltorch.reshape(t12813, (1, 2048, 71, 64)) # t12821: "cuda:0 bf16[1, 2048, 71, 64]"
# t12821 = prims.reshape(t12813, (1, 2048, 71, 64)) # t12821: "cuda:0 bf16[1, 2048, 71, 64]"
del t12813
t12824 = torch.permute(t12821, (0, 2, 1, 3)) # t12824: "cuda:0 bf16[1, 71, 2048, 64]"
# t12824 = ltorch.permute(t12821, (0, 2, 1, 3)) # t12824: "cuda:0 bf16[1, 71, 2048, 64]"
# t12824 = prims.transpose(t12821, (0, 2, 1, 3)) # t12824: "cuda:0 bf16[1, 71, 2048, 64]"
del t12821
[t12803] = nvFusion4(f1983, f1985, t12772, t4934)
# t4935 = prims.convert_element_type(t4934, dtypes.float32) # t4935: "cuda:0 f32[1, 2048, 18176]"
# t4937 = prims.div(t4935, 1.4142135623730951) # t4937: "cuda:0 f32[1, 2048, 18176]"
# t4940 = prims.erf(t4937) # t4940: "cuda:0 f32[1, 2048, 18176]"
# t4944 = prims.mul(0.5, t4940) # t4944: "cuda:0 f32[1, 2048, 18176]"
# t4948 = prims.add(0.5, t4944) # t4948: "cuda:0 f32[1, 2048, 18176]"
# t12777 = prims.convert_element_type(t12772, dtypes.float32) # t12777: "cuda:0 f32[1, 2048, 18176]"
# t12778 = prims.mul(t4948, t12777) # t12778: "cuda:0 f32[1, 2048, 18176]"
# t12779 = prims.mul(t4935, t12777) # t12779: "cuda:0 f32[1, 2048, 18176]"
# t12787 = prims.mul(f1985, t12779) # t12787: "cuda:0 f32[1, 2048, 18176]"
# t12790 = prims.pow(t4937, 2.0) # t12790: "cuda:0 f32[1, 2048, 18176]"
# t12791 = prims.neg(t12790) # t12791: "cuda:0 f32[1, 2048, 18176]"
# t12792 = prims.exp(t12791) # t12792: "cuda:0 f32[1, 2048, 18176]"
# t12793 = prims.mul(1.1283791670955126, t12792) # t12793: "cuda:0 f32[1, 2048, 18176]"
# t12794 = prims.mul(t12793, t12787) # t12794: "cuda:0 f32[1, 2048, 18176]"
# t12798 = prims.div(t12794, f1983) # t12798: "cuda:0 f32[1, 2048, 18176]"
# t12802 = prims.add(t12778, t12798) # t12802: "cuda:0 f32[1, 2048, 18176]"
# t12803 = prims.convert_element_type(t12802, dtypes.bfloat16) # t12803: "cuda:0 bf16[1, 2048, 18176]"
del f1983, f1985, t12772, t4934
t12804 = torch.reshape(t12803, (-1, 18176)) # t12804: "cuda:0 bf16[2048, 18176]"
# t12804 = ltorch.reshape(t12803, (-1, 18176)) # t12804: "cuda:0 bf16[2048, 18176]"
# t12804 = prims.reshape(t12803, (2048, 18176)) # t12804: "cuda:0 bf16[2048, 18176]"
del t12803
t12808 = torch.permute(t12804, (1, 0)) # t12808: "cuda:0 bf16[18176, 2048]"
# t12808 = ltorch.permute(t12804, (1, 0)) # t12808: "cuda:0 bf16[18176, 2048]"
# t12808 = prims.transpose(t12804, (1, 0)) # t12808: "cuda:0 bf16[18176, 2048]"
t12810 = torch.matmul(t12808, t12809) # t12810: "cuda:0 bf16[18176, 4544]"
# t12810 = ltorch.matmul(t12808, t12809) # t12810: "cuda:0 bf16[18176, 4544]"
# t12810 = prims.matmul(t12808, t12809) # t12810: "cuda:0 bf16[18176, 4544]"
del t12808
t12805 = torch.matmul(t12804, t_transformer_h_30_mlp_fc_weight) # t12805: "cuda:0 bf16[2048, 4544]"
# t12805 = ltorch.matmul(t12804, t_transformer_h_30_mlp_fc_weight) # t12805: "cuda:0 bf16[2048, 4544]"
# t12805 = prims.matmul(t12804, t_transformer_h_30_mlp_fc_weight) # t12805: "cuda:0 bf16[2048, 4544]"
del t12804, t_transformer_h_30_mlp_fc_weight
(t12825, t12826, t12827) = cudnn_sdpa_bwd(t12824, t4918, t4921, t4871, None, f1974, b1975, t4922, t4923, t4924, t4925, scale=f1976, cat_grad_qkv=False)
del t12824, t4918, t4921, t4871, f1974, b1975, t4922, t4923, t4924, t4925, f1976
t12829 = torch_slice_prim_impl(t12826, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t12829: "cuda:0 bf16[1, 71, 2048, 64]"
del t12826
t12833 = torch_slice_prim_impl(t12825, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t12833: "cuda:0 bf16[1, 71, 2048, 64]"
del t12825
t12936 = torch.reshape(t12827, (1, 1, 71, 2048, 64)) # t12936: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t12936 = ltorch.reshape(t12827, (1, 1, 71, 2048, 64)) # t12936: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t12936 = prims.reshape(t12827, (1, 1, 71, 2048, 64)) # t12936: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t12827
[t12970] = nvFusion5(i1947, t12829, t12833, t12936, t61, t66)
# t12830 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t12830: "cuda:0 bf16[1, 71, 2048, 0]"
# t12831 = prims.pad(t12830, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t12831: "cuda:0 bf16[1, 71, 2048, 64]"
# t12834 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t12834: "cuda:0 bf16[1, 71, 2048, 0]"
# t12835 = prims.pad(t12834, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t12835: "cuda:0 bf16[1, 71, 2048, 64]"
# t12836 = prims.convert_element_type(t12829, dtypes.float32) # t12836: "cuda:0 f32[1, 71, 2048, 64]"
# t12840 = prims.mul(t66, t12836) # t12840: "cuda:0 f32[1, 71, 2048, 64]"
# t12843 = prims.convert_element_type(t12840, dtypes.bfloat16) # t12843: "cuda:0 bf16[1, 71, 2048, 64]"
# t12852 = prims.mul(t61, t12836) # t12852: "cuda:0 f32[1, 71, 2048, 64]"
# t12864 = prims.slice_prim(t12843, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t12864: "cuda:0 bf16[1, 71, 2048, 32]"
# t12865 = prims.slice_prim(t12843, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t12865: "cuda:0 bf16[1, 71, 2048, 32]"
# t12866 = prims.convert_element_type(t12864, dtypes.float32) # t12866: "cuda:0 f32[1, 71, 2048, 32]"
# t12867 = prims.neg(t12866) # t12867: "cuda:0 f32[1, 71, 2048, 32]"
# t12868 = prims.convert_element_type(t12867, dtypes.bfloat16) # t12868: "cuda:0 bf16[1, 71, 2048, 32]"
# t12869 = prims.pad(t12868, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t12869: "cuda:0 bf16[1, 71, 2048, 64]"
# t12871 = prims.convert_element_type(t12869, dtypes.float32) # t12871: "cuda:0 f32[1, 71, 2048, 64]"
# t12872 = prims.add(t12852, t12871) # t12872: "cuda:0 f32[1, 71, 2048, 64]"
# t12874 = prims.pad(t12865, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t12874: "cuda:0 bf16[1, 71, 2048, 64]"
# t12876 = prims.convert_element_type(t12874, dtypes.float32) # t12876: "cuda:0 f32[1, 71, 2048, 64]"
# t12877 = prims.add(t12872, t12876) # t12877: "cuda:0 f32[1, 71, 2048, 64]"
# t12878 = prims.convert_element_type(t12877, dtypes.bfloat16) # t12878: "cuda:0 bf16[1, 71, 2048, 64]"
# t12879 = prims.pad(t12878, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t12879: "cuda:0 bf16[1, 71, 2048, 64]"
# t12880 = prims.convert_element_type(t12831, dtypes.float32) # t12880: "cuda:0 f32[1, 71, 2048, 64]"
# t12881 = prims.convert_element_type(t12879, dtypes.float32) # t12881: "cuda:0 f32[1, 71, 2048, 64]"
# t12882 = prims.add(t12880, t12881) # t12882: "cuda:0 f32[1, 71, 2048, 64]"
# t12883 = prims.convert_element_type(t12882, dtypes.bfloat16) # t12883: "cuda:0 bf16[1, 71, 2048, 64]"
# t12884 = prims.convert_element_type(t12833, dtypes.float32) # t12884: "cuda:0 f32[1, 71, 2048, 64]"
# t12888 = prims.mul(t66, t12884) # t12888: "cuda:0 f32[1, 71, 2048, 64]"
# t12891 = prims.convert_element_type(t12888, dtypes.bfloat16) # t12891: "cuda:0 bf16[1, 71, 2048, 64]"
# t12900 = prims.mul(t61, t12884) # t12900: "cuda:0 f32[1, 71, 2048, 64]"
# t12912 = prims.slice_prim(t12891, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t12912: "cuda:0 bf16[1, 71, 2048, 32]"
# t12913 = prims.slice_prim(t12891, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t12913: "cuda:0 bf16[1, 71, 2048, 32]"
# t12914 = prims.convert_element_type(t12912, dtypes.float32) # t12914: "cuda:0 f32[1, 71, 2048, 32]"
# t12915 = prims.neg(t12914) # t12915: "cuda:0 f32[1, 71, 2048, 32]"
# t12916 = prims.convert_element_type(t12915, dtypes.bfloat16) # t12916: "cuda:0 bf16[1, 71, 2048, 32]"
# t12917 = prims.pad(t12916, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t12917: "cuda:0 bf16[1, 71, 2048, 64]"
# t12919 = prims.convert_element_type(t12917, dtypes.float32) # t12919: "cuda:0 f32[1, 71, 2048, 64]"
# t12920 = prims.add(t12900, t12919) # t12920: "cuda:0 f32[1, 71, 2048, 64]"
# t12922 = prims.pad(t12913, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t12922: "cuda:0 bf16[1, 71, 2048, 64]"
# t12924 = prims.convert_element_type(t12922, dtypes.float32) # t12924: "cuda:0 f32[1, 71, 2048, 64]"
# t12925 = prims.add(t12920, t12924) # t12925: "cuda:0 f32[1, 71, 2048, 64]"
# t12926 = prims.convert_element_type(t12925, dtypes.bfloat16) # t12926: "cuda:0 bf16[1, 71, 2048, 64]"
# t12927 = prims.pad(t12926, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t12927: "cuda:0 bf16[1, 71, 2048, 64]"
# t12928 = prims.convert_element_type(t12835, dtypes.float32) # t12928: "cuda:0 f32[1, 71, 2048, 64]"
# t12929 = prims.convert_element_type(t12927, dtypes.float32) # t12929: "cuda:0 f32[1, 71, 2048, 64]"
# t12930 = prims.add(t12928, t12929) # t12930: "cuda:0 f32[1, 71, 2048, 64]"
# t12931 = prims.convert_element_type(t12930, dtypes.bfloat16) # t12931: "cuda:0 bf16[1, 71, 2048, 64]"
# t12941 = prims.reshape(t12883, (1, 1, 71, 2048, 64)) # t12941: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t12946 = prims.reshape(t12931, (1, 1, 71, 2048, 64)) # t12946: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t12952 = prims.convert_element_type(t12936, dtypes.float32) # t12952: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t12953 = prims.sum(t12952, (0, 1, 2)) # t12953: "cuda:0 f32[2048, 64]"
# t12954 = prims.convert_element_type(t12953, dtypes.bfloat16) # t12954: "cuda:0 bf16[2048, 64]"
# t12955 = prims.broadcast_in_dim(t12954, [1, 1, 1, 2048, 64], [3, 4]) # t12955: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t12961 = prims.convert_element_type(t12941, dtypes.float32) # t12961: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t12962 = prims.sum(t12961, (0, 1, 2)) # t12962: "cuda:0 f32[2048, 64]"
# t12963 = prims.convert_element_type(t12962, dtypes.bfloat16) # t12963: "cuda:0 bf16[2048, 64]"
# t12964 = prims.broadcast_in_dim(t12963, [1, 1, 1, 2048, 64], [3, 4]) # t12964: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t12970 = prims.cat((t12946, t12964, t12955), i1947) # t12970: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1947, t12829, t12833, t12936
t12976 = torch.permute(t12970, (0, 3, 1, 2, 4)) # t12976: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t12976 = ltorch.permute(t12970, (0, 3, 1, 2, 4)) # t12976: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t12976 = prims.transpose(t12970, (0, 3, 1, 2, 4)) # t12976: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t12970
t12982 = torch.reshape(t12976, (1, 2048, 4672)) # t12982: "cuda:0 bf16[1, 2048, 4672]"
# t12982 = ltorch.reshape(t12976, (1, 2048, 4672)) # t12982: "cuda:0 bf16[1, 2048, 4672]"
# t12982 = prims.reshape(t12976, (1, 2048, 4672)) # t12982: "cuda:0 bf16[1, 2048, 4672]"
del t12976
t12983 = torch.reshape(t12982, (-1, 4672)) # t12983: "cuda:0 bf16[2048, 4672]"
# t12983 = ltorch.reshape(t12982, (-1, 4672)) # t12983: "cuda:0 bf16[2048, 4672]"
# t12983 = prims.reshape(t12982, (2048, 4672)) # t12983: "cuda:0 bf16[2048, 4672]"
del t12982
t12987 = torch.permute(t12983, (1, 0)) # t12987: "cuda:0 bf16[4672, 2048]"
# t12987 = ltorch.permute(t12983, (1, 0)) # t12987: "cuda:0 bf16[4672, 2048]"
# t12987 = prims.transpose(t12983, (1, 0)) # t12987: "cuda:0 bf16[4672, 2048]"
t12989 = torch.matmul(t12987, t12809) # t12989: "cuda:0 bf16[4672, 4544]"
# t12989 = ltorch.matmul(t12987, t12988) # t12989: "cuda:0 bf16[4672, 4544]"
# t12989 = prims.matmul(t12987, t12988) # t12989: "cuda:0 bf16[4672, 4544]"
del t12987, t12809
t12984 = torch.matmul(t12983, t_transformer_h_30_attn_attn_weight) # t12984: "cuda:0 bf16[2048, 4544]"
# t12984 = ltorch.matmul(t12983, t_transformer_h_30_attn_attn_weight) # t12984: "cuda:0 bf16[2048, 4544]"
# t12984 = prims.matmul(t12983, t_transformer_h_30_attn_attn_weight) # t12984: "cuda:0 bf16[2048, 4544]"
del t12983, t_transformer_h_30_attn_attn_weight
t12806 = torch.reshape(t12805, (1, 2048, 4544)) # t12806: "cuda:0 bf16[1, 2048, 4544]"
# t12806 = ltorch.reshape(t12805, (1, 2048, 4544)) # t12806: "cuda:0 bf16[1, 2048, 4544]"
# t12806 = prims.reshape(t12805, (1, 2048, 4544)) # t12806: "cuda:0 bf16[1, 2048, 4544]"
del t12805
t12985 = torch.reshape(t12984, (1, 2048, 4544)) # t12985: "cuda:0 bf16[1, 2048, 4544]"
# t12985 = ltorch.reshape(t12984, (1, 2048, 4544)) # t12985: "cuda:0 bf16[1, 2048, 4544]"
# t12985 = prims.reshape(t12984, (1, 2048, 4544)) # t12985: "cuda:0 bf16[1, 2048, 4544]"
del t12984
[t12998, t13004, t13046] = nvFusion6(i13026, t12763, t12806, t12985, t4640, t4772, t4793, t4808, t4813, t4819)
# t4799 = prims.convert_element_type(t4640, dtypes.float32) # t4799: "cuda:0 f32[1, 2048, 4544]"
# t4794 = prims.convert_element_type(t4793, dtypes.float32) # t4794: "cuda:0 f32[1, 2048, 4544]"
# t4795 = prims.convert_element_type(t4772, dtypes.float32) # t4795: "cuda:0 f32[1, 2048, 4544]"
# t4796 = prims.add(t4794, t4795) # t4796: "cuda:0 f32[1, 2048, 4544]"
# t4800 = prims.add(t4796, t4799) # t4800: "cuda:0 f32[1, 2048, 4544]"
# t4810 = prims.broadcast_in_dim(t4808, [1, 2048, 1], [0, 1]) # t4810: "cuda:0 f32[1, 2048, 1]"
# t4814 = prims.broadcast_in_dim(t4810, (1, 2048, 4544), (0, 1, 2)) # t4814: "cuda:0 f32[1, 2048, 4544]"
# t4816 = prims.sub(t4800, t4814) # t4816: "cuda:0 f32[1, 2048, 4544]"
# t4817 = prims.broadcast_in_dim(t4813, (1, 2048, 4544), (0, 1, 2)) # t4817: "cuda:0 f32[1, 2048, 4544]"
# t4818 = prims.mul(t4816, t4817) # t4818: "cuda:0 f32[1, 2048, 4544]"
# t4820 = prims.convert_element_type(t4819, dtypes.float32) # t4820: "cuda:0 f32[1, 2048, 4544]"
# t13043 = prims.convert_element_type(t12763, dtypes.float32) # t13043: "cuda:0 f32[1, 2048, 4544]"
# t12990 = prims.convert_element_type(t12806, dtypes.float32) # t12990: "cuda:0 f32[1, 2048, 4544]"
# t12991 = prims.convert_element_type(t12985, dtypes.float32) # t12991: "cuda:0 f32[1, 2048, 4544]"
# t12992 = prims.add(t12990, t12991) # t12992: "cuda:0 f32[1, 2048, 4544]"
# t12997 = prims.sum(t12992, (0, 1)) # t12997: "cuda:0 f32[4544]"
# t12998 = prims.convert_element_type(t12997, dtypes.bfloat16) # t12998: "cuda:0 bf16[4544]"
# t12999 = prims.mul(t4820, t12992) # t12999: "cuda:0 f32[1, 2048, 4544]"
# t13000 = prims.mul(t4818, t12992) # t13000: "cuda:0 f32[1, 2048, 4544]"
# t13003 = prims.sum(t13000, (0, 1)) # t13003: "cuda:0 f32[4544]"
# t13004 = prims.convert_element_type(t13003, dtypes.bfloat16) # t13004: "cuda:0 bf16[4544]"
# t13005 = prims.mul(t4817, t12999) # t13005: "cuda:0 f32[1, 2048, 4544]"
# t13006 = prims.mul(t4816, t12999) # t13006: "cuda:0 f32[1, 2048, 4544]"
# t13007 = prims.sum(t13006, (0, 2)) # t13007: "cuda:0 f32[2048]"
# t13008 = prims.broadcast_in_dim(t13007, [1, 2048, 1], [1]) # t13008: "cuda:0 f32[1, 2048, 1]"
# t13009 = prims.neg(t13005) # t13009: "cuda:0 f32[1, 2048, 4544]"
# t13011 = prims.sum(t13009, (0, 2)) # t13011: "cuda:0 f32[2048]"
# t13012 = prims.broadcast_in_dim(t13011, [1, 2048, 1], [1]) # t13012: "cuda:0 f32[1, 2048, 1]"
# t13013 = prims.mul(-0.5, t13008) # t13013: "cuda:0 f32[1, 2048, 1]"
# t13014 = prims.pow(t4813, 3.0) # t13014: "cuda:0 f32[1, 2048, 1]"
# t13015 = prims.mul(t13013, t13014) # t13015: "cuda:0 f32[1, 2048, 1]"
# t13017 = prims.sum(t13012, (0, 2)) # t13017: "cuda:0 f32[2048]"
# t13018 = prims.broadcast_in_dim(t13017, [1, 2048], [1]) # t13018: "cuda:0 f32[1, 2048]"
# t13019 = prims.sum(t13015, (0, 2)) # t13019: "cuda:0 f32[2048]"
# t13020 = prims.broadcast_in_dim(t13019, [1, 2048], [1]) # t13020: "cuda:0 f32[1, 2048]"
# t13023 = prims.broadcast_in_dim(t13018, [1, 2048, 1], [0, 1]) # t13023: "cuda:0 f32[1, 2048, 1]"
# t13024 = prims.broadcast_in_dim(t13023, (1, 2048, 4544), (0, 1, 2)) # t13024: "cuda:0 f32[1, 2048, 4544]"
# t13025 = prims.mul(0.00022007042253521127, t13024) # t13025: "cuda:0 f32[1, 2048, 4544]"
# t13027 = prims.broadcast_in_dim(t13020, [1, 2048, 1], [0, 1]) # t13027: "cuda:0 f32[1, 2048, 1]"
# t13028 = prims.broadcast_in_dim(t13027, (1, 2048, 4544), (0, 1, 2)) # t13028: "cuda:0 f32[1, 2048, 4544]"
# t13030 = prims.broadcast_in_dim(t4808, [1, 2048, 1], [0, 1]) # t13030: "cuda:0 f32[1, 2048, 1]"
# t13031 = prims.broadcast_in_dim(t13030, (1, 2048, 4544), (0, 1, 2)) # t13031: "cuda:0 f32[1, 2048, 4544]"
# t13032 = prims.mul(2.0, t13028) # t13032: "cuda:0 f32[1, 2048, 4544]"
# t13033 = prims.sub(t4800, t13031) # t13033: "cuda:0 f32[1, 2048, 4544]"
# t13034 = prims.mul(t13032, t13033) # t13034: "cuda:0 f32[1, 2048, 4544]"
# f13035 = prims.convert_element_type(i13026, float) # f13035: "float 4544.0"
# t13036 = prims.div(t13034, f13035) # t13036: "cuda:0 f32[1, 2048, 4544]"
# t13037 = prims.add(t13025, t13036) # t13037: "cuda:0 f32[1, 2048, 4544]"
# t13041 = prims.add(t13005, t13037) # t13041: "cuda:0 f32[1, 2048, 4544]"
# t13045 = prims.add(t13043, t13041) # t13045: "cuda:0 f32[1, 2048, 4544]"
# t13046 = prims.convert_element_type(t13045, dtypes.bfloat16) # t13046: "cuda:0 bf16[1, 2048, 4544]"
del i13026, t12763, t12806, t12985, t4640, t4772, t4793, t4808, t4813, t4819
t13053 = torch.reshape(t13046, (-1, 4544)) # t13053: "cuda:0 bf16[2048, 4544]"
# t13053 = ltorch.reshape(t13046, (-1, 4544)) # t13053: "cuda:0 bf16[2048, 4544]"
# t13053 = prims.reshape(t13046, (2048, 4544)) # t13053: "cuda:0 bf16[2048, 4544]"
t13057 = torch.permute(t13053, (1, 0)) # t13057: "cuda:0 bf16[4544, 2048]"
# t13057 = ltorch.permute(t13053, (1, 0)) # t13057: "cuda:0 bf16[4544, 2048]"
# t13057 = prims.transpose(t13053, (1, 0)) # t13057: "cuda:0 bf16[4544, 2048]"
t13054 = torch.matmul(t13053, t_transformer_h_29_mlp_proj_weight) # t13054: "cuda:0 bf16[2048, 18176]"
# t13054 = ltorch.matmul(t13053, t_transformer_h_29_mlp_proj_weight) # t13054: "cuda:0 bf16[2048, 18176]"
# t13054 = prims.matmul(t13053, t_transformer_h_29_mlp_proj_weight) # t13054: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_29_mlp_proj_weight
t13059 = torch.matmul(t13057, t13058) # t13059: "cuda:0 bf16[4544, 18176]"
# t13059 = ltorch.matmul(t13057, t13058) # t13059: "cuda:0 bf16[4544, 18176]"
# t13059 = prims.matmul(t13057, t13058) # t13059: "cuda:0 bf16[4544, 18176]"
del t13058
t13095 = torch.matmul(t13053, t_transformer_h_29_attn_proj_weight) # t13095: "cuda:0 bf16[2048, 4544]"
# t13095 = ltorch.matmul(t13094, t_transformer_h_29_attn_proj_weight) # t13095: "cuda:0 bf16[2048, 4544]"
# t13095 = prims.matmul(t13094, t_transformer_h_29_attn_proj_weight) # t13095: "cuda:0 bf16[2048, 4544]"
del t13053, t_transformer_h_29_attn_proj_weight
t13100 = torch.matmul(t13057, t13099) # t13100: "cuda:0 bf16[4544, 4544]"
# t13100 = ltorch.matmul(t13098, t13099) # t13100: "cuda:0 bf16[4544, 4544]"
# t13100 = prims.matmul(t13098, t13099) # t13100: "cuda:0 bf16[4544, 4544]"
del t13057, t13099
t13055 = torch.reshape(t13054, (1, 2048, 18176)) # t13055: "cuda:0 bf16[1, 2048, 18176]"
# t13055 = ltorch.reshape(t13054, (1, 2048, 18176)) # t13055: "cuda:0 bf16[1, 2048, 18176]"
# t13055 = prims.reshape(t13054, (1, 2048, 18176)) # t13055: "cuda:0 bf16[1, 2048, 18176]"
del t13054
t13096 = torch.reshape(t13095, (1, 2048, 4544)) # t13096: "cuda:0 bf16[1, 2048, 4544]"
# t13096 = ltorch.reshape(t13095, (1, 2048, 4544)) # t13096: "cuda:0 bf16[1, 2048, 4544]"
# t13096 = prims.reshape(t13095, (1, 2048, 4544)) # t13096: "cuda:0 bf16[1, 2048, 4544]"
del t13095
t13104 = torch.reshape(t13096, (1, 2048, 71, 64)) # t13104: "cuda:0 bf16[1, 2048, 71, 64]"
# t13104 = ltorch.reshape(t13096, (1, 2048, 71, 64)) # t13104: "cuda:0 bf16[1, 2048, 71, 64]"
# t13104 = prims.reshape(t13096, (1, 2048, 71, 64)) # t13104: "cuda:0 bf16[1, 2048, 71, 64]"
del t13096
t13107 = torch.permute(t13104, (0, 2, 1, 3)) # t13107: "cuda:0 bf16[1, 71, 2048, 64]"
# t13107 = ltorch.permute(t13104, (0, 2, 1, 3)) # t13107: "cuda:0 bf16[1, 71, 2048, 64]"
# t13107 = prims.transpose(t13104, (0, 2, 1, 3)) # t13107: "cuda:0 bf16[1, 71, 2048, 64]"
del t13104
[t13086] = nvFusion7(f1919, f1921, t13055, t4773)
# t4774 = prims.convert_element_type(t4773, dtypes.float32) # t4774: "cuda:0 f32[1, 2048, 18176]"
# t4776 = prims.div(t4774, 1.4142135623730951) # t4776: "cuda:0 f32[1, 2048, 18176]"
# t4779 = prims.erf(t4776) # t4779: "cuda:0 f32[1, 2048, 18176]"
# t4783 = prims.mul(0.5, t4779) # t4783: "cuda:0 f32[1, 2048, 18176]"
# t4787 = prims.add(0.5, t4783) # t4787: "cuda:0 f32[1, 2048, 18176]"
# t13060 = prims.convert_element_type(t13055, dtypes.float32) # t13060: "cuda:0 f32[1, 2048, 18176]"
# t13061 = prims.mul(t4787, t13060) # t13061: "cuda:0 f32[1, 2048, 18176]"
# t13062 = prims.mul(t4774, t13060) # t13062: "cuda:0 f32[1, 2048, 18176]"
# t13070 = prims.mul(f1921, t13062) # t13070: "cuda:0 f32[1, 2048, 18176]"
# t13073 = prims.pow(t4776, 2.0) # t13073: "cuda:0 f32[1, 2048, 18176]"
# t13074 = prims.neg(t13073) # t13074: "cuda:0 f32[1, 2048, 18176]"
# t13075 = prims.exp(t13074) # t13075: "cuda:0 f32[1, 2048, 18176]"
# t13076 = prims.mul(1.1283791670955126, t13075) # t13076: "cuda:0 f32[1, 2048, 18176]"
# t13077 = prims.mul(t13076, t13070) # t13077: "cuda:0 f32[1, 2048, 18176]"
# t13081 = prims.div(t13077, f1919) # t13081: "cuda:0 f32[1, 2048, 18176]"
# t13085 = prims.add(t13061, t13081) # t13085: "cuda:0 f32[1, 2048, 18176]"
# t13086 = prims.convert_element_type(t13085, dtypes.bfloat16) # t13086: "cuda:0 bf16[1, 2048, 18176]"
del f1919, f1921, t13055, t4773
t13087 = torch.reshape(t13086, (-1, 18176)) # t13087: "cuda:0 bf16[2048, 18176]"
# t13087 = ltorch.reshape(t13086, (-1, 18176)) # t13087: "cuda:0 bf16[2048, 18176]"
# t13087 = prims.reshape(t13086, (2048, 18176)) # t13087: "cuda:0 bf16[2048, 18176]"
del t13086
t13091 = torch.permute(t13087, (1, 0)) # t13091: "cuda:0 bf16[18176, 2048]"
# t13091 = ltorch.permute(t13087, (1, 0)) # t13091: "cuda:0 bf16[18176, 2048]"
# t13091 = prims.transpose(t13087, (1, 0)) # t13091: "cuda:0 bf16[18176, 2048]"
t13093 = torch.matmul(t13091, t13092) # t13093: "cuda:0 bf16[18176, 4544]"
# t13093 = ltorch.matmul(t13091, t13092) # t13093: "cuda:0 bf16[18176, 4544]"
# t13093 = prims.matmul(t13091, t13092) # t13093: "cuda:0 bf16[18176, 4544]"
del t13091
t13088 = torch.matmul(t13087, t_transformer_h_29_mlp_fc_weight) # t13088: "cuda:0 bf16[2048, 4544]"
# t13088 = ltorch.matmul(t13087, t_transformer_h_29_mlp_fc_weight) # t13088: "cuda:0 bf16[2048, 4544]"
# t13088 = prims.matmul(t13087, t_transformer_h_29_mlp_fc_weight) # t13088: "cuda:0 bf16[2048, 4544]"
del t13087, t_transformer_h_29_mlp_fc_weight
(t13108, t13109, t13110) = cudnn_sdpa_bwd(t13107, t4757, t4760, t4710, None, f1910, b1911, t4761, t4762, t4763, t4764, scale=f1912, cat_grad_qkv=False)
del t13107, t4757, t4760, t4710, f1910, b1911, t4761, t4762, t4763, t4764, f1912
t13112 = torch_slice_prim_impl(t13109, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13112: "cuda:0 bf16[1, 71, 2048, 64]"
del t13109
t13116 = torch_slice_prim_impl(t13108, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13116: "cuda:0 bf16[1, 71, 2048, 64]"
del t13108
t13219 = torch.reshape(t13110, (1, 1, 71, 2048, 64)) # t13219: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13219 = ltorch.reshape(t13110, (1, 1, 71, 2048, 64)) # t13219: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13219 = prims.reshape(t13110, (1, 1, 71, 2048, 64)) # t13219: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t13110
[t13253] = nvFusion8(i1883, t13112, t13116, t13219, t61, t66)
# t13113 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t13113: "cuda:0 bf16[1, 71, 2048, 0]"
# t13114 = prims.pad(t13113, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t13114: "cuda:0 bf16[1, 71, 2048, 64]"
# t13117 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t13117: "cuda:0 bf16[1, 71, 2048, 0]"
# t13118 = prims.pad(t13117, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t13118: "cuda:0 bf16[1, 71, 2048, 64]"
# t13119 = prims.convert_element_type(t13112, dtypes.float32) # t13119: "cuda:0 f32[1, 71, 2048, 64]"
# t13123 = prims.mul(t66, t13119) # t13123: "cuda:0 f32[1, 71, 2048, 64]"
# t13126 = prims.convert_element_type(t13123, dtypes.bfloat16) # t13126: "cuda:0 bf16[1, 71, 2048, 64]"
# t13135 = prims.mul(t61, t13119) # t13135: "cuda:0 f32[1, 71, 2048, 64]"
# t13147 = prims.slice_prim(t13126, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t13147: "cuda:0 bf16[1, 71, 2048, 32]"
# t13148 = prims.slice_prim(t13126, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13148: "cuda:0 bf16[1, 71, 2048, 32]"
# t13149 = prims.convert_element_type(t13147, dtypes.float32) # t13149: "cuda:0 f32[1, 71, 2048, 32]"
# t13150 = prims.neg(t13149) # t13150: "cuda:0 f32[1, 71, 2048, 32]"
# t13151 = prims.convert_element_type(t13150, dtypes.bfloat16) # t13151: "cuda:0 bf16[1, 71, 2048, 32]"
# t13152 = prims.pad(t13151, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t13152: "cuda:0 bf16[1, 71, 2048, 64]"
# t13154 = prims.convert_element_type(t13152, dtypes.float32) # t13154: "cuda:0 f32[1, 71, 2048, 64]"
# t13155 = prims.add(t13135, t13154) # t13155: "cuda:0 f32[1, 71, 2048, 64]"
# t13157 = prims.pad(t13148, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t13157: "cuda:0 bf16[1, 71, 2048, 64]"
# t13159 = prims.convert_element_type(t13157, dtypes.float32) # t13159: "cuda:0 f32[1, 71, 2048, 64]"
# t13160 = prims.add(t13155, t13159) # t13160: "cuda:0 f32[1, 71, 2048, 64]"
# t13161 = prims.convert_element_type(t13160, dtypes.bfloat16) # t13161: "cuda:0 bf16[1, 71, 2048, 64]"
# t13162 = prims.pad(t13161, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t13162: "cuda:0 bf16[1, 71, 2048, 64]"
# t13163 = prims.convert_element_type(t13114, dtypes.float32) # t13163: "cuda:0 f32[1, 71, 2048, 64]"
# t13164 = prims.convert_element_type(t13162, dtypes.float32) # t13164: "cuda:0 f32[1, 71, 2048, 64]"
# t13165 = prims.add(t13163, t13164) # t13165: "cuda:0 f32[1, 71, 2048, 64]"
# t13166 = prims.convert_element_type(t13165, dtypes.bfloat16) # t13166: "cuda:0 bf16[1, 71, 2048, 64]"
# t13167 = prims.convert_element_type(t13116, dtypes.float32) # t13167: "cuda:0 f32[1, 71, 2048, 64]"
# t13171 = prims.mul(t66, t13167) # t13171: "cuda:0 f32[1, 71, 2048, 64]"
# t13174 = prims.convert_element_type(t13171, dtypes.bfloat16) # t13174: "cuda:0 bf16[1, 71, 2048, 64]"
# t13183 = prims.mul(t61, t13167) # t13183: "cuda:0 f32[1, 71, 2048, 64]"
# t13195 = prims.slice_prim(t13174, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t13195: "cuda:0 bf16[1, 71, 2048, 32]"
# t13196 = prims.slice_prim(t13174, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13196: "cuda:0 bf16[1, 71, 2048, 32]"
# t13197 = prims.convert_element_type(t13195, dtypes.float32) # t13197: "cuda:0 f32[1, 71, 2048, 32]"
# t13198 = prims.neg(t13197) # t13198: "cuda:0 f32[1, 71, 2048, 32]"
# t13199 = prims.convert_element_type(t13198, dtypes.bfloat16) # t13199: "cuda:0 bf16[1, 71, 2048, 32]"
# t13200 = prims.pad(t13199, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t13200: "cuda:0 bf16[1, 71, 2048, 64]"
# t13202 = prims.convert_element_type(t13200, dtypes.float32) # t13202: "cuda:0 f32[1, 71, 2048, 64]"
# t13203 = prims.add(t13183, t13202) # t13203: "cuda:0 f32[1, 71, 2048, 64]"
# t13205 = prims.pad(t13196, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t13205: "cuda:0 bf16[1, 71, 2048, 64]"
# t13207 = prims.convert_element_type(t13205, dtypes.float32) # t13207: "cuda:0 f32[1, 71, 2048, 64]"
# t13208 = prims.add(t13203, t13207) # t13208: "cuda:0 f32[1, 71, 2048, 64]"
# t13209 = prims.convert_element_type(t13208, dtypes.bfloat16) # t13209: "cuda:0 bf16[1, 71, 2048, 64]"
# t13210 = prims.pad(t13209, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t13210: "cuda:0 bf16[1, 71, 2048, 64]"
# t13211 = prims.convert_element_type(t13118, dtypes.float32) # t13211: "cuda:0 f32[1, 71, 2048, 64]"
# t13212 = prims.convert_element_type(t13210, dtypes.float32) # t13212: "cuda:0 f32[1, 71, 2048, 64]"
# t13213 = prims.add(t13211, t13212) # t13213: "cuda:0 f32[1, 71, 2048, 64]"
# t13214 = prims.convert_element_type(t13213, dtypes.bfloat16) # t13214: "cuda:0 bf16[1, 71, 2048, 64]"
# t13224 = prims.reshape(t13166, (1, 1, 71, 2048, 64)) # t13224: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13229 = prims.reshape(t13214, (1, 1, 71, 2048, 64)) # t13229: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13235 = prims.convert_element_type(t13219, dtypes.float32) # t13235: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t13236 = prims.sum(t13235, (0, 1, 2)) # t13236: "cuda:0 f32[2048, 64]"
# t13237 = prims.convert_element_type(t13236, dtypes.bfloat16) # t13237: "cuda:0 bf16[2048, 64]"
# t13238 = prims.broadcast_in_dim(t13237, [1, 1, 1, 2048, 64], [3, 4]) # t13238: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t13244 = prims.convert_element_type(t13224, dtypes.float32) # t13244: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t13245 = prims.sum(t13244, (0, 1, 2)) # t13245: "cuda:0 f32[2048, 64]"
# t13246 = prims.convert_element_type(t13245, dtypes.bfloat16) # t13246: "cuda:0 bf16[2048, 64]"
# t13247 = prims.broadcast_in_dim(t13246, [1, 1, 1, 2048, 64], [3, 4]) # t13247: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t13253 = prims.cat((t13229, t13247, t13238), i1883) # t13253: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1883, t13112, t13116, t13219
t13259 = torch.permute(t13253, (0, 3, 1, 2, 4)) # t13259: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t13259 = ltorch.permute(t13253, (0, 3, 1, 2, 4)) # t13259: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t13259 = prims.transpose(t13253, (0, 3, 1, 2, 4)) # t13259: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t13253
t13265 = torch.reshape(t13259, (1, 2048, 4672)) # t13265: "cuda:0 bf16[1, 2048, 4672]"
# t13265 = ltorch.reshape(t13259, (1, 2048, 4672)) # t13265: "cuda:0 bf16[1, 2048, 4672]"
# t13265 = prims.reshape(t13259, (1, 2048, 4672)) # t13265: "cuda:0 bf16[1, 2048, 4672]"
del t13259
t13266 = torch.reshape(t13265, (-1, 4672)) # t13266: "cuda:0 bf16[2048, 4672]"
# t13266 = ltorch.reshape(t13265, (-1, 4672)) # t13266: "cuda:0 bf16[2048, 4672]"
# t13266 = prims.reshape(t13265, (2048, 4672)) # t13266: "cuda:0 bf16[2048, 4672]"
del t13265
t13270 = torch.permute(t13266, (1, 0)) # t13270: "cuda:0 bf16[4672, 2048]"
# t13270 = ltorch.permute(t13266, (1, 0)) # t13270: "cuda:0 bf16[4672, 2048]"
# t13270 = prims.transpose(t13266, (1, 0)) # t13270: "cuda:0 bf16[4672, 2048]"
t13272 = torch.matmul(t13270, t13092) # t13272: "cuda:0 bf16[4672, 4544]"
# t13272 = ltorch.matmul(t13270, t13271) # t13272: "cuda:0 bf16[4672, 4544]"
# t13272 = prims.matmul(t13270, t13271) # t13272: "cuda:0 bf16[4672, 4544]"
del t13270, t13092
t13267 = torch.matmul(t13266, t_transformer_h_29_attn_attn_weight) # t13267: "cuda:0 bf16[2048, 4544]"
# t13267 = ltorch.matmul(t13266, t_transformer_h_29_attn_attn_weight) # t13267: "cuda:0 bf16[2048, 4544]"
# t13267 = prims.matmul(t13266, t_transformer_h_29_attn_attn_weight) # t13267: "cuda:0 bf16[2048, 4544]"
del t13266, t_transformer_h_29_attn_attn_weight
t13089 = torch.reshape(t13088, (1, 2048, 4544)) # t13089: "cuda:0 bf16[1, 2048, 4544]"
# t13089 = ltorch.reshape(t13088, (1, 2048, 4544)) # t13089: "cuda:0 bf16[1, 2048, 4544]"
# t13089 = prims.reshape(t13088, (1, 2048, 4544)) # t13089: "cuda:0 bf16[1, 2048, 4544]"
del t13088
t13268 = torch.reshape(t13267, (1, 2048, 4544)) # t13268: "cuda:0 bf16[1, 2048, 4544]"
# t13268 = ltorch.reshape(t13267, (1, 2048, 4544)) # t13268: "cuda:0 bf16[1, 2048, 4544]"
# t13268 = prims.reshape(t13267, (1, 2048, 4544)) # t13268: "cuda:0 bf16[1, 2048, 4544]"
del t13267
[t13281, t13287, t13329] = nvFusion9(i13309, t13046, t13089, t13268, t4479, t4611, t4632, t4647, t4652, t4658)
# t4638 = prims.convert_element_type(t4479, dtypes.float32) # t4638: "cuda:0 f32[1, 2048, 4544]"
# t4633 = prims.convert_element_type(t4632, dtypes.float32) # t4633: "cuda:0 f32[1, 2048, 4544]"
# t4634 = prims.convert_element_type(t4611, dtypes.float32) # t4634: "cuda:0 f32[1, 2048, 4544]"
# t4635 = prims.add(t4633, t4634) # t4635: "cuda:0 f32[1, 2048, 4544]"
# t4639 = prims.add(t4635, t4638) # t4639: "cuda:0 f32[1, 2048, 4544]"
# t4649 = prims.broadcast_in_dim(t4647, [1, 2048, 1], [0, 1]) # t4649: "cuda:0 f32[1, 2048, 1]"
# t4653 = prims.broadcast_in_dim(t4649, (1, 2048, 4544), (0, 1, 2)) # t4653: "cuda:0 f32[1, 2048, 4544]"
# t4655 = prims.sub(t4639, t4653) # t4655: "cuda:0 f32[1, 2048, 4544]"
# t4656 = prims.broadcast_in_dim(t4652, (1, 2048, 4544), (0, 1, 2)) # t4656: "cuda:0 f32[1, 2048, 4544]"
# t4657 = prims.mul(t4655, t4656) # t4657: "cuda:0 f32[1, 2048, 4544]"
# t4659 = prims.convert_element_type(t4658, dtypes.float32) # t4659: "cuda:0 f32[1, 2048, 4544]"
# t13326 = prims.convert_element_type(t13046, dtypes.float32) # t13326: "cuda:0 f32[1, 2048, 4544]"
# t13273 = prims.convert_element_type(t13089, dtypes.float32) # t13273: "cuda:0 f32[1, 2048, 4544]"
# t13274 = prims.convert_element_type(t13268, dtypes.float32) # t13274: "cuda:0 f32[1, 2048, 4544]"
# t13275 = prims.add(t13273, t13274) # t13275: "cuda:0 f32[1, 2048, 4544]"
# t13280 = prims.sum(t13275, (0, 1)) # t13280: "cuda:0 f32[4544]"
# t13281 = prims.convert_element_type(t13280, dtypes.bfloat16) # t13281: "cuda:0 bf16[4544]"
# t13282 = prims.mul(t4659, t13275) # t13282: "cuda:0 f32[1, 2048, 4544]"
# t13283 = prims.mul(t4657, t13275) # t13283: "cuda:0 f32[1, 2048, 4544]"
# t13286 = prims.sum(t13283, (0, 1)) # t13286: "cuda:0 f32[4544]"
# t13287 = prims.convert_element_type(t13286, dtypes.bfloat16) # t13287: "cuda:0 bf16[4544]"
# t13288 = prims.mul(t4656, t13282) # t13288: "cuda:0 f32[1, 2048, 4544]"
# t13289 = prims.mul(t4655, t13282) # t13289: "cuda:0 f32[1, 2048, 4544]"
# t13290 = prims.sum(t13289, (0, 2)) # t13290: "cuda:0 f32[2048]"
# t13291 = prims.broadcast_in_dim(t13290, [1, 2048, 1], [1]) # t13291: "cuda:0 f32[1, 2048, 1]"
# t13292 = prims.neg(t13288) # t13292: "cuda:0 f32[1, 2048, 4544]"
# t13294 = prims.sum(t13292, (0, 2)) # t13294: "cuda:0 f32[2048]"
# t13295 = prims.broadcast_in_dim(t13294, [1, 2048, 1], [1]) # t13295: "cuda:0 f32[1, 2048, 1]"
# t13296 = prims.mul(-0.5, t13291) # t13296: "cuda:0 f32[1, 2048, 1]"
# t13297 = prims.pow(t4652, 3.0) # t13297: "cuda:0 f32[1, 2048, 1]"
# t13298 = prims.mul(t13296, t13297) # t13298: "cuda:0 f32[1, 2048, 1]"
# t13300 = prims.sum(t13295, (0, 2)) # t13300: "cuda:0 f32[2048]"
# t13301 = prims.broadcast_in_dim(t13300, [1, 2048], [1]) # t13301: "cuda:0 f32[1, 2048]"
# t13302 = prims.sum(t13298, (0, 2)) # t13302: "cuda:0 f32[2048]"
# t13303 = prims.broadcast_in_dim(t13302, [1, 2048], [1]) # t13303: "cuda:0 f32[1, 2048]"
# t13306 = prims.broadcast_in_dim(t13301, [1, 2048, 1], [0, 1]) # t13306: "cuda:0 f32[1, 2048, 1]"
# t13307 = prims.broadcast_in_dim(t13306, (1, 2048, 4544), (0, 1, 2)) # t13307: "cuda:0 f32[1, 2048, 4544]"
# t13308 = prims.mul(0.00022007042253521127, t13307) # t13308: "cuda:0 f32[1, 2048, 4544]"
# t13310 = prims.broadcast_in_dim(t13303, [1, 2048, 1], [0, 1]) # t13310: "cuda:0 f32[1, 2048, 1]"
# t13311 = prims.broadcast_in_dim(t13310, (1, 2048, 4544), (0, 1, 2)) # t13311: "cuda:0 f32[1, 2048, 4544]"
# t13313 = prims.broadcast_in_dim(t4647, [1, 2048, 1], [0, 1]) # t13313: "cuda:0 f32[1, 2048, 1]"
# t13314 = prims.broadcast_in_dim(t13313, (1, 2048, 4544), (0, 1, 2)) # t13314: "cuda:0 f32[1, 2048, 4544]"
# t13315 = prims.mul(2.0, t13311) # t13315: "cuda:0 f32[1, 2048, 4544]"
# t13316 = prims.sub(t4639, t13314) # t13316: "cuda:0 f32[1, 2048, 4544]"
# t13317 = prims.mul(t13315, t13316) # t13317: "cuda:0 f32[1, 2048, 4544]"
# f13318 = prims.convert_element_type(i13309, float) # f13318: "float 4544.0"
# t13319 = prims.div(t13317, f13318) # t13319: "cuda:0 f32[1, 2048, 4544]"
# t13320 = prims.add(t13308, t13319) # t13320: "cuda:0 f32[1, 2048, 4544]"
# t13324 = prims.add(t13288, t13320) # t13324: "cuda:0 f32[1, 2048, 4544]"
# t13328 = prims.add(t13326, t13324) # t13328: "cuda:0 f32[1, 2048, 4544]"
# t13329 = prims.convert_element_type(t13328, dtypes.bfloat16) # t13329: "cuda:0 bf16[1, 2048, 4544]"
del i13309, t13046, t13089, t13268, t4479, t4611, t4632, t4647, t4652, t4658
t13336 = torch.reshape(t13329, (-1, 4544)) # t13336: "cuda:0 bf16[2048, 4544]"
# t13336 = ltorch.reshape(t13329, (-1, 4544)) # t13336: "cuda:0 bf16[2048, 4544]"
# t13336 = prims.reshape(t13329, (2048, 4544)) # t13336: "cuda:0 bf16[2048, 4544]"
t13340 = torch.permute(t13336, (1, 0)) # t13340: "cuda:0 bf16[4544, 2048]"
# t13340 = ltorch.permute(t13336, (1, 0)) # t13340: "cuda:0 bf16[4544, 2048]"
# t13340 = prims.transpose(t13336, (1, 0)) # t13340: "cuda:0 bf16[4544, 2048]"
t13337 = torch.matmul(t13336, t_transformer_h_28_mlp_proj_weight) # t13337: "cuda:0 bf16[2048, 18176]"
# t13337 = ltorch.matmul(t13336, t_transformer_h_28_mlp_proj_weight) # t13337: "cuda:0 bf16[2048, 18176]"
# t13337 = prims.matmul(t13336, t_transformer_h_28_mlp_proj_weight) # t13337: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_28_mlp_proj_weight
t13342 = torch.matmul(t13340, t13341) # t13342: "cuda:0 bf16[4544, 18176]"
# t13342 = ltorch.matmul(t13340, t13341) # t13342: "cuda:0 bf16[4544, 18176]"
# t13342 = prims.matmul(t13340, t13341) # t13342: "cuda:0 bf16[4544, 18176]"
del t13341
t13378 = torch.matmul(t13336, t_transformer_h_28_attn_proj_weight) # t13378: "cuda:0 bf16[2048, 4544]"
# t13378 = ltorch.matmul(t13377, t_transformer_h_28_attn_proj_weight) # t13378: "cuda:0 bf16[2048, 4544]"
# t13378 = prims.matmul(t13377, t_transformer_h_28_attn_proj_weight) # t13378: "cuda:0 bf16[2048, 4544]"
del t13336, t_transformer_h_28_attn_proj_weight
t13383 = torch.matmul(t13340, t13382) # t13383: "cuda:0 bf16[4544, 4544]"
# t13383 = ltorch.matmul(t13381, t13382) # t13383: "cuda:0 bf16[4544, 4544]"
# t13383 = prims.matmul(t13381, t13382) # t13383: "cuda:0 bf16[4544, 4544]"
del t13340, t13382
t13338 = torch.reshape(t13337, (1, 2048, 18176)) # t13338: "cuda:0 bf16[1, 2048, 18176]"
# t13338 = ltorch.reshape(t13337, (1, 2048, 18176)) # t13338: "cuda:0 bf16[1, 2048, 18176]"
# t13338 = prims.reshape(t13337, (1, 2048, 18176)) # t13338: "cuda:0 bf16[1, 2048, 18176]"
del t13337
t13379 = torch.reshape(t13378, (1, 2048, 4544)) # t13379: "cuda:0 bf16[1, 2048, 4544]"
# t13379 = ltorch.reshape(t13378, (1, 2048, 4544)) # t13379: "cuda:0 bf16[1, 2048, 4544]"
# t13379 = prims.reshape(t13378, (1, 2048, 4544)) # t13379: "cuda:0 bf16[1, 2048, 4544]"
del t13378
t13387 = torch.reshape(t13379, (1, 2048, 71, 64)) # t13387: "cuda:0 bf16[1, 2048, 71, 64]"
# t13387 = ltorch.reshape(t13379, (1, 2048, 71, 64)) # t13387: "cuda:0 bf16[1, 2048, 71, 64]"
# t13387 = prims.reshape(t13379, (1, 2048, 71, 64)) # t13387: "cuda:0 bf16[1, 2048, 71, 64]"
del t13379
t13390 = torch.permute(t13387, (0, 2, 1, 3)) # t13390: "cuda:0 bf16[1, 71, 2048, 64]"
# t13390 = ltorch.permute(t13387, (0, 2, 1, 3)) # t13390: "cuda:0 bf16[1, 71, 2048, 64]"
# t13390 = prims.transpose(t13387, (0, 2, 1, 3)) # t13390: "cuda:0 bf16[1, 71, 2048, 64]"
del t13387
[t13369] = nvFusion10(f1855, f1857, t13338, t4612)
# t4613 = prims.convert_element_type(t4612, dtypes.float32) # t4613: "cuda:0 f32[1, 2048, 18176]"
# t4615 = prims.div(t4613, 1.4142135623730951) # t4615: "cuda:0 f32[1, 2048, 18176]"
# t4618 = prims.erf(t4615) # t4618: "cuda:0 f32[1, 2048, 18176]"
# t4622 = prims.mul(0.5, t4618) # t4622: "cuda:0 f32[1, 2048, 18176]"
# t4626 = prims.add(0.5, t4622) # t4626: "cuda:0 f32[1, 2048, 18176]"
# t13343 = prims.convert_element_type(t13338, dtypes.float32) # t13343: "cuda:0 f32[1, 2048, 18176]"
# t13344 = prims.mul(t4626, t13343) # t13344: "cuda:0 f32[1, 2048, 18176]"
# t13345 = prims.mul(t4613, t13343) # t13345: "cuda:0 f32[1, 2048, 18176]"
# t13353 = prims.mul(f1857, t13345) # t13353: "cuda:0 f32[1, 2048, 18176]"
# t13356 = prims.pow(t4615, 2.0) # t13356: "cuda:0 f32[1, 2048, 18176]"
# t13357 = prims.neg(t13356) # t13357: "cuda:0 f32[1, 2048, 18176]"
# t13358 = prims.exp(t13357) # t13358: "cuda:0 f32[1, 2048, 18176]"
# t13359 = prims.mul(1.1283791670955126, t13358) # t13359: "cuda:0 f32[1, 2048, 18176]"
# t13360 = prims.mul(t13359, t13353) # t13360: "cuda:0 f32[1, 2048, 18176]"
# t13364 = prims.div(t13360, f1855) # t13364: "cuda:0 f32[1, 2048, 18176]"
# t13368 = prims.add(t13344, t13364) # t13368: "cuda:0 f32[1, 2048, 18176]"
# t13369 = prims.convert_element_type(t13368, dtypes.bfloat16) # t13369: "cuda:0 bf16[1, 2048, 18176]"
del f1855, f1857, t13338, t4612
t13370 = torch.reshape(t13369, (-1, 18176)) # t13370: "cuda:0 bf16[2048, 18176]"
# t13370 = ltorch.reshape(t13369, (-1, 18176)) # t13370: "cuda:0 bf16[2048, 18176]"
# t13370 = prims.reshape(t13369, (2048, 18176)) # t13370: "cuda:0 bf16[2048, 18176]"
del t13369
t13374 = torch.permute(t13370, (1, 0)) # t13374: "cuda:0 bf16[18176, 2048]"
# t13374 = ltorch.permute(t13370, (1, 0)) # t13374: "cuda:0 bf16[18176, 2048]"
# t13374 = prims.transpose(t13370, (1, 0)) # t13374: "cuda:0 bf16[18176, 2048]"
t13376 = torch.matmul(t13374, t13375) # t13376: "cuda:0 bf16[18176, 4544]"
# t13376 = ltorch.matmul(t13374, t13375) # t13376: "cuda:0 bf16[18176, 4544]"
# t13376 = prims.matmul(t13374, t13375) # t13376: "cuda:0 bf16[18176, 4544]"
del t13374
t13371 = torch.matmul(t13370, t_transformer_h_28_mlp_fc_weight) # t13371: "cuda:0 bf16[2048, 4544]"
# t13371 = ltorch.matmul(t13370, t_transformer_h_28_mlp_fc_weight) # t13371: "cuda:0 bf16[2048, 4544]"
# t13371 = prims.matmul(t13370, t_transformer_h_28_mlp_fc_weight) # t13371: "cuda:0 bf16[2048, 4544]"
del t13370, t_transformer_h_28_mlp_fc_weight
(t13391, t13392, t13393) = cudnn_sdpa_bwd(t13390, t4596, t4599, t4549, None, f1846, b1847, t4600, t4601, t4602, t4603, scale=f1848, cat_grad_qkv=False)
del t13390, t4596, t4599, t4549, f1846, b1847, t4600, t4601, t4602, t4603, f1848
t13395 = torch_slice_prim_impl(t13392, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13395: "cuda:0 bf16[1, 71, 2048, 64]"
del t13392
t13399 = torch_slice_prim_impl(t13391, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13399: "cuda:0 bf16[1, 71, 2048, 64]"
del t13391
t13502 = torch.reshape(t13393, (1, 1, 71, 2048, 64)) # t13502: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13502 = ltorch.reshape(t13393, (1, 1, 71, 2048, 64)) # t13502: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13502 = prims.reshape(t13393, (1, 1, 71, 2048, 64)) # t13502: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t13393
[t13536] = nvFusion11(i1819, t13395, t13399, t13502, t61, t66)
# t13396 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t13396: "cuda:0 bf16[1, 71, 2048, 0]"
# t13397 = prims.pad(t13396, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t13397: "cuda:0 bf16[1, 71, 2048, 64]"
# t13400 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t13400: "cuda:0 bf16[1, 71, 2048, 0]"
# t13401 = prims.pad(t13400, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t13401: "cuda:0 bf16[1, 71, 2048, 64]"
# t13402 = prims.convert_element_type(t13395, dtypes.float32) # t13402: "cuda:0 f32[1, 71, 2048, 64]"
# t13406 = prims.mul(t66, t13402) # t13406: "cuda:0 f32[1, 71, 2048, 64]"
# t13409 = prims.convert_element_type(t13406, dtypes.bfloat16) # t13409: "cuda:0 bf16[1, 71, 2048, 64]"
# t13418 = prims.mul(t61, t13402) # t13418: "cuda:0 f32[1, 71, 2048, 64]"
# t13430 = prims.slice_prim(t13409, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t13430: "cuda:0 bf16[1, 71, 2048, 32]"
# t13431 = prims.slice_prim(t13409, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13431: "cuda:0 bf16[1, 71, 2048, 32]"
# t13432 = prims.convert_element_type(t13430, dtypes.float32) # t13432: "cuda:0 f32[1, 71, 2048, 32]"
# t13433 = prims.neg(t13432) # t13433: "cuda:0 f32[1, 71, 2048, 32]"
# t13434 = prims.convert_element_type(t13433, dtypes.bfloat16) # t13434: "cuda:0 bf16[1, 71, 2048, 32]"
# t13435 = prims.pad(t13434, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t13435: "cuda:0 bf16[1, 71, 2048, 64]"
# t13437 = prims.convert_element_type(t13435, dtypes.float32) # t13437: "cuda:0 f32[1, 71, 2048, 64]"
# t13438 = prims.add(t13418, t13437) # t13438: "cuda:0 f32[1, 71, 2048, 64]"
# t13440 = prims.pad(t13431, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t13440: "cuda:0 bf16[1, 71, 2048, 64]"
# t13442 = prims.convert_element_type(t13440, dtypes.float32) # t13442: "cuda:0 f32[1, 71, 2048, 64]"
# t13443 = prims.add(t13438, t13442) # t13443: "cuda:0 f32[1, 71, 2048, 64]"
# t13444 = prims.convert_element_type(t13443, dtypes.bfloat16) # t13444: "cuda:0 bf16[1, 71, 2048, 64]"
# t13445 = prims.pad(t13444, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t13445: "cuda:0 bf16[1, 71, 2048, 64]"
# t13446 = prims.convert_element_type(t13397, dtypes.float32) # t13446: "cuda:0 f32[1, 71, 2048, 64]"
# t13447 = prims.convert_element_type(t13445, dtypes.float32) # t13447: "cuda:0 f32[1, 71, 2048, 64]"
# t13448 = prims.add(t13446, t13447) # t13448: "cuda:0 f32[1, 71, 2048, 64]"
# t13449 = prims.convert_element_type(t13448, dtypes.bfloat16) # t13449: "cuda:0 bf16[1, 71, 2048, 64]"
# t13450 = prims.convert_element_type(t13399, dtypes.float32) # t13450: "cuda:0 f32[1, 71, 2048, 64]"
# t13454 = prims.mul(t66, t13450) # t13454: "cuda:0 f32[1, 71, 2048, 64]"
# t13457 = prims.convert_element_type(t13454, dtypes.bfloat16) # t13457: "cuda:0 bf16[1, 71, 2048, 64]"
# t13466 = prims.mul(t61, t13450) # t13466: "cuda:0 f32[1, 71, 2048, 64]"
# t13478 = prims.slice_prim(t13457, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t13478: "cuda:0 bf16[1, 71, 2048, 32]"
# t13479 = prims.slice_prim(t13457, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13479: "cuda:0 bf16[1, 71, 2048, 32]"
# t13480 = prims.convert_element_type(t13478, dtypes.float32) # t13480: "cuda:0 f32[1, 71, 2048, 32]"
# t13481 = prims.neg(t13480) # t13481: "cuda:0 f32[1, 71, 2048, 32]"
# t13482 = prims.convert_element_type(t13481, dtypes.bfloat16) # t13482: "cuda:0 bf16[1, 71, 2048, 32]"
# t13483 = prims.pad(t13482, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t13483: "cuda:0 bf16[1, 71, 2048, 64]"
# t13485 = prims.convert_element_type(t13483, dtypes.float32) # t13485: "cuda:0 f32[1, 71, 2048, 64]"
# t13486 = prims.add(t13466, t13485) # t13486: "cuda:0 f32[1, 71, 2048, 64]"
# t13488 = prims.pad(t13479, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t13488: "cuda:0 bf16[1, 71, 2048, 64]"
# t13490 = prims.convert_element_type(t13488, dtypes.float32) # t13490: "cuda:0 f32[1, 71, 2048, 64]"
# t13491 = prims.add(t13486, t13490) # t13491: "cuda:0 f32[1, 71, 2048, 64]"
# t13492 = prims.convert_element_type(t13491, dtypes.bfloat16) # t13492: "cuda:0 bf16[1, 71, 2048, 64]"
# t13493 = prims.pad(t13492, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t13493: "cuda:0 bf16[1, 71, 2048, 64]"
# t13494 = prims.convert_element_type(t13401, dtypes.float32) # t13494: "cuda:0 f32[1, 71, 2048, 64]"
# t13495 = prims.convert_element_type(t13493, dtypes.float32) # t13495: "cuda:0 f32[1, 71, 2048, 64]"
# t13496 = prims.add(t13494, t13495) # t13496: "cuda:0 f32[1, 71, 2048, 64]"
# t13497 = prims.convert_element_type(t13496, dtypes.bfloat16) # t13497: "cuda:0 bf16[1, 71, 2048, 64]"
# t13507 = prims.reshape(t13449, (1, 1, 71, 2048, 64)) # t13507: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13512 = prims.reshape(t13497, (1, 1, 71, 2048, 64)) # t13512: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13518 = prims.convert_element_type(t13502, dtypes.float32) # t13518: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t13519 = prims.sum(t13518, (0, 1, 2)) # t13519: "cuda:0 f32[2048, 64]"
# t13520 = prims.convert_element_type(t13519, dtypes.bfloat16) # t13520: "cuda:0 bf16[2048, 64]"
# t13521 = prims.broadcast_in_dim(t13520, [1, 1, 1, 2048, 64], [3, 4]) # t13521: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t13527 = prims.convert_element_type(t13507, dtypes.float32) # t13527: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t13528 = prims.sum(t13527, (0, 1, 2)) # t13528: "cuda:0 f32[2048, 64]"
# t13529 = prims.convert_element_type(t13528, dtypes.bfloat16) # t13529: "cuda:0 bf16[2048, 64]"
# t13530 = prims.broadcast_in_dim(t13529, [1, 1, 1, 2048, 64], [3, 4]) # t13530: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t13536 = prims.cat((t13512, t13530, t13521), i1819) # t13536: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1819, t13395, t13399, t13502
t13542 = torch.permute(t13536, (0, 3, 1, 2, 4)) # t13542: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t13542 = ltorch.permute(t13536, (0, 3, 1, 2, 4)) # t13542: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t13542 = prims.transpose(t13536, (0, 3, 1, 2, 4)) # t13542: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t13536
t13548 = torch.reshape(t13542, (1, 2048, 4672)) # t13548: "cuda:0 bf16[1, 2048, 4672]"
# t13548 = ltorch.reshape(t13542, (1, 2048, 4672)) # t13548: "cuda:0 bf16[1, 2048, 4672]"
# t13548 = prims.reshape(t13542, (1, 2048, 4672)) # t13548: "cuda:0 bf16[1, 2048, 4672]"
del t13542
t13549 = torch.reshape(t13548, (-1, 4672)) # t13549: "cuda:0 bf16[2048, 4672]"
# t13549 = ltorch.reshape(t13548, (-1, 4672)) # t13549: "cuda:0 bf16[2048, 4672]"
# t13549 = prims.reshape(t13548, (2048, 4672)) # t13549: "cuda:0 bf16[2048, 4672]"
del t13548
t13553 = torch.permute(t13549, (1, 0)) # t13553: "cuda:0 bf16[4672, 2048]"
# t13553 = ltorch.permute(t13549, (1, 0)) # t13553: "cuda:0 bf16[4672, 2048]"
# t13553 = prims.transpose(t13549, (1, 0)) # t13553: "cuda:0 bf16[4672, 2048]"
t13555 = torch.matmul(t13553, t13375) # t13555: "cuda:0 bf16[4672, 4544]"
# t13555 = ltorch.matmul(t13553, t13554) # t13555: "cuda:0 bf16[4672, 4544]"
# t13555 = prims.matmul(t13553, t13554) # t13555: "cuda:0 bf16[4672, 4544]"
del t13553, t13375
t13550 = torch.matmul(t13549, t_transformer_h_28_attn_attn_weight) # t13550: "cuda:0 bf16[2048, 4544]"
# t13550 = ltorch.matmul(t13549, t_transformer_h_28_attn_attn_weight) # t13550: "cuda:0 bf16[2048, 4544]"
# t13550 = prims.matmul(t13549, t_transformer_h_28_attn_attn_weight) # t13550: "cuda:0 bf16[2048, 4544]"
del t13549, t_transformer_h_28_attn_attn_weight
t13372 = torch.reshape(t13371, (1, 2048, 4544)) # t13372: "cuda:0 bf16[1, 2048, 4544]"
# t13372 = ltorch.reshape(t13371, (1, 2048, 4544)) # t13372: "cuda:0 bf16[1, 2048, 4544]"
# t13372 = prims.reshape(t13371, (1, 2048, 4544)) # t13372: "cuda:0 bf16[1, 2048, 4544]"
del t13371
t13551 = torch.reshape(t13550, (1, 2048, 4544)) # t13551: "cuda:0 bf16[1, 2048, 4544]"
# t13551 = ltorch.reshape(t13550, (1, 2048, 4544)) # t13551: "cuda:0 bf16[1, 2048, 4544]"
# t13551 = prims.reshape(t13550, (1, 2048, 4544)) # t13551: "cuda:0 bf16[1, 2048, 4544]"
del t13550
[t13564, t13570, t13612] = nvFusion12(i13592, t13329, t13372, t13551, t4318, t4450, t4471, t4486, t4491, t4497)
# t4477 = prims.convert_element_type(t4318, dtypes.float32) # t4477: "cuda:0 f32[1, 2048, 4544]"
# t4472 = prims.convert_element_type(t4471, dtypes.float32) # t4472: "cuda:0 f32[1, 2048, 4544]"
# t4473 = prims.convert_element_type(t4450, dtypes.float32) # t4473: "cuda:0 f32[1, 2048, 4544]"
# t4474 = prims.add(t4472, t4473) # t4474: "cuda:0 f32[1, 2048, 4544]"
# t4478 = prims.add(t4474, t4477) # t4478: "cuda:0 f32[1, 2048, 4544]"
# t4488 = prims.broadcast_in_dim(t4486, [1, 2048, 1], [0, 1]) # t4488: "cuda:0 f32[1, 2048, 1]"
# t4492 = prims.broadcast_in_dim(t4488, (1, 2048, 4544), (0, 1, 2)) # t4492: "cuda:0 f32[1, 2048, 4544]"
# t4494 = prims.sub(t4478, t4492) # t4494: "cuda:0 f32[1, 2048, 4544]"
# t4495 = prims.broadcast_in_dim(t4491, (1, 2048, 4544), (0, 1, 2)) # t4495: "cuda:0 f32[1, 2048, 4544]"
# t4496 = prims.mul(t4494, t4495) # t4496: "cuda:0 f32[1, 2048, 4544]"
# t4498 = prims.convert_element_type(t4497, dtypes.float32) # t4498: "cuda:0 f32[1, 2048, 4544]"
# t13609 = prims.convert_element_type(t13329, dtypes.float32) # t13609: "cuda:0 f32[1, 2048, 4544]"
# t13556 = prims.convert_element_type(t13372, dtypes.float32) # t13556: "cuda:0 f32[1, 2048, 4544]"
# t13557 = prims.convert_element_type(t13551, dtypes.float32) # t13557: "cuda:0 f32[1, 2048, 4544]"
# t13558 = prims.add(t13556, t13557) # t13558: "cuda:0 f32[1, 2048, 4544]"
# t13563 = prims.sum(t13558, (0, 1)) # t13563: "cuda:0 f32[4544]"
# t13564 = prims.convert_element_type(t13563, dtypes.bfloat16) # t13564: "cuda:0 bf16[4544]"
# t13565 = prims.mul(t4498, t13558) # t13565: "cuda:0 f32[1, 2048, 4544]"
# t13566 = prims.mul(t4496, t13558) # t13566: "cuda:0 f32[1, 2048, 4544]"
# t13569 = prims.sum(t13566, (0, 1)) # t13569: "cuda:0 f32[4544]"
# t13570 = prims.convert_element_type(t13569, dtypes.bfloat16) # t13570: "cuda:0 bf16[4544]"
# t13571 = prims.mul(t4495, t13565) # t13571: "cuda:0 f32[1, 2048, 4544]"
# t13572 = prims.mul(t4494, t13565) # t13572: "cuda:0 f32[1, 2048, 4544]"
# t13573 = prims.sum(t13572, (0, 2)) # t13573: "cuda:0 f32[2048]"
# t13574 = prims.broadcast_in_dim(t13573, [1, 2048, 1], [1]) # t13574: "cuda:0 f32[1, 2048, 1]"
# t13575 = prims.neg(t13571) # t13575: "cuda:0 f32[1, 2048, 4544]"
# t13577 = prims.sum(t13575, (0, 2)) # t13577: "cuda:0 f32[2048]"
# t13578 = prims.broadcast_in_dim(t13577, [1, 2048, 1], [1]) # t13578: "cuda:0 f32[1, 2048, 1]"
# t13579 = prims.mul(-0.5, t13574) # t13579: "cuda:0 f32[1, 2048, 1]"
# t13580 = prims.pow(t4491, 3.0) # t13580: "cuda:0 f32[1, 2048, 1]"
# t13581 = prims.mul(t13579, t13580) # t13581: "cuda:0 f32[1, 2048, 1]"
# t13583 = prims.sum(t13578, (0, 2)) # t13583: "cuda:0 f32[2048]"
# t13584 = prims.broadcast_in_dim(t13583, [1, 2048], [1]) # t13584: "cuda:0 f32[1, 2048]"
# t13585 = prims.sum(t13581, (0, 2)) # t13585: "cuda:0 f32[2048]"
# t13586 = prims.broadcast_in_dim(t13585, [1, 2048], [1]) # t13586: "cuda:0 f32[1, 2048]"
# t13589 = prims.broadcast_in_dim(t13584, [1, 2048, 1], [0, 1]) # t13589: "cuda:0 f32[1, 2048, 1]"
# t13590 = prims.broadcast_in_dim(t13589, (1, 2048, 4544), (0, 1, 2)) # t13590: "cuda:0 f32[1, 2048, 4544]"
# t13591 = prims.mul(0.00022007042253521127, t13590) # t13591: "cuda:0 f32[1, 2048, 4544]"
# t13593 = prims.broadcast_in_dim(t13586, [1, 2048, 1], [0, 1]) # t13593: "cuda:0 f32[1, 2048, 1]"
# t13594 = prims.broadcast_in_dim(t13593, (1, 2048, 4544), (0, 1, 2)) # t13594: "cuda:0 f32[1, 2048, 4544]"
# t13596 = prims.broadcast_in_dim(t4486, [1, 2048, 1], [0, 1]) # t13596: "cuda:0 f32[1, 2048, 1]"
# t13597 = prims.broadcast_in_dim(t13596, (1, 2048, 4544), (0, 1, 2)) # t13597: "cuda:0 f32[1, 2048, 4544]"
# t13598 = prims.mul(2.0, t13594) # t13598: "cuda:0 f32[1, 2048, 4544]"
# t13599 = prims.sub(t4478, t13597) # t13599: "cuda:0 f32[1, 2048, 4544]"
# t13600 = prims.mul(t13598, t13599) # t13600: "cuda:0 f32[1, 2048, 4544]"
# f13601 = prims.convert_element_type(i13592, float) # f13601: "float 4544.0"
# t13602 = prims.div(t13600, f13601) # t13602: "cuda:0 f32[1, 2048, 4544]"
# t13603 = prims.add(t13591, t13602) # t13603: "cuda:0 f32[1, 2048, 4544]"
# t13607 = prims.add(t13571, t13603) # t13607: "cuda:0 f32[1, 2048, 4544]"
# t13611 = prims.add(t13609, t13607) # t13611: "cuda:0 f32[1, 2048, 4544]"
# t13612 = prims.convert_element_type(t13611, dtypes.bfloat16) # t13612: "cuda:0 bf16[1, 2048, 4544]"
del i13592, t13329, t13372, t13551, t4318, t4450, t4471, t4486, t4491, t4497
t13619 = torch.reshape(t13612, (-1, 4544)) # t13619: "cuda:0 bf16[2048, 4544]"
# t13619 = ltorch.reshape(t13612, (-1, 4544)) # t13619: "cuda:0 bf16[2048, 4544]"
# t13619 = prims.reshape(t13612, (2048, 4544)) # t13619: "cuda:0 bf16[2048, 4544]"
t13623 = torch.permute(t13619, (1, 0)) # t13623: "cuda:0 bf16[4544, 2048]"
# t13623 = ltorch.permute(t13619, (1, 0)) # t13623: "cuda:0 bf16[4544, 2048]"
# t13623 = prims.transpose(t13619, (1, 0)) # t13623: "cuda:0 bf16[4544, 2048]"
t13625 = torch.matmul(t13623, t13624) # t13625: "cuda:0 bf16[4544, 18176]"
# t13625 = ltorch.matmul(t13623, t13624) # t13625: "cuda:0 bf16[4544, 18176]"
# t13625 = prims.matmul(t13623, t13624) # t13625: "cuda:0 bf16[4544, 18176]"
del t13624
t13661 = torch.matmul(t13619, t_transformer_h_27_attn_proj_weight) # t13661: "cuda:0 bf16[2048, 4544]"
# t13661 = ltorch.matmul(t13660, t_transformer_h_27_attn_proj_weight) # t13661: "cuda:0 bf16[2048, 4544]"
# t13661 = prims.matmul(t13660, t_transformer_h_27_attn_proj_weight) # t13661: "cuda:0 bf16[2048, 4544]"
del t_transformer_h_27_attn_proj_weight
t13666 = torch.matmul(t13623, t13665) # t13666: "cuda:0 bf16[4544, 4544]"
# t13666 = ltorch.matmul(t13664, t13665) # t13666: "cuda:0 bf16[4544, 4544]"
# t13666 = prims.matmul(t13664, t13665) # t13666: "cuda:0 bf16[4544, 4544]"
del t13623, t13665
t13620 = torch.matmul(t13619, t_transformer_h_27_mlp_proj_weight) # t13620: "cuda:0 bf16[2048, 18176]"
# t13620 = ltorch.matmul(t13619, t_transformer_h_27_mlp_proj_weight) # t13620: "cuda:0 bf16[2048, 18176]"
# t13620 = prims.matmul(t13619, t_transformer_h_27_mlp_proj_weight) # t13620: "cuda:0 bf16[2048, 18176]"
del t13619, t_transformer_h_27_mlp_proj_weight
t13662 = torch.reshape(t13661, (1, 2048, 4544)) # t13662: "cuda:0 bf16[1, 2048, 4544]"
# t13662 = ltorch.reshape(t13661, (1, 2048, 4544)) # t13662: "cuda:0 bf16[1, 2048, 4544]"
# t13662 = prims.reshape(t13661, (1, 2048, 4544)) # t13662: "cuda:0 bf16[1, 2048, 4544]"
del t13661
t13670 = torch.reshape(t13662, (1, 2048, 71, 64)) # t13670: "cuda:0 bf16[1, 2048, 71, 64]"
# t13670 = ltorch.reshape(t13662, (1, 2048, 71, 64)) # t13670: "cuda:0 bf16[1, 2048, 71, 64]"
# t13670 = prims.reshape(t13662, (1, 2048, 71, 64)) # t13670: "cuda:0 bf16[1, 2048, 71, 64]"
del t13662
t13673 = torch.permute(t13670, (0, 2, 1, 3)) # t13673: "cuda:0 bf16[1, 71, 2048, 64]"
# t13673 = ltorch.permute(t13670, (0, 2, 1, 3)) # t13673: "cuda:0 bf16[1, 71, 2048, 64]"
# t13673 = prims.transpose(t13670, (0, 2, 1, 3)) # t13673: "cuda:0 bf16[1, 71, 2048, 64]"
del t13670
t13621 = torch.reshape(t13620, (1, 2048, 18176)) # t13621: "cuda:0 bf16[1, 2048, 18176]"
# t13621 = ltorch.reshape(t13620, (1, 2048, 18176)) # t13621: "cuda:0 bf16[1, 2048, 18176]"
# t13621 = prims.reshape(t13620, (1, 2048, 18176)) # t13621: "cuda:0 bf16[1, 2048, 18176]"
del t13620
[t13652] = nvFusion13(f1791, f1793, t13621, t4451)
# t4452 = prims.convert_element_type(t4451, dtypes.float32) # t4452: "cuda:0 f32[1, 2048, 18176]"
# t4454 = prims.div(t4452, 1.4142135623730951) # t4454: "cuda:0 f32[1, 2048, 18176]"
# t4457 = prims.erf(t4454) # t4457: "cuda:0 f32[1, 2048, 18176]"
# t4461 = prims.mul(0.5, t4457) # t4461: "cuda:0 f32[1, 2048, 18176]"
# t4465 = prims.add(0.5, t4461) # t4465: "cuda:0 f32[1, 2048, 18176]"
# t13626 = prims.convert_element_type(t13621, dtypes.float32) # t13626: "cuda:0 f32[1, 2048, 18176]"
# t13627 = prims.mul(t4465, t13626) # t13627: "cuda:0 f32[1, 2048, 18176]"
# t13628 = prims.mul(t4452, t13626) # t13628: "cuda:0 f32[1, 2048, 18176]"
# t13636 = prims.mul(f1793, t13628) # t13636: "cuda:0 f32[1, 2048, 18176]"
# t13639 = prims.pow(t4454, 2.0) # t13639: "cuda:0 f32[1, 2048, 18176]"
# t13640 = prims.neg(t13639) # t13640: "cuda:0 f32[1, 2048, 18176]"
# t13641 = prims.exp(t13640) # t13641: "cuda:0 f32[1, 2048, 18176]"
# t13642 = prims.mul(1.1283791670955126, t13641) # t13642: "cuda:0 f32[1, 2048, 18176]"
# t13643 = prims.mul(t13642, t13636) # t13643: "cuda:0 f32[1, 2048, 18176]"
# t13647 = prims.div(t13643, f1791) # t13647: "cuda:0 f32[1, 2048, 18176]"
# t13651 = prims.add(t13627, t13647) # t13651: "cuda:0 f32[1, 2048, 18176]"
# t13652 = prims.convert_element_type(t13651, dtypes.bfloat16) # t13652: "cuda:0 bf16[1, 2048, 18176]"
del f1791, f1793, t13621, t4451
t13653 = torch.reshape(t13652, (-1, 18176)) # t13653: "cuda:0 bf16[2048, 18176]"
# t13653 = ltorch.reshape(t13652, (-1, 18176)) # t13653: "cuda:0 bf16[2048, 18176]"
# t13653 = prims.reshape(t13652, (2048, 18176)) # t13653: "cuda:0 bf16[2048, 18176]"
del t13652
t13657 = torch.permute(t13653, (1, 0)) # t13657: "cuda:0 bf16[18176, 2048]"
# t13657 = ltorch.permute(t13653, (1, 0)) # t13657: "cuda:0 bf16[18176, 2048]"
# t13657 = prims.transpose(t13653, (1, 0)) # t13657: "cuda:0 bf16[18176, 2048]"
(t13674, t13675, t13676) = cudnn_sdpa_bwd(t13673, t4435, t4438, t4388, None, f1782, b1783, t4439, t4440, t4441, t4442, scale=f1784, cat_grad_qkv=False)
del t13673, t4435, t4438, t4388, f1782, b1783, t4439, t4440, t4441, t4442, f1784
t13659 = torch.matmul(t13657, t13658) # t13659: "cuda:0 bf16[18176, 4544]"
# t13659 = ltorch.matmul(t13657, t13658) # t13659: "cuda:0 bf16[18176, 4544]"
# t13659 = prims.matmul(t13657, t13658) # t13659: "cuda:0 bf16[18176, 4544]"
del t13657
t13654 = torch.matmul(t13653, t_transformer_h_27_mlp_fc_weight) # t13654: "cuda:0 bf16[2048, 4544]"
# t13654 = ltorch.matmul(t13653, t_transformer_h_27_mlp_fc_weight) # t13654: "cuda:0 bf16[2048, 4544]"
# t13654 = prims.matmul(t13653, t_transformer_h_27_mlp_fc_weight) # t13654: "cuda:0 bf16[2048, 4544]"
del t13653, t_transformer_h_27_mlp_fc_weight
t13678 = torch_slice_prim_impl(t13675, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13678: "cuda:0 bf16[1, 71, 2048, 64]"
del t13675
t13682 = torch_slice_prim_impl(t13674, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13682: "cuda:0 bf16[1, 71, 2048, 64]"
del t13674
t13785 = torch.reshape(t13676, (1, 1, 71, 2048, 64)) # t13785: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13785 = ltorch.reshape(t13676, (1, 1, 71, 2048, 64)) # t13785: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13785 = prims.reshape(t13676, (1, 1, 71, 2048, 64)) # t13785: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t13676
[t13819] = nvFusion14(i1755, t13678, t13682, t13785, t61, t66)
# t13679 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t13679: "cuda:0 bf16[1, 71, 2048, 0]"
# t13680 = prims.pad(t13679, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t13680: "cuda:0 bf16[1, 71, 2048, 64]"
# t13683 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t13683: "cuda:0 bf16[1, 71, 2048, 0]"
# t13684 = prims.pad(t13683, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t13684: "cuda:0 bf16[1, 71, 2048, 64]"
# t13685 = prims.convert_element_type(t13678, dtypes.float32) # t13685: "cuda:0 f32[1, 71, 2048, 64]"
# t13689 = prims.mul(t66, t13685) # t13689: "cuda:0 f32[1, 71, 2048, 64]"
# t13692 = prims.convert_element_type(t13689, dtypes.bfloat16) # t13692: "cuda:0 bf16[1, 71, 2048, 64]"
# t13701 = prims.mul(t61, t13685) # t13701: "cuda:0 f32[1, 71, 2048, 64]"
# t13713 = prims.slice_prim(t13692, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t13713: "cuda:0 bf16[1, 71, 2048, 32]"
# t13714 = prims.slice_prim(t13692, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13714: "cuda:0 bf16[1, 71, 2048, 32]"
# t13715 = prims.convert_element_type(t13713, dtypes.float32) # t13715: "cuda:0 f32[1, 71, 2048, 32]"
# t13716 = prims.neg(t13715) # t13716: "cuda:0 f32[1, 71, 2048, 32]"
# t13717 = prims.convert_element_type(t13716, dtypes.bfloat16) # t13717: "cuda:0 bf16[1, 71, 2048, 32]"
# t13718 = prims.pad(t13717, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t13718: "cuda:0 bf16[1, 71, 2048, 64]"
# t13720 = prims.convert_element_type(t13718, dtypes.float32) # t13720: "cuda:0 f32[1, 71, 2048, 64]"
# t13721 = prims.add(t13701, t13720) # t13721: "cuda:0 f32[1, 71, 2048, 64]"
# t13723 = prims.pad(t13714, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t13723: "cuda:0 bf16[1, 71, 2048, 64]"
# t13725 = prims.convert_element_type(t13723, dtypes.float32) # t13725: "cuda:0 f32[1, 71, 2048, 64]"
# t13726 = prims.add(t13721, t13725) # t13726: "cuda:0 f32[1, 71, 2048, 64]"
# t13727 = prims.convert_element_type(t13726, dtypes.bfloat16) # t13727: "cuda:0 bf16[1, 71, 2048, 64]"
# t13728 = prims.pad(t13727, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t13728: "cuda:0 bf16[1, 71, 2048, 64]"
# t13729 = prims.convert_element_type(t13680, dtypes.float32) # t13729: "cuda:0 f32[1, 71, 2048, 64]"
# t13730 = prims.convert_element_type(t13728, dtypes.float32) # t13730: "cuda:0 f32[1, 71, 2048, 64]"
# t13731 = prims.add(t13729, t13730) # t13731: "cuda:0 f32[1, 71, 2048, 64]"
# t13732 = prims.convert_element_type(t13731, dtypes.bfloat16) # t13732: "cuda:0 bf16[1, 71, 2048, 64]"
# t13733 = prims.convert_element_type(t13682, dtypes.float32) # t13733: "cuda:0 f32[1, 71, 2048, 64]"
# t13737 = prims.mul(t66, t13733) # t13737: "cuda:0 f32[1, 71, 2048, 64]"
# t13740 = prims.convert_element_type(t13737, dtypes.bfloat16) # t13740: "cuda:0 bf16[1, 71, 2048, 64]"
# t13749 = prims.mul(t61, t13733) # t13749: "cuda:0 f32[1, 71, 2048, 64]"
# t13761 = prims.slice_prim(t13740, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t13761: "cuda:0 bf16[1, 71, 2048, 32]"
# t13762 = prims.slice_prim(t13740, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13762: "cuda:0 bf16[1, 71, 2048, 32]"
# t13763 = prims.convert_element_type(t13761, dtypes.float32) # t13763: "cuda:0 f32[1, 71, 2048, 32]"
# t13764 = prims.neg(t13763) # t13764: "cuda:0 f32[1, 71, 2048, 32]"
# t13765 = prims.convert_element_type(t13764, dtypes.bfloat16) # t13765: "cuda:0 bf16[1, 71, 2048, 32]"
# t13766 = prims.pad(t13765, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t13766: "cuda:0 bf16[1, 71, 2048, 64]"
# t13768 = prims.convert_element_type(t13766, dtypes.float32) # t13768: "cuda:0 f32[1, 71, 2048, 64]"
# t13769 = prims.add(t13749, t13768) # t13769: "cuda:0 f32[1, 71, 2048, 64]"
# t13771 = prims.pad(t13762, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t13771: "cuda:0 bf16[1, 71, 2048, 64]"
# t13773 = prims.convert_element_type(t13771, dtypes.float32) # t13773: "cuda:0 f32[1, 71, 2048, 64]"
# t13774 = prims.add(t13769, t13773) # t13774: "cuda:0 f32[1, 71, 2048, 64]"
# t13775 = prims.convert_element_type(t13774, dtypes.bfloat16) # t13775: "cuda:0 bf16[1, 71, 2048, 64]"
# t13776 = prims.pad(t13775, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t13776: "cuda:0 bf16[1, 71, 2048, 64]"
# t13777 = prims.convert_element_type(t13684, dtypes.float32) # t13777: "cuda:0 f32[1, 71, 2048, 64]"
# t13778 = prims.convert_element_type(t13776, dtypes.float32) # t13778: "cuda:0 f32[1, 71, 2048, 64]"
# t13779 = prims.add(t13777, t13778) # t13779: "cuda:0 f32[1, 71, 2048, 64]"
# t13780 = prims.convert_element_type(t13779, dtypes.bfloat16) # t13780: "cuda:0 bf16[1, 71, 2048, 64]"
# t13790 = prims.reshape(t13732, (1, 1, 71, 2048, 64)) # t13790: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13795 = prims.reshape(t13780, (1, 1, 71, 2048, 64)) # t13795: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t13801 = prims.convert_element_type(t13785, dtypes.float32) # t13801: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t13802 = prims.sum(t13801, (0, 1, 2)) # t13802: "cuda:0 f32[2048, 64]"
# t13803 = prims.convert_element_type(t13802, dtypes.bfloat16) # t13803: "cuda:0 bf16[2048, 64]"
# t13804 = prims.broadcast_in_dim(t13803, [1, 1, 1, 2048, 64], [3, 4]) # t13804: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t13810 = prims.convert_element_type(t13790, dtypes.float32) # t13810: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t13811 = prims.sum(t13810, (0, 1, 2)) # t13811: "cuda:0 f32[2048, 64]"
# t13812 = prims.convert_element_type(t13811, dtypes.bfloat16) # t13812: "cuda:0 bf16[2048, 64]"
# t13813 = prims.broadcast_in_dim(t13812, [1, 1, 1, 2048, 64], [3, 4]) # t13813: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t13819 = prims.cat((t13795, t13813, t13804), i1755) # t13819: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1755, t13678, t13682, t13785
t13825 = torch.permute(t13819, (0, 3, 1, 2, 4)) # t13825: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t13825 = ltorch.permute(t13819, (0, 3, 1, 2, 4)) # t13825: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t13825 = prims.transpose(t13819, (0, 3, 1, 2, 4)) # t13825: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t13819
t13831 = torch.reshape(t13825, (1, 2048, 4672)) # t13831: "cuda:0 bf16[1, 2048, 4672]"
# t13831 = ltorch.reshape(t13825, (1, 2048, 4672)) # t13831: "cuda:0 bf16[1, 2048, 4672]"
# t13831 = prims.reshape(t13825, (1, 2048, 4672)) # t13831: "cuda:0 bf16[1, 2048, 4672]"
del t13825
t13832 = torch.reshape(t13831, (-1, 4672)) # t13832: "cuda:0 bf16[2048, 4672]"
# t13832 = ltorch.reshape(t13831, (-1, 4672)) # t13832: "cuda:0 bf16[2048, 4672]"
# t13832 = prims.reshape(t13831, (2048, 4672)) # t13832: "cuda:0 bf16[2048, 4672]"
del t13831
t13836 = torch.permute(t13832, (1, 0)) # t13836: "cuda:0 bf16[4672, 2048]"
# t13836 = ltorch.permute(t13832, (1, 0)) # t13836: "cuda:0 bf16[4672, 2048]"
# t13836 = prims.transpose(t13832, (1, 0)) # t13836: "cuda:0 bf16[4672, 2048]"
t13838 = torch.matmul(t13836, t13658) # t13838: "cuda:0 bf16[4672, 4544]"
# t13838 = ltorch.matmul(t13836, t13837) # t13838: "cuda:0 bf16[4672, 4544]"
# t13838 = prims.matmul(t13836, t13837) # t13838: "cuda:0 bf16[4672, 4544]"
del t13836, t13658
t13833 = torch.matmul(t13832, t_transformer_h_27_attn_attn_weight) # t13833: "cuda:0 bf16[2048, 4544]"
# t13833 = ltorch.matmul(t13832, t_transformer_h_27_attn_attn_weight) # t13833: "cuda:0 bf16[2048, 4544]"
# t13833 = prims.matmul(t13832, t_transformer_h_27_attn_attn_weight) # t13833: "cuda:0 bf16[2048, 4544]"
del t13832, t_transformer_h_27_attn_attn_weight
t13655 = torch.reshape(t13654, (1, 2048, 4544)) # t13655: "cuda:0 bf16[1, 2048, 4544]"
# t13655 = ltorch.reshape(t13654, (1, 2048, 4544)) # t13655: "cuda:0 bf16[1, 2048, 4544]"
# t13655 = prims.reshape(t13654, (1, 2048, 4544)) # t13655: "cuda:0 bf16[1, 2048, 4544]"
del t13654
t13834 = torch.reshape(t13833, (1, 2048, 4544)) # t13834: "cuda:0 bf16[1, 2048, 4544]"
# t13834 = ltorch.reshape(t13833, (1, 2048, 4544)) # t13834: "cuda:0 bf16[1, 2048, 4544]"
# t13834 = prims.reshape(t13833, (1, 2048, 4544)) # t13834: "cuda:0 bf16[1, 2048, 4544]"
del t13833
[t13847, t13853, t13895] = nvFusion15(i13875, t13612, t13655, t13834, t4157, t4289, t4310, t4325, t4330, t4336)
# t4316 = prims.convert_element_type(t4157, dtypes.float32) # t4316: "cuda:0 f32[1, 2048, 4544]"
# t4311 = prims.convert_element_type(t4310, dtypes.float32) # t4311: "cuda:0 f32[1, 2048, 4544]"
# t4312 = prims.convert_element_type(t4289, dtypes.float32) # t4312: "cuda:0 f32[1, 2048, 4544]"
# t4313 = prims.add(t4311, t4312) # t4313: "cuda:0 f32[1, 2048, 4544]"
# t4317 = prims.add(t4313, t4316) # t4317: "cuda:0 f32[1, 2048, 4544]"
# t4327 = prims.broadcast_in_dim(t4325, [1, 2048, 1], [0, 1]) # t4327: "cuda:0 f32[1, 2048, 1]"
# t4331 = prims.broadcast_in_dim(t4327, (1, 2048, 4544), (0, 1, 2)) # t4331: "cuda:0 f32[1, 2048, 4544]"
# t4333 = prims.sub(t4317, t4331) # t4333: "cuda:0 f32[1, 2048, 4544]"
# t4334 = prims.broadcast_in_dim(t4330, (1, 2048, 4544), (0, 1, 2)) # t4334: "cuda:0 f32[1, 2048, 4544]"
# t4335 = prims.mul(t4333, t4334) # t4335: "cuda:0 f32[1, 2048, 4544]"
# t4337 = prims.convert_element_type(t4336, dtypes.float32) # t4337: "cuda:0 f32[1, 2048, 4544]"
# t13892 = prims.convert_element_type(t13612, dtypes.float32) # t13892: "cuda:0 f32[1, 2048, 4544]"
# t13839 = prims.convert_element_type(t13655, dtypes.float32) # t13839: "cuda:0 f32[1, 2048, 4544]"
# t13840 = prims.convert_element_type(t13834, dtypes.float32) # t13840: "cuda:0 f32[1, 2048, 4544]"
# t13841 = prims.add(t13839, t13840) # t13841: "cuda:0 f32[1, 2048, 4544]"
# t13846 = prims.sum(t13841, (0, 1)) # t13846: "cuda:0 f32[4544]"
# t13847 = prims.convert_element_type(t13846, dtypes.bfloat16) # t13847: "cuda:0 bf16[4544]"
# t13848 = prims.mul(t4337, t13841) # t13848: "cuda:0 f32[1, 2048, 4544]"
# t13849 = prims.mul(t4335, t13841) # t13849: "cuda:0 f32[1, 2048, 4544]"
# t13852 = prims.sum(t13849, (0, 1)) # t13852: "cuda:0 f32[4544]"
# t13853 = prims.convert_element_type(t13852, dtypes.bfloat16) # t13853: "cuda:0 bf16[4544]"
# t13854 = prims.mul(t4334, t13848) # t13854: "cuda:0 f32[1, 2048, 4544]"
# t13855 = prims.mul(t4333, t13848) # t13855: "cuda:0 f32[1, 2048, 4544]"
# t13856 = prims.sum(t13855, (0, 2)) # t13856: "cuda:0 f32[2048]"
# t13857 = prims.broadcast_in_dim(t13856, [1, 2048, 1], [1]) # t13857: "cuda:0 f32[1, 2048, 1]"
# t13858 = prims.neg(t13854) # t13858: "cuda:0 f32[1, 2048, 4544]"
# t13860 = prims.sum(t13858, (0, 2)) # t13860: "cuda:0 f32[2048]"
# t13861 = prims.broadcast_in_dim(t13860, [1, 2048, 1], [1]) # t13861: "cuda:0 f32[1, 2048, 1]"
# t13862 = prims.mul(-0.5, t13857) # t13862: "cuda:0 f32[1, 2048, 1]"
# t13863 = prims.pow(t4330, 3.0) # t13863: "cuda:0 f32[1, 2048, 1]"
# t13864 = prims.mul(t13862, t13863) # t13864: "cuda:0 f32[1, 2048, 1]"
# t13866 = prims.sum(t13861, (0, 2)) # t13866: "cuda:0 f32[2048]"
# t13867 = prims.broadcast_in_dim(t13866, [1, 2048], [1]) # t13867: "cuda:0 f32[1, 2048]"
# t13868 = prims.sum(t13864, (0, 2)) # t13868: "cuda:0 f32[2048]"
# t13869 = prims.broadcast_in_dim(t13868, [1, 2048], [1]) # t13869: "cuda:0 f32[1, 2048]"
# t13872 = prims.broadcast_in_dim(t13867, [1, 2048, 1], [0, 1]) # t13872: "cuda:0 f32[1, 2048, 1]"
# t13873 = prims.broadcast_in_dim(t13872, (1, 2048, 4544), (0, 1, 2)) # t13873: "cuda:0 f32[1, 2048, 4544]"
# t13874 = prims.mul(0.00022007042253521127, t13873) # t13874: "cuda:0 f32[1, 2048, 4544]"
# t13876 = prims.broadcast_in_dim(t13869, [1, 2048, 1], [0, 1]) # t13876: "cuda:0 f32[1, 2048, 1]"
# t13877 = prims.broadcast_in_dim(t13876, (1, 2048, 4544), (0, 1, 2)) # t13877: "cuda:0 f32[1, 2048, 4544]"
# t13879 = prims.broadcast_in_dim(t4325, [1, 2048, 1], [0, 1]) # t13879: "cuda:0 f32[1, 2048, 1]"
# t13880 = prims.broadcast_in_dim(t13879, (1, 2048, 4544), (0, 1, 2)) # t13880: "cuda:0 f32[1, 2048, 4544]"
# t13881 = prims.mul(2.0, t13877) # t13881: "cuda:0 f32[1, 2048, 4544]"
# t13882 = prims.sub(t4317, t13880) # t13882: "cuda:0 f32[1, 2048, 4544]"
# t13883 = prims.mul(t13881, t13882) # t13883: "cuda:0 f32[1, 2048, 4544]"
# f13884 = prims.convert_element_type(i13875, float) # f13884: "float 4544.0"
# t13885 = prims.div(t13883, f13884) # t13885: "cuda:0 f32[1, 2048, 4544]"
# t13886 = prims.add(t13874, t13885) # t13886: "cuda:0 f32[1, 2048, 4544]"
# t13890 = prims.add(t13854, t13886) # t13890: "cuda:0 f32[1, 2048, 4544]"
# t13894 = prims.add(t13892, t13890) # t13894: "cuda:0 f32[1, 2048, 4544]"
# t13895 = prims.convert_element_type(t13894, dtypes.bfloat16) # t13895: "cuda:0 bf16[1, 2048, 4544]"
del i13875, t13612, t13655, t13834, t4157, t4289, t4310, t4325, t4330, t4336
t13902 = torch.reshape(t13895, (-1, 4544)) # t13902: "cuda:0 bf16[2048, 4544]"
# t13902 = ltorch.reshape(t13895, (-1, 4544)) # t13902: "cuda:0 bf16[2048, 4544]"
# t13902 = prims.reshape(t13895, (2048, 4544)) # t13902: "cuda:0 bf16[2048, 4544]"
t13906 = torch.permute(t13902, (1, 0)) # t13906: "cuda:0 bf16[4544, 2048]"
# t13906 = ltorch.permute(t13902, (1, 0)) # t13906: "cuda:0 bf16[4544, 2048]"
# t13906 = prims.transpose(t13902, (1, 0)) # t13906: "cuda:0 bf16[4544, 2048]"
t13949 = torch.matmul(t13906, t13948) # t13949: "cuda:0 bf16[4544, 4544]"
# t13949 = ltorch.matmul(t13947, t13948) # t13949: "cuda:0 bf16[4544, 4544]"
# t13949 = prims.matmul(t13947, t13948) # t13949: "cuda:0 bf16[4544, 4544]"
del t13948
t13903 = torch.matmul(t13902, t_transformer_h_26_mlp_proj_weight) # t13903: "cuda:0 bf16[2048, 18176]"
# t13903 = ltorch.matmul(t13902, t_transformer_h_26_mlp_proj_weight) # t13903: "cuda:0 bf16[2048, 18176]"
# t13903 = prims.matmul(t13902, t_transformer_h_26_mlp_proj_weight) # t13903: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_26_mlp_proj_weight
t13908 = torch.matmul(t13906, t13907) # t13908: "cuda:0 bf16[4544, 18176]"
# t13908 = ltorch.matmul(t13906, t13907) # t13908: "cuda:0 bf16[4544, 18176]"
# t13908 = prims.matmul(t13906, t13907) # t13908: "cuda:0 bf16[4544, 18176]"
del t13906, t13907
t13944 = torch.matmul(t13902, t_transformer_h_26_attn_proj_weight) # t13944: "cuda:0 bf16[2048, 4544]"
# t13944 = ltorch.matmul(t13943, t_transformer_h_26_attn_proj_weight) # t13944: "cuda:0 bf16[2048, 4544]"
# t13944 = prims.matmul(t13943, t_transformer_h_26_attn_proj_weight) # t13944: "cuda:0 bf16[2048, 4544]"
del t13902, t_transformer_h_26_attn_proj_weight
t13904 = torch.reshape(t13903, (1, 2048, 18176)) # t13904: "cuda:0 bf16[1, 2048, 18176]"
# t13904 = ltorch.reshape(t13903, (1, 2048, 18176)) # t13904: "cuda:0 bf16[1, 2048, 18176]"
# t13904 = prims.reshape(t13903, (1, 2048, 18176)) # t13904: "cuda:0 bf16[1, 2048, 18176]"
del t13903
t13945 = torch.reshape(t13944, (1, 2048, 4544)) # t13945: "cuda:0 bf16[1, 2048, 4544]"
# t13945 = ltorch.reshape(t13944, (1, 2048, 4544)) # t13945: "cuda:0 bf16[1, 2048, 4544]"
# t13945 = prims.reshape(t13944, (1, 2048, 4544)) # t13945: "cuda:0 bf16[1, 2048, 4544]"
del t13944
t13953 = torch.reshape(t13945, (1, 2048, 71, 64)) # t13953: "cuda:0 bf16[1, 2048, 71, 64]"
# t13953 = ltorch.reshape(t13945, (1, 2048, 71, 64)) # t13953: "cuda:0 bf16[1, 2048, 71, 64]"
# t13953 = prims.reshape(t13945, (1, 2048, 71, 64)) # t13953: "cuda:0 bf16[1, 2048, 71, 64]"
del t13945
t13956 = torch.permute(t13953, (0, 2, 1, 3)) # t13956: "cuda:0 bf16[1, 71, 2048, 64]"
# t13956 = ltorch.permute(t13953, (0, 2, 1, 3)) # t13956: "cuda:0 bf16[1, 71, 2048, 64]"
# t13956 = prims.transpose(t13953, (0, 2, 1, 3)) # t13956: "cuda:0 bf16[1, 71, 2048, 64]"
del t13953
[t13935] = nvFusion16(f1727, f1729, t13904, t4290)
# t4291 = prims.convert_element_type(t4290, dtypes.float32) # t4291: "cuda:0 f32[1, 2048, 18176]"
# t4293 = prims.div(t4291, 1.4142135623730951) # t4293: "cuda:0 f32[1, 2048, 18176]"
# t4296 = prims.erf(t4293) # t4296: "cuda:0 f32[1, 2048, 18176]"
# t4300 = prims.mul(0.5, t4296) # t4300: "cuda:0 f32[1, 2048, 18176]"
# t4304 = prims.add(0.5, t4300) # t4304: "cuda:0 f32[1, 2048, 18176]"
# t13909 = prims.convert_element_type(t13904, dtypes.float32) # t13909: "cuda:0 f32[1, 2048, 18176]"
# t13910 = prims.mul(t4304, t13909) # t13910: "cuda:0 f32[1, 2048, 18176]"
# t13911 = prims.mul(t4291, t13909) # t13911: "cuda:0 f32[1, 2048, 18176]"
# t13919 = prims.mul(f1729, t13911) # t13919: "cuda:0 f32[1, 2048, 18176]"
# t13922 = prims.pow(t4293, 2.0) # t13922: "cuda:0 f32[1, 2048, 18176]"
# t13923 = prims.neg(t13922) # t13923: "cuda:0 f32[1, 2048, 18176]"
# t13924 = prims.exp(t13923) # t13924: "cuda:0 f32[1, 2048, 18176]"
# t13925 = prims.mul(1.1283791670955126, t13924) # t13925: "cuda:0 f32[1, 2048, 18176]"
# t13926 = prims.mul(t13925, t13919) # t13926: "cuda:0 f32[1, 2048, 18176]"
# t13930 = prims.div(t13926, f1727) # t13930: "cuda:0 f32[1, 2048, 18176]"
# t13934 = prims.add(t13910, t13930) # t13934: "cuda:0 f32[1, 2048, 18176]"
# t13935 = prims.convert_element_type(t13934, dtypes.bfloat16) # t13935: "cuda:0 bf16[1, 2048, 18176]"
del f1727, f1729, t13904, t4290
t13936 = torch.reshape(t13935, (-1, 18176)) # t13936: "cuda:0 bf16[2048, 18176]"
# t13936 = ltorch.reshape(t13935, (-1, 18176)) # t13936: "cuda:0 bf16[2048, 18176]"
# t13936 = prims.reshape(t13935, (2048, 18176)) # t13936: "cuda:0 bf16[2048, 18176]"
del t13935
t13940 = torch.permute(t13936, (1, 0)) # t13940: "cuda:0 bf16[18176, 2048]"
# t13940 = ltorch.permute(t13936, (1, 0)) # t13940: "cuda:0 bf16[18176, 2048]"
# t13940 = prims.transpose(t13936, (1, 0)) # t13940: "cuda:0 bf16[18176, 2048]"
t13942 = torch.matmul(t13940, t13941) # t13942: "cuda:0 bf16[18176, 4544]"
# t13942 = ltorch.matmul(t13940, t13941) # t13942: "cuda:0 bf16[18176, 4544]"
# t13942 = prims.matmul(t13940, t13941) # t13942: "cuda:0 bf16[18176, 4544]"
del t13940
t13937 = torch.matmul(t13936, t_transformer_h_26_mlp_fc_weight) # t13937: "cuda:0 bf16[2048, 4544]"
# t13937 = ltorch.matmul(t13936, t_transformer_h_26_mlp_fc_weight) # t13937: "cuda:0 bf16[2048, 4544]"
# t13937 = prims.matmul(t13936, t_transformer_h_26_mlp_fc_weight) # t13937: "cuda:0 bf16[2048, 4544]"
del t13936, t_transformer_h_26_mlp_fc_weight
(t13957, t13958, t13959) = cudnn_sdpa_bwd(t13956, t4274, t4277, t4227, None, f1718, b1719, t4278, t4279, t4280, t4281, scale=f1720, cat_grad_qkv=False)
del t13956, t4274, t4277, t4227, f1718, b1719, t4278, t4279, t4280, t4281, f1720
t13961 = torch_slice_prim_impl(t13958, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13961: "cuda:0 bf16[1, 71, 2048, 64]"
del t13958
t13965 = torch_slice_prim_impl(t13957, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13965: "cuda:0 bf16[1, 71, 2048, 64]"
del t13957
t14068 = torch.reshape(t13959, (1, 1, 71, 2048, 64)) # t14068: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14068 = ltorch.reshape(t13959, (1, 1, 71, 2048, 64)) # t14068: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14068 = prims.reshape(t13959, (1, 1, 71, 2048, 64)) # t14068: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t13959
[t14102] = nvFusion17(i1691, t13961, t13965, t14068, t61, t66)
# t13962 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t13962: "cuda:0 bf16[1, 71, 2048, 0]"
# t13963 = prims.pad(t13962, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t13963: "cuda:0 bf16[1, 71, 2048, 64]"
# t13966 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t13966: "cuda:0 bf16[1, 71, 2048, 0]"
# t13967 = prims.pad(t13966, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t13967: "cuda:0 bf16[1, 71, 2048, 64]"
# t13968 = prims.convert_element_type(t13961, dtypes.float32) # t13968: "cuda:0 f32[1, 71, 2048, 64]"
# t13972 = prims.mul(t66, t13968) # t13972: "cuda:0 f32[1, 71, 2048, 64]"
# t13975 = prims.convert_element_type(t13972, dtypes.bfloat16) # t13975: "cuda:0 bf16[1, 71, 2048, 64]"
# t13984 = prims.mul(t61, t13968) # t13984: "cuda:0 f32[1, 71, 2048, 64]"
# t13996 = prims.slice_prim(t13975, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t13996: "cuda:0 bf16[1, 71, 2048, 32]"
# t13997 = prims.slice_prim(t13975, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t13997: "cuda:0 bf16[1, 71, 2048, 32]"
# t13998 = prims.convert_element_type(t13996, dtypes.float32) # t13998: "cuda:0 f32[1, 71, 2048, 32]"
# t13999 = prims.neg(t13998) # t13999: "cuda:0 f32[1, 71, 2048, 32]"
# t14000 = prims.convert_element_type(t13999, dtypes.bfloat16) # t14000: "cuda:0 bf16[1, 71, 2048, 32]"
# t14001 = prims.pad(t14000, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t14001: "cuda:0 bf16[1, 71, 2048, 64]"
# t14003 = prims.convert_element_type(t14001, dtypes.float32) # t14003: "cuda:0 f32[1, 71, 2048, 64]"
# t14004 = prims.add(t13984, t14003) # t14004: "cuda:0 f32[1, 71, 2048, 64]"
# t14006 = prims.pad(t13997, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t14006: "cuda:0 bf16[1, 71, 2048, 64]"
# t14008 = prims.convert_element_type(t14006, dtypes.float32) # t14008: "cuda:0 f32[1, 71, 2048, 64]"
# t14009 = prims.add(t14004, t14008) # t14009: "cuda:0 f32[1, 71, 2048, 64]"
# t14010 = prims.convert_element_type(t14009, dtypes.bfloat16) # t14010: "cuda:0 bf16[1, 71, 2048, 64]"
# t14011 = prims.pad(t14010, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t14011: "cuda:0 bf16[1, 71, 2048, 64]"
# t14012 = prims.convert_element_type(t13963, dtypes.float32) # t14012: "cuda:0 f32[1, 71, 2048, 64]"
# t14013 = prims.convert_element_type(t14011, dtypes.float32) # t14013: "cuda:0 f32[1, 71, 2048, 64]"
# t14014 = prims.add(t14012, t14013) # t14014: "cuda:0 f32[1, 71, 2048, 64]"
# t14015 = prims.convert_element_type(t14014, dtypes.bfloat16) # t14015: "cuda:0 bf16[1, 71, 2048, 64]"
# t14016 = prims.convert_element_type(t13965, dtypes.float32) # t14016: "cuda:0 f32[1, 71, 2048, 64]"
# t14020 = prims.mul(t66, t14016) # t14020: "cuda:0 f32[1, 71, 2048, 64]"
# t14023 = prims.convert_element_type(t14020, dtypes.bfloat16) # t14023: "cuda:0 bf16[1, 71, 2048, 64]"
# t14032 = prims.mul(t61, t14016) # t14032: "cuda:0 f32[1, 71, 2048, 64]"
# t14044 = prims.slice_prim(t14023, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t14044: "cuda:0 bf16[1, 71, 2048, 32]"
# t14045 = prims.slice_prim(t14023, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14045: "cuda:0 bf16[1, 71, 2048, 32]"
# t14046 = prims.convert_element_type(t14044, dtypes.float32) # t14046: "cuda:0 f32[1, 71, 2048, 32]"
# t14047 = prims.neg(t14046) # t14047: "cuda:0 f32[1, 71, 2048, 32]"
# t14048 = prims.convert_element_type(t14047, dtypes.bfloat16) # t14048: "cuda:0 bf16[1, 71, 2048, 32]"
# t14049 = prims.pad(t14048, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t14049: "cuda:0 bf16[1, 71, 2048, 64]"
# t14051 = prims.convert_element_type(t14049, dtypes.float32) # t14051: "cuda:0 f32[1, 71, 2048, 64]"
# t14052 = prims.add(t14032, t14051) # t14052: "cuda:0 f32[1, 71, 2048, 64]"
# t14054 = prims.pad(t14045, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t14054: "cuda:0 bf16[1, 71, 2048, 64]"
# t14056 = prims.convert_element_type(t14054, dtypes.float32) # t14056: "cuda:0 f32[1, 71, 2048, 64]"
# t14057 = prims.add(t14052, t14056) # t14057: "cuda:0 f32[1, 71, 2048, 64]"
# t14058 = prims.convert_element_type(t14057, dtypes.bfloat16) # t14058: "cuda:0 bf16[1, 71, 2048, 64]"
# t14059 = prims.pad(t14058, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t14059: "cuda:0 bf16[1, 71, 2048, 64]"
# t14060 = prims.convert_element_type(t13967, dtypes.float32) # t14060: "cuda:0 f32[1, 71, 2048, 64]"
# t14061 = prims.convert_element_type(t14059, dtypes.float32) # t14061: "cuda:0 f32[1, 71, 2048, 64]"
# t14062 = prims.add(t14060, t14061) # t14062: "cuda:0 f32[1, 71, 2048, 64]"
# t14063 = prims.convert_element_type(t14062, dtypes.bfloat16) # t14063: "cuda:0 bf16[1, 71, 2048, 64]"
# t14073 = prims.reshape(t14015, (1, 1, 71, 2048, 64)) # t14073: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14078 = prims.reshape(t14063, (1, 1, 71, 2048, 64)) # t14078: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14084 = prims.convert_element_type(t14068, dtypes.float32) # t14084: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t14085 = prims.sum(t14084, (0, 1, 2)) # t14085: "cuda:0 f32[2048, 64]"
# t14086 = prims.convert_element_type(t14085, dtypes.bfloat16) # t14086: "cuda:0 bf16[2048, 64]"
# t14087 = prims.broadcast_in_dim(t14086, [1, 1, 1, 2048, 64], [3, 4]) # t14087: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t14093 = prims.convert_element_type(t14073, dtypes.float32) # t14093: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t14094 = prims.sum(t14093, (0, 1, 2)) # t14094: "cuda:0 f32[2048, 64]"
# t14095 = prims.convert_element_type(t14094, dtypes.bfloat16) # t14095: "cuda:0 bf16[2048, 64]"
# t14096 = prims.broadcast_in_dim(t14095, [1, 1, 1, 2048, 64], [3, 4]) # t14096: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t14102 = prims.cat((t14078, t14096, t14087), i1691) # t14102: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1691, t13961, t13965, t14068
t14108 = torch.permute(t14102, (0, 3, 1, 2, 4)) # t14108: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t14108 = ltorch.permute(t14102, (0, 3, 1, 2, 4)) # t14108: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t14108 = prims.transpose(t14102, (0, 3, 1, 2, 4)) # t14108: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t14102
t14114 = torch.reshape(t14108, (1, 2048, 4672)) # t14114: "cuda:0 bf16[1, 2048, 4672]"
# t14114 = ltorch.reshape(t14108, (1, 2048, 4672)) # t14114: "cuda:0 bf16[1, 2048, 4672]"
# t14114 = prims.reshape(t14108, (1, 2048, 4672)) # t14114: "cuda:0 bf16[1, 2048, 4672]"
del t14108
t14115 = torch.reshape(t14114, (-1, 4672)) # t14115: "cuda:0 bf16[2048, 4672]"
# t14115 = ltorch.reshape(t14114, (-1, 4672)) # t14115: "cuda:0 bf16[2048, 4672]"
# t14115 = prims.reshape(t14114, (2048, 4672)) # t14115: "cuda:0 bf16[2048, 4672]"
del t14114
t14119 = torch.permute(t14115, (1, 0)) # t14119: "cuda:0 bf16[4672, 2048]"
# t14119 = ltorch.permute(t14115, (1, 0)) # t14119: "cuda:0 bf16[4672, 2048]"
# t14119 = prims.transpose(t14115, (1, 0)) # t14119: "cuda:0 bf16[4672, 2048]"
t14121 = torch.matmul(t14119, t13941) # t14121: "cuda:0 bf16[4672, 4544]"
# t14121 = ltorch.matmul(t14119, t14120) # t14121: "cuda:0 bf16[4672, 4544]"
# t14121 = prims.matmul(t14119, t14120) # t14121: "cuda:0 bf16[4672, 4544]"
del t14119, t13941
t14116 = torch.matmul(t14115, t_transformer_h_26_attn_attn_weight) # t14116: "cuda:0 bf16[2048, 4544]"
# t14116 = ltorch.matmul(t14115, t_transformer_h_26_attn_attn_weight) # t14116: "cuda:0 bf16[2048, 4544]"
# t14116 = prims.matmul(t14115, t_transformer_h_26_attn_attn_weight) # t14116: "cuda:0 bf16[2048, 4544]"
del t14115, t_transformer_h_26_attn_attn_weight
t13938 = torch.reshape(t13937, (1, 2048, 4544)) # t13938: "cuda:0 bf16[1, 2048, 4544]"
# t13938 = ltorch.reshape(t13937, (1, 2048, 4544)) # t13938: "cuda:0 bf16[1, 2048, 4544]"
# t13938 = prims.reshape(t13937, (1, 2048, 4544)) # t13938: "cuda:0 bf16[1, 2048, 4544]"
del t13937
t14117 = torch.reshape(t14116, (1, 2048, 4544)) # t14117: "cuda:0 bf16[1, 2048, 4544]"
# t14117 = ltorch.reshape(t14116, (1, 2048, 4544)) # t14117: "cuda:0 bf16[1, 2048, 4544]"
# t14117 = prims.reshape(t14116, (1, 2048, 4544)) # t14117: "cuda:0 bf16[1, 2048, 4544]"
del t14116
[t14130, t14136, t14178] = nvFusion18(i14158, t13895, t13938, t14117, t3996, t4128, t4149, t4164, t4169, t4175)
# t4155 = prims.convert_element_type(t3996, dtypes.float32) # t4155: "cuda:0 f32[1, 2048, 4544]"
# t4150 = prims.convert_element_type(t4149, dtypes.float32) # t4150: "cuda:0 f32[1, 2048, 4544]"
# t4151 = prims.convert_element_type(t4128, dtypes.float32) # t4151: "cuda:0 f32[1, 2048, 4544]"
# t4152 = prims.add(t4150, t4151) # t4152: "cuda:0 f32[1, 2048, 4544]"
# t4156 = prims.add(t4152, t4155) # t4156: "cuda:0 f32[1, 2048, 4544]"
# t4166 = prims.broadcast_in_dim(t4164, [1, 2048, 1], [0, 1]) # t4166: "cuda:0 f32[1, 2048, 1]"
# t4170 = prims.broadcast_in_dim(t4166, (1, 2048, 4544), (0, 1, 2)) # t4170: "cuda:0 f32[1, 2048, 4544]"
# t4172 = prims.sub(t4156, t4170) # t4172: "cuda:0 f32[1, 2048, 4544]"
# t4173 = prims.broadcast_in_dim(t4169, (1, 2048, 4544), (0, 1, 2)) # t4173: "cuda:0 f32[1, 2048, 4544]"
# t4174 = prims.mul(t4172, t4173) # t4174: "cuda:0 f32[1, 2048, 4544]"
# t4176 = prims.convert_element_type(t4175, dtypes.float32) # t4176: "cuda:0 f32[1, 2048, 4544]"
# t14175 = prims.convert_element_type(t13895, dtypes.float32) # t14175: "cuda:0 f32[1, 2048, 4544]"
# t14122 = prims.convert_element_type(t13938, dtypes.float32) # t14122: "cuda:0 f32[1, 2048, 4544]"
# t14123 = prims.convert_element_type(t14117, dtypes.float32) # t14123: "cuda:0 f32[1, 2048, 4544]"
# t14124 = prims.add(t14122, t14123) # t14124: "cuda:0 f32[1, 2048, 4544]"
# t14129 = prims.sum(t14124, (0, 1)) # t14129: "cuda:0 f32[4544]"
# t14130 = prims.convert_element_type(t14129, dtypes.bfloat16) # t14130: "cuda:0 bf16[4544]"
# t14131 = prims.mul(t4176, t14124) # t14131: "cuda:0 f32[1, 2048, 4544]"
# t14132 = prims.mul(t4174, t14124) # t14132: "cuda:0 f32[1, 2048, 4544]"
# t14135 = prims.sum(t14132, (0, 1)) # t14135: "cuda:0 f32[4544]"
# t14136 = prims.convert_element_type(t14135, dtypes.bfloat16) # t14136: "cuda:0 bf16[4544]"
# t14137 = prims.mul(t4173, t14131) # t14137: "cuda:0 f32[1, 2048, 4544]"
# t14138 = prims.mul(t4172, t14131) # t14138: "cuda:0 f32[1, 2048, 4544]"
# t14139 = prims.sum(t14138, (0, 2)) # t14139: "cuda:0 f32[2048]"
# t14140 = prims.broadcast_in_dim(t14139, [1, 2048, 1], [1]) # t14140: "cuda:0 f32[1, 2048, 1]"
# t14141 = prims.neg(t14137) # t14141: "cuda:0 f32[1, 2048, 4544]"
# t14143 = prims.sum(t14141, (0, 2)) # t14143: "cuda:0 f32[2048]"
# t14144 = prims.broadcast_in_dim(t14143, [1, 2048, 1], [1]) # t14144: "cuda:0 f32[1, 2048, 1]"
# t14145 = prims.mul(-0.5, t14140) # t14145: "cuda:0 f32[1, 2048, 1]"
# t14146 = prims.pow(t4169, 3.0) # t14146: "cuda:0 f32[1, 2048, 1]"
# t14147 = prims.mul(t14145, t14146) # t14147: "cuda:0 f32[1, 2048, 1]"
# t14149 = prims.sum(t14144, (0, 2)) # t14149: "cuda:0 f32[2048]"
# t14150 = prims.broadcast_in_dim(t14149, [1, 2048], [1]) # t14150: "cuda:0 f32[1, 2048]"
# t14151 = prims.sum(t14147, (0, 2)) # t14151: "cuda:0 f32[2048]"
# t14152 = prims.broadcast_in_dim(t14151, [1, 2048], [1]) # t14152: "cuda:0 f32[1, 2048]"
# t14155 = prims.broadcast_in_dim(t14150, [1, 2048, 1], [0, 1]) # t14155: "cuda:0 f32[1, 2048, 1]"
# t14156 = prims.broadcast_in_dim(t14155, (1, 2048, 4544), (0, 1, 2)) # t14156: "cuda:0 f32[1, 2048, 4544]"
# t14157 = prims.mul(0.00022007042253521127, t14156) # t14157: "cuda:0 f32[1, 2048, 4544]"
# t14159 = prims.broadcast_in_dim(t14152, [1, 2048, 1], [0, 1]) # t14159: "cuda:0 f32[1, 2048, 1]"
# t14160 = prims.broadcast_in_dim(t14159, (1, 2048, 4544), (0, 1, 2)) # t14160: "cuda:0 f32[1, 2048, 4544]"
# t14162 = prims.broadcast_in_dim(t4164, [1, 2048, 1], [0, 1]) # t14162: "cuda:0 f32[1, 2048, 1]"
# t14163 = prims.broadcast_in_dim(t14162, (1, 2048, 4544), (0, 1, 2)) # t14163: "cuda:0 f32[1, 2048, 4544]"
# t14164 = prims.mul(2.0, t14160) # t14164: "cuda:0 f32[1, 2048, 4544]"
# t14165 = prims.sub(t4156, t14163) # t14165: "cuda:0 f32[1, 2048, 4544]"
# t14166 = prims.mul(t14164, t14165) # t14166: "cuda:0 f32[1, 2048, 4544]"
# f14167 = prims.convert_element_type(i14158, float) # f14167: "float 4544.0"
# t14168 = prims.div(t14166, f14167) # t14168: "cuda:0 f32[1, 2048, 4544]"
# t14169 = prims.add(t14157, t14168) # t14169: "cuda:0 f32[1, 2048, 4544]"
# t14173 = prims.add(t14137, t14169) # t14173: "cuda:0 f32[1, 2048, 4544]"
# t14177 = prims.add(t14175, t14173) # t14177: "cuda:0 f32[1, 2048, 4544]"
# t14178 = prims.convert_element_type(t14177, dtypes.bfloat16) # t14178: "cuda:0 bf16[1, 2048, 4544]"
del i14158, t13895, t13938, t14117, t3996, t4128, t4149, t4164, t4169, t4175
t14185 = torch.reshape(t14178, (-1, 4544)) # t14185: "cuda:0 bf16[2048, 4544]"
# t14185 = ltorch.reshape(t14178, (-1, 4544)) # t14185: "cuda:0 bf16[2048, 4544]"
# t14185 = prims.reshape(t14178, (2048, 4544)) # t14185: "cuda:0 bf16[2048, 4544]"
t14189 = torch.permute(t14185, (1, 0)) # t14189: "cuda:0 bf16[4544, 2048]"
# t14189 = ltorch.permute(t14185, (1, 0)) # t14189: "cuda:0 bf16[4544, 2048]"
# t14189 = prims.transpose(t14185, (1, 0)) # t14189: "cuda:0 bf16[4544, 2048]"
t14186 = torch.matmul(t14185, t_transformer_h_25_mlp_proj_weight) # t14186: "cuda:0 bf16[2048, 18176]"
# t14186 = ltorch.matmul(t14185, t_transformer_h_25_mlp_proj_weight) # t14186: "cuda:0 bf16[2048, 18176]"
# t14186 = prims.matmul(t14185, t_transformer_h_25_mlp_proj_weight) # t14186: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_25_mlp_proj_weight
t14191 = torch.matmul(t14189, t14190) # t14191: "cuda:0 bf16[4544, 18176]"
# t14191 = ltorch.matmul(t14189, t14190) # t14191: "cuda:0 bf16[4544, 18176]"
# t14191 = prims.matmul(t14189, t14190) # t14191: "cuda:0 bf16[4544, 18176]"
del t14190
t14227 = torch.matmul(t14185, t_transformer_h_25_attn_proj_weight) # t14227: "cuda:0 bf16[2048, 4544]"
# t14227 = ltorch.matmul(t14226, t_transformer_h_25_attn_proj_weight) # t14227: "cuda:0 bf16[2048, 4544]"
# t14227 = prims.matmul(t14226, t_transformer_h_25_attn_proj_weight) # t14227: "cuda:0 bf16[2048, 4544]"
del t14185, t_transformer_h_25_attn_proj_weight
t14232 = torch.matmul(t14189, t14231) # t14232: "cuda:0 bf16[4544, 4544]"
# t14232 = ltorch.matmul(t14230, t14231) # t14232: "cuda:0 bf16[4544, 4544]"
# t14232 = prims.matmul(t14230, t14231) # t14232: "cuda:0 bf16[4544, 4544]"
del t14189, t14231
t14187 = torch.reshape(t14186, (1, 2048, 18176)) # t14187: "cuda:0 bf16[1, 2048, 18176]"
# t14187 = ltorch.reshape(t14186, (1, 2048, 18176)) # t14187: "cuda:0 bf16[1, 2048, 18176]"
# t14187 = prims.reshape(t14186, (1, 2048, 18176)) # t14187: "cuda:0 bf16[1, 2048, 18176]"
del t14186
t14228 = torch.reshape(t14227, (1, 2048, 4544)) # t14228: "cuda:0 bf16[1, 2048, 4544]"
# t14228 = ltorch.reshape(t14227, (1, 2048, 4544)) # t14228: "cuda:0 bf16[1, 2048, 4544]"
# t14228 = prims.reshape(t14227, (1, 2048, 4544)) # t14228: "cuda:0 bf16[1, 2048, 4544]"
del t14227
t14236 = torch.reshape(t14228, (1, 2048, 71, 64)) # t14236: "cuda:0 bf16[1, 2048, 71, 64]"
# t14236 = ltorch.reshape(t14228, (1, 2048, 71, 64)) # t14236: "cuda:0 bf16[1, 2048, 71, 64]"
# t14236 = prims.reshape(t14228, (1, 2048, 71, 64)) # t14236: "cuda:0 bf16[1, 2048, 71, 64]"
del t14228
t14239 = torch.permute(t14236, (0, 2, 1, 3)) # t14239: "cuda:0 bf16[1, 71, 2048, 64]"
# t14239 = ltorch.permute(t14236, (0, 2, 1, 3)) # t14239: "cuda:0 bf16[1, 71, 2048, 64]"
# t14239 = prims.transpose(t14236, (0, 2, 1, 3)) # t14239: "cuda:0 bf16[1, 71, 2048, 64]"
del t14236
[t14218] = nvFusion19(f1663, f1665, t14187, t4129)
# t4130 = prims.convert_element_type(t4129, dtypes.float32) # t4130: "cuda:0 f32[1, 2048, 18176]"
# t4132 = prims.div(t4130, 1.4142135623730951) # t4132: "cuda:0 f32[1, 2048, 18176]"
# t4135 = prims.erf(t4132) # t4135: "cuda:0 f32[1, 2048, 18176]"
# t4139 = prims.mul(0.5, t4135) # t4139: "cuda:0 f32[1, 2048, 18176]"
# t4143 = prims.add(0.5, t4139) # t4143: "cuda:0 f32[1, 2048, 18176]"
# t14192 = prims.convert_element_type(t14187, dtypes.float32) # t14192: "cuda:0 f32[1, 2048, 18176]"
# t14193 = prims.mul(t4143, t14192) # t14193: "cuda:0 f32[1, 2048, 18176]"
# t14194 = prims.mul(t4130, t14192) # t14194: "cuda:0 f32[1, 2048, 18176]"
# t14202 = prims.mul(f1665, t14194) # t14202: "cuda:0 f32[1, 2048, 18176]"
# t14205 = prims.pow(t4132, 2.0) # t14205: "cuda:0 f32[1, 2048, 18176]"
# t14206 = prims.neg(t14205) # t14206: "cuda:0 f32[1, 2048, 18176]"
# t14207 = prims.exp(t14206) # t14207: "cuda:0 f32[1, 2048, 18176]"
# t14208 = prims.mul(1.1283791670955126, t14207) # t14208: "cuda:0 f32[1, 2048, 18176]"
# t14209 = prims.mul(t14208, t14202) # t14209: "cuda:0 f32[1, 2048, 18176]"
# t14213 = prims.div(t14209, f1663) # t14213: "cuda:0 f32[1, 2048, 18176]"
# t14217 = prims.add(t14193, t14213) # t14217: "cuda:0 f32[1, 2048, 18176]"
# t14218 = prims.convert_element_type(t14217, dtypes.bfloat16) # t14218: "cuda:0 bf16[1, 2048, 18176]"
del f1663, f1665, t14187, t4129
t14219 = torch.reshape(t14218, (-1, 18176)) # t14219: "cuda:0 bf16[2048, 18176]"
# t14219 = ltorch.reshape(t14218, (-1, 18176)) # t14219: "cuda:0 bf16[2048, 18176]"
# t14219 = prims.reshape(t14218, (2048, 18176)) # t14219: "cuda:0 bf16[2048, 18176]"
del t14218
t14223 = torch.permute(t14219, (1, 0)) # t14223: "cuda:0 bf16[18176, 2048]"
# t14223 = ltorch.permute(t14219, (1, 0)) # t14223: "cuda:0 bf16[18176, 2048]"
# t14223 = prims.transpose(t14219, (1, 0)) # t14223: "cuda:0 bf16[18176, 2048]"
t14225 = torch.matmul(t14223, t14224) # t14225: "cuda:0 bf16[18176, 4544]"
# t14225 = ltorch.matmul(t14223, t14224) # t14225: "cuda:0 bf16[18176, 4544]"
# t14225 = prims.matmul(t14223, t14224) # t14225: "cuda:0 bf16[18176, 4544]"
del t14223
t14220 = torch.matmul(t14219, t_transformer_h_25_mlp_fc_weight) # t14220: "cuda:0 bf16[2048, 4544]"
# t14220 = ltorch.matmul(t14219, t_transformer_h_25_mlp_fc_weight) # t14220: "cuda:0 bf16[2048, 4544]"
# t14220 = prims.matmul(t14219, t_transformer_h_25_mlp_fc_weight) # t14220: "cuda:0 bf16[2048, 4544]"
del t14219, t_transformer_h_25_mlp_fc_weight
(t14240, t14241, t14242) = cudnn_sdpa_bwd(t14239, t4113, t4116, t4066, None, f1654, b1655, t4117, t4118, t4119, t4120, scale=f1656, cat_grad_qkv=False)
del t14239, t4113, t4116, t4066, f1654, b1655, t4117, t4118, t4119, t4120, f1656
t14244 = torch_slice_prim_impl(t14241, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14244: "cuda:0 bf16[1, 71, 2048, 64]"
del t14241
t14248 = torch_slice_prim_impl(t14240, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14248: "cuda:0 bf16[1, 71, 2048, 64]"
del t14240
t14351 = torch.reshape(t14242, (1, 1, 71, 2048, 64)) # t14351: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14351 = ltorch.reshape(t14242, (1, 1, 71, 2048, 64)) # t14351: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14351 = prims.reshape(t14242, (1, 1, 71, 2048, 64)) # t14351: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t14242
[t14385] = nvFusion20(i1627, t14244, t14248, t14351, t61, t66)
# t14245 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t14245: "cuda:0 bf16[1, 71, 2048, 0]"
# t14246 = prims.pad(t14245, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t14246: "cuda:0 bf16[1, 71, 2048, 64]"
# t14249 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t14249: "cuda:0 bf16[1, 71, 2048, 0]"
# t14250 = prims.pad(t14249, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t14250: "cuda:0 bf16[1, 71, 2048, 64]"
# t14251 = prims.convert_element_type(t14244, dtypes.float32) # t14251: "cuda:0 f32[1, 71, 2048, 64]"
# t14255 = prims.mul(t66, t14251) # t14255: "cuda:0 f32[1, 71, 2048, 64]"
# t14258 = prims.convert_element_type(t14255, dtypes.bfloat16) # t14258: "cuda:0 bf16[1, 71, 2048, 64]"
# t14267 = prims.mul(t61, t14251) # t14267: "cuda:0 f32[1, 71, 2048, 64]"
# t14279 = prims.slice_prim(t14258, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t14279: "cuda:0 bf16[1, 71, 2048, 32]"
# t14280 = prims.slice_prim(t14258, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14280: "cuda:0 bf16[1, 71, 2048, 32]"
# t14281 = prims.convert_element_type(t14279, dtypes.float32) # t14281: "cuda:0 f32[1, 71, 2048, 32]"
# t14282 = prims.neg(t14281) # t14282: "cuda:0 f32[1, 71, 2048, 32]"
# t14283 = prims.convert_element_type(t14282, dtypes.bfloat16) # t14283: "cuda:0 bf16[1, 71, 2048, 32]"
# t14284 = prims.pad(t14283, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t14284: "cuda:0 bf16[1, 71, 2048, 64]"
# t14286 = prims.convert_element_type(t14284, dtypes.float32) # t14286: "cuda:0 f32[1, 71, 2048, 64]"
# t14287 = prims.add(t14267, t14286) # t14287: "cuda:0 f32[1, 71, 2048, 64]"
# t14289 = prims.pad(t14280, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t14289: "cuda:0 bf16[1, 71, 2048, 64]"
# t14291 = prims.convert_element_type(t14289, dtypes.float32) # t14291: "cuda:0 f32[1, 71, 2048, 64]"
# t14292 = prims.add(t14287, t14291) # t14292: "cuda:0 f32[1, 71, 2048, 64]"
# t14293 = prims.convert_element_type(t14292, dtypes.bfloat16) # t14293: "cuda:0 bf16[1, 71, 2048, 64]"
# t14294 = prims.pad(t14293, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t14294: "cuda:0 bf16[1, 71, 2048, 64]"
# t14295 = prims.convert_element_type(t14246, dtypes.float32) # t14295: "cuda:0 f32[1, 71, 2048, 64]"
# t14296 = prims.convert_element_type(t14294, dtypes.float32) # t14296: "cuda:0 f32[1, 71, 2048, 64]"
# t14297 = prims.add(t14295, t14296) # t14297: "cuda:0 f32[1, 71, 2048, 64]"
# t14298 = prims.convert_element_type(t14297, dtypes.bfloat16) # t14298: "cuda:0 bf16[1, 71, 2048, 64]"
# t14299 = prims.convert_element_type(t14248, dtypes.float32) # t14299: "cuda:0 f32[1, 71, 2048, 64]"
# t14303 = prims.mul(t66, t14299) # t14303: "cuda:0 f32[1, 71, 2048, 64]"
# t14306 = prims.convert_element_type(t14303, dtypes.bfloat16) # t14306: "cuda:0 bf16[1, 71, 2048, 64]"
# t14315 = prims.mul(t61, t14299) # t14315: "cuda:0 f32[1, 71, 2048, 64]"
# t14327 = prims.slice_prim(t14306, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t14327: "cuda:0 bf16[1, 71, 2048, 32]"
# t14328 = prims.slice_prim(t14306, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14328: "cuda:0 bf16[1, 71, 2048, 32]"
# t14329 = prims.convert_element_type(t14327, dtypes.float32) # t14329: "cuda:0 f32[1, 71, 2048, 32]"
# t14330 = prims.neg(t14329) # t14330: "cuda:0 f32[1, 71, 2048, 32]"
# t14331 = prims.convert_element_type(t14330, dtypes.bfloat16) # t14331: "cuda:0 bf16[1, 71, 2048, 32]"
# t14332 = prims.pad(t14331, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t14332: "cuda:0 bf16[1, 71, 2048, 64]"
# t14334 = prims.convert_element_type(t14332, dtypes.float32) # t14334: "cuda:0 f32[1, 71, 2048, 64]"
# t14335 = prims.add(t14315, t14334) # t14335: "cuda:0 f32[1, 71, 2048, 64]"
# t14337 = prims.pad(t14328, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t14337: "cuda:0 bf16[1, 71, 2048, 64]"
# t14339 = prims.convert_element_type(t14337, dtypes.float32) # t14339: "cuda:0 f32[1, 71, 2048, 64]"
# t14340 = prims.add(t14335, t14339) # t14340: "cuda:0 f32[1, 71, 2048, 64]"
# t14341 = prims.convert_element_type(t14340, dtypes.bfloat16) # t14341: "cuda:0 bf16[1, 71, 2048, 64]"
# t14342 = prims.pad(t14341, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t14342: "cuda:0 bf16[1, 71, 2048, 64]"
# t14343 = prims.convert_element_type(t14250, dtypes.float32) # t14343: "cuda:0 f32[1, 71, 2048, 64]"
# t14344 = prims.convert_element_type(t14342, dtypes.float32) # t14344: "cuda:0 f32[1, 71, 2048, 64]"
# t14345 = prims.add(t14343, t14344) # t14345: "cuda:0 f32[1, 71, 2048, 64]"
# t14346 = prims.convert_element_type(t14345, dtypes.bfloat16) # t14346: "cuda:0 bf16[1, 71, 2048, 64]"
# t14356 = prims.reshape(t14298, (1, 1, 71, 2048, 64)) # t14356: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14361 = prims.reshape(t14346, (1, 1, 71, 2048, 64)) # t14361: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14367 = prims.convert_element_type(t14351, dtypes.float32) # t14367: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t14368 = prims.sum(t14367, (0, 1, 2)) # t14368: "cuda:0 f32[2048, 64]"
# t14369 = prims.convert_element_type(t14368, dtypes.bfloat16) # t14369: "cuda:0 bf16[2048, 64]"
# t14370 = prims.broadcast_in_dim(t14369, [1, 1, 1, 2048, 64], [3, 4]) # t14370: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t14376 = prims.convert_element_type(t14356, dtypes.float32) # t14376: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t14377 = prims.sum(t14376, (0, 1, 2)) # t14377: "cuda:0 f32[2048, 64]"
# t14378 = prims.convert_element_type(t14377, dtypes.bfloat16) # t14378: "cuda:0 bf16[2048, 64]"
# t14379 = prims.broadcast_in_dim(t14378, [1, 1, 1, 2048, 64], [3, 4]) # t14379: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t14385 = prims.cat((t14361, t14379, t14370), i1627) # t14385: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1627, t14244, t14248, t14351
t14391 = torch.permute(t14385, (0, 3, 1, 2, 4)) # t14391: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t14391 = ltorch.permute(t14385, (0, 3, 1, 2, 4)) # t14391: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t14391 = prims.transpose(t14385, (0, 3, 1, 2, 4)) # t14391: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t14385
t14397 = torch.reshape(t14391, (1, 2048, 4672)) # t14397: "cuda:0 bf16[1, 2048, 4672]"
# t14397 = ltorch.reshape(t14391, (1, 2048, 4672)) # t14397: "cuda:0 bf16[1, 2048, 4672]"
# t14397 = prims.reshape(t14391, (1, 2048, 4672)) # t14397: "cuda:0 bf16[1, 2048, 4672]"
del t14391
t14398 = torch.reshape(t14397, (-1, 4672)) # t14398: "cuda:0 bf16[2048, 4672]"
# t14398 = ltorch.reshape(t14397, (-1, 4672)) # t14398: "cuda:0 bf16[2048, 4672]"
# t14398 = prims.reshape(t14397, (2048, 4672)) # t14398: "cuda:0 bf16[2048, 4672]"
del t14397
t14402 = torch.permute(t14398, (1, 0)) # t14402: "cuda:0 bf16[4672, 2048]"
# t14402 = ltorch.permute(t14398, (1, 0)) # t14402: "cuda:0 bf16[4672, 2048]"
# t14402 = prims.transpose(t14398, (1, 0)) # t14402: "cuda:0 bf16[4672, 2048]"
t14404 = torch.matmul(t14402, t14224) # t14404: "cuda:0 bf16[4672, 4544]"
# t14404 = ltorch.matmul(t14402, t14403) # t14404: "cuda:0 bf16[4672, 4544]"
# t14404 = prims.matmul(t14402, t14403) # t14404: "cuda:0 bf16[4672, 4544]"
del t14402, t14224
t14399 = torch.matmul(t14398, t_transformer_h_25_attn_attn_weight) # t14399: "cuda:0 bf16[2048, 4544]"
# t14399 = ltorch.matmul(t14398, t_transformer_h_25_attn_attn_weight) # t14399: "cuda:0 bf16[2048, 4544]"
# t14399 = prims.matmul(t14398, t_transformer_h_25_attn_attn_weight) # t14399: "cuda:0 bf16[2048, 4544]"
del t14398, t_transformer_h_25_attn_attn_weight
t14221 = torch.reshape(t14220, (1, 2048, 4544)) # t14221: "cuda:0 bf16[1, 2048, 4544]"
# t14221 = ltorch.reshape(t14220, (1, 2048, 4544)) # t14221: "cuda:0 bf16[1, 2048, 4544]"
# t14221 = prims.reshape(t14220, (1, 2048, 4544)) # t14221: "cuda:0 bf16[1, 2048, 4544]"
del t14220
t14400 = torch.reshape(t14399, (1, 2048, 4544)) # t14400: "cuda:0 bf16[1, 2048, 4544]"
# t14400 = ltorch.reshape(t14399, (1, 2048, 4544)) # t14400: "cuda:0 bf16[1, 2048, 4544]"
# t14400 = prims.reshape(t14399, (1, 2048, 4544)) # t14400: "cuda:0 bf16[1, 2048, 4544]"
del t14399
[t14413, t14419, t14461] = nvFusion21(i14441, t14178, t14221, t14400, t3835, t3967, t3988, t4003, t4008, t4014)
# t3994 = prims.convert_element_type(t3835, dtypes.float32) # t3994: "cuda:0 f32[1, 2048, 4544]"
# t3989 = prims.convert_element_type(t3988, dtypes.float32) # t3989: "cuda:0 f32[1, 2048, 4544]"
# t3990 = prims.convert_element_type(t3967, dtypes.float32) # t3990: "cuda:0 f32[1, 2048, 4544]"
# t3991 = prims.add(t3989, t3990) # t3991: "cuda:0 f32[1, 2048, 4544]"
# t3995 = prims.add(t3991, t3994) # t3995: "cuda:0 f32[1, 2048, 4544]"
# t4005 = prims.broadcast_in_dim(t4003, [1, 2048, 1], [0, 1]) # t4005: "cuda:0 f32[1, 2048, 1]"
# t4009 = prims.broadcast_in_dim(t4005, (1, 2048, 4544), (0, 1, 2)) # t4009: "cuda:0 f32[1, 2048, 4544]"
# t4011 = prims.sub(t3995, t4009) # t4011: "cuda:0 f32[1, 2048, 4544]"
# t4012 = prims.broadcast_in_dim(t4008, (1, 2048, 4544), (0, 1, 2)) # t4012: "cuda:0 f32[1, 2048, 4544]"
# t4013 = prims.mul(t4011, t4012) # t4013: "cuda:0 f32[1, 2048, 4544]"
# t4015 = prims.convert_element_type(t4014, dtypes.float32) # t4015: "cuda:0 f32[1, 2048, 4544]"
# t14458 = prims.convert_element_type(t14178, dtypes.float32) # t14458: "cuda:0 f32[1, 2048, 4544]"
# t14405 = prims.convert_element_type(t14221, dtypes.float32) # t14405: "cuda:0 f32[1, 2048, 4544]"
# t14406 = prims.convert_element_type(t14400, dtypes.float32) # t14406: "cuda:0 f32[1, 2048, 4544]"
# t14407 = prims.add(t14405, t14406) # t14407: "cuda:0 f32[1, 2048, 4544]"
# t14412 = prims.sum(t14407, (0, 1)) # t14412: "cuda:0 f32[4544]"
# t14413 = prims.convert_element_type(t14412, dtypes.bfloat16) # t14413: "cuda:0 bf16[4544]"
# t14414 = prims.mul(t4015, t14407) # t14414: "cuda:0 f32[1, 2048, 4544]"
# t14415 = prims.mul(t4013, t14407) # t14415: "cuda:0 f32[1, 2048, 4544]"
# t14418 = prims.sum(t14415, (0, 1)) # t14418: "cuda:0 f32[4544]"
# t14419 = prims.convert_element_type(t14418, dtypes.bfloat16) # t14419: "cuda:0 bf16[4544]"
# t14420 = prims.mul(t4012, t14414) # t14420: "cuda:0 f32[1, 2048, 4544]"
# t14421 = prims.mul(t4011, t14414) # t14421: "cuda:0 f32[1, 2048, 4544]"
# t14422 = prims.sum(t14421, (0, 2)) # t14422: "cuda:0 f32[2048]"
# t14423 = prims.broadcast_in_dim(t14422, [1, 2048, 1], [1]) # t14423: "cuda:0 f32[1, 2048, 1]"
# t14424 = prims.neg(t14420) # t14424: "cuda:0 f32[1, 2048, 4544]"
# t14426 = prims.sum(t14424, (0, 2)) # t14426: "cuda:0 f32[2048]"
# t14427 = prims.broadcast_in_dim(t14426, [1, 2048, 1], [1]) # t14427: "cuda:0 f32[1, 2048, 1]"
# t14428 = prims.mul(-0.5, t14423) # t14428: "cuda:0 f32[1, 2048, 1]"
# t14429 = prims.pow(t4008, 3.0) # t14429: "cuda:0 f32[1, 2048, 1]"
# t14430 = prims.mul(t14428, t14429) # t14430: "cuda:0 f32[1, 2048, 1]"
# t14432 = prims.sum(t14427, (0, 2)) # t14432: "cuda:0 f32[2048]"
# t14433 = prims.broadcast_in_dim(t14432, [1, 2048], [1]) # t14433: "cuda:0 f32[1, 2048]"
# t14434 = prims.sum(t14430, (0, 2)) # t14434: "cuda:0 f32[2048]"
# t14435 = prims.broadcast_in_dim(t14434, [1, 2048], [1]) # t14435: "cuda:0 f32[1, 2048]"
# t14438 = prims.broadcast_in_dim(t14433, [1, 2048, 1], [0, 1]) # t14438: "cuda:0 f32[1, 2048, 1]"
# t14439 = prims.broadcast_in_dim(t14438, (1, 2048, 4544), (0, 1, 2)) # t14439: "cuda:0 f32[1, 2048, 4544]"
# t14440 = prims.mul(0.00022007042253521127, t14439) # t14440: "cuda:0 f32[1, 2048, 4544]"
# t14442 = prims.broadcast_in_dim(t14435, [1, 2048, 1], [0, 1]) # t14442: "cuda:0 f32[1, 2048, 1]"
# t14443 = prims.broadcast_in_dim(t14442, (1, 2048, 4544), (0, 1, 2)) # t14443: "cuda:0 f32[1, 2048, 4544]"
# t14445 = prims.broadcast_in_dim(t4003, [1, 2048, 1], [0, 1]) # t14445: "cuda:0 f32[1, 2048, 1]"
# t14446 = prims.broadcast_in_dim(t14445, (1, 2048, 4544), (0, 1, 2)) # t14446: "cuda:0 f32[1, 2048, 4544]"
# t14447 = prims.mul(2.0, t14443) # t14447: "cuda:0 f32[1, 2048, 4544]"
# t14448 = prims.sub(t3995, t14446) # t14448: "cuda:0 f32[1, 2048, 4544]"
# t14449 = prims.mul(t14447, t14448) # t14449: "cuda:0 f32[1, 2048, 4544]"
# f14450 = prims.convert_element_type(i14441, float) # f14450: "float 4544.0"
# t14451 = prims.div(t14449, f14450) # t14451: "cuda:0 f32[1, 2048, 4544]"
# t14452 = prims.add(t14440, t14451) # t14452: "cuda:0 f32[1, 2048, 4544]"
# t14456 = prims.add(t14420, t14452) # t14456: "cuda:0 f32[1, 2048, 4544]"
# t14460 = prims.add(t14458, t14456) # t14460: "cuda:0 f32[1, 2048, 4544]"
# t14461 = prims.convert_element_type(t14460, dtypes.bfloat16) # t14461: "cuda:0 bf16[1, 2048, 4544]"
del i14441, t14178, t14221, t14400, t3835, t3967, t3988, t4003, t4008, t4014
t14468 = torch.reshape(t14461, (-1, 4544)) # t14468: "cuda:0 bf16[2048, 4544]"
# t14468 = ltorch.reshape(t14461, (-1, 4544)) # t14468: "cuda:0 bf16[2048, 4544]"
# t14468 = prims.reshape(t14461, (2048, 4544)) # t14468: "cuda:0 bf16[2048, 4544]"
t14472 = torch.permute(t14468, (1, 0)) # t14472: "cuda:0 bf16[4544, 2048]"
# t14472 = ltorch.permute(t14468, (1, 0)) # t14472: "cuda:0 bf16[4544, 2048]"
# t14472 = prims.transpose(t14468, (1, 0)) # t14472: "cuda:0 bf16[4544, 2048]"
t14469 = torch.matmul(t14468, t_transformer_h_24_mlp_proj_weight) # t14469: "cuda:0 bf16[2048, 18176]"
# t14469 = ltorch.matmul(t14468, t_transformer_h_24_mlp_proj_weight) # t14469: "cuda:0 bf16[2048, 18176]"
# t14469 = prims.matmul(t14468, t_transformer_h_24_mlp_proj_weight) # t14469: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_24_mlp_proj_weight
t14474 = torch.matmul(t14472, t14473) # t14474: "cuda:0 bf16[4544, 18176]"
# t14474 = ltorch.matmul(t14472, t14473) # t14474: "cuda:0 bf16[4544, 18176]"
# t14474 = prims.matmul(t14472, t14473) # t14474: "cuda:0 bf16[4544, 18176]"
del t14473
t14510 = torch.matmul(t14468, t_transformer_h_24_attn_proj_weight) # t14510: "cuda:0 bf16[2048, 4544]"
# t14510 = ltorch.matmul(t14509, t_transformer_h_24_attn_proj_weight) # t14510: "cuda:0 bf16[2048, 4544]"
# t14510 = prims.matmul(t14509, t_transformer_h_24_attn_proj_weight) # t14510: "cuda:0 bf16[2048, 4544]"
del t14468, t_transformer_h_24_attn_proj_weight
t14515 = torch.matmul(t14472, t14514) # t14515: "cuda:0 bf16[4544, 4544]"
# t14515 = ltorch.matmul(t14513, t14514) # t14515: "cuda:0 bf16[4544, 4544]"
# t14515 = prims.matmul(t14513, t14514) # t14515: "cuda:0 bf16[4544, 4544]"
del t14472, t14514
t14470 = torch.reshape(t14469, (1, 2048, 18176)) # t14470: "cuda:0 bf16[1, 2048, 18176]"
# t14470 = ltorch.reshape(t14469, (1, 2048, 18176)) # t14470: "cuda:0 bf16[1, 2048, 18176]"
# t14470 = prims.reshape(t14469, (1, 2048, 18176)) # t14470: "cuda:0 bf16[1, 2048, 18176]"
del t14469
t14511 = torch.reshape(t14510, (1, 2048, 4544)) # t14511: "cuda:0 bf16[1, 2048, 4544]"
# t14511 = ltorch.reshape(t14510, (1, 2048, 4544)) # t14511: "cuda:0 bf16[1, 2048, 4544]"
# t14511 = prims.reshape(t14510, (1, 2048, 4544)) # t14511: "cuda:0 bf16[1, 2048, 4544]"
del t14510
t14519 = torch.reshape(t14511, (1, 2048, 71, 64)) # t14519: "cuda:0 bf16[1, 2048, 71, 64]"
# t14519 = ltorch.reshape(t14511, (1, 2048, 71, 64)) # t14519: "cuda:0 bf16[1, 2048, 71, 64]"
# t14519 = prims.reshape(t14511, (1, 2048, 71, 64)) # t14519: "cuda:0 bf16[1, 2048, 71, 64]"
del t14511
t14522 = torch.permute(t14519, (0, 2, 1, 3)) # t14522: "cuda:0 bf16[1, 71, 2048, 64]"
# t14522 = ltorch.permute(t14519, (0, 2, 1, 3)) # t14522: "cuda:0 bf16[1, 71, 2048, 64]"
# t14522 = prims.transpose(t14519, (0, 2, 1, 3)) # t14522: "cuda:0 bf16[1, 71, 2048, 64]"
del t14519
[t14501] = nvFusion22(f1599, f1601, t14470, t3968)
# t3969 = prims.convert_element_type(t3968, dtypes.float32) # t3969: "cuda:0 f32[1, 2048, 18176]"
# t3971 = prims.div(t3969, 1.4142135623730951) # t3971: "cuda:0 f32[1, 2048, 18176]"
# t3974 = prims.erf(t3971) # t3974: "cuda:0 f32[1, 2048, 18176]"
# t3978 = prims.mul(0.5, t3974) # t3978: "cuda:0 f32[1, 2048, 18176]"
# t3982 = prims.add(0.5, t3978) # t3982: "cuda:0 f32[1, 2048, 18176]"
# t14475 = prims.convert_element_type(t14470, dtypes.float32) # t14475: "cuda:0 f32[1, 2048, 18176]"
# t14476 = prims.mul(t3982, t14475) # t14476: "cuda:0 f32[1, 2048, 18176]"
# t14477 = prims.mul(t3969, t14475) # t14477: "cuda:0 f32[1, 2048, 18176]"
# t14485 = prims.mul(f1601, t14477) # t14485: "cuda:0 f32[1, 2048, 18176]"
# t14488 = prims.pow(t3971, 2.0) # t14488: "cuda:0 f32[1, 2048, 18176]"
# t14489 = prims.neg(t14488) # t14489: "cuda:0 f32[1, 2048, 18176]"
# t14490 = prims.exp(t14489) # t14490: "cuda:0 f32[1, 2048, 18176]"
# t14491 = prims.mul(1.1283791670955126, t14490) # t14491: "cuda:0 f32[1, 2048, 18176]"
# t14492 = prims.mul(t14491, t14485) # t14492: "cuda:0 f32[1, 2048, 18176]"
# t14496 = prims.div(t14492, f1599) # t14496: "cuda:0 f32[1, 2048, 18176]"
# t14500 = prims.add(t14476, t14496) # t14500: "cuda:0 f32[1, 2048, 18176]"
# t14501 = prims.convert_element_type(t14500, dtypes.bfloat16) # t14501: "cuda:0 bf16[1, 2048, 18176]"
del f1599, f1601, t14470, t3968
t14502 = torch.reshape(t14501, (-1, 18176)) # t14502: "cuda:0 bf16[2048, 18176]"
# t14502 = ltorch.reshape(t14501, (-1, 18176)) # t14502: "cuda:0 bf16[2048, 18176]"
# t14502 = prims.reshape(t14501, (2048, 18176)) # t14502: "cuda:0 bf16[2048, 18176]"
del t14501
t14506 = torch.permute(t14502, (1, 0)) # t14506: "cuda:0 bf16[18176, 2048]"
# t14506 = ltorch.permute(t14502, (1, 0)) # t14506: "cuda:0 bf16[18176, 2048]"
# t14506 = prims.transpose(t14502, (1, 0)) # t14506: "cuda:0 bf16[18176, 2048]"
t14508 = torch.matmul(t14506, t14507) # t14508: "cuda:0 bf16[18176, 4544]"
# t14508 = ltorch.matmul(t14506, t14507) # t14508: "cuda:0 bf16[18176, 4544]"
# t14508 = prims.matmul(t14506, t14507) # t14508: "cuda:0 bf16[18176, 4544]"
del t14506
t14503 = torch.matmul(t14502, t_transformer_h_24_mlp_fc_weight) # t14503: "cuda:0 bf16[2048, 4544]"
# t14503 = ltorch.matmul(t14502, t_transformer_h_24_mlp_fc_weight) # t14503: "cuda:0 bf16[2048, 4544]"
# t14503 = prims.matmul(t14502, t_transformer_h_24_mlp_fc_weight) # t14503: "cuda:0 bf16[2048, 4544]"
del t14502, t_transformer_h_24_mlp_fc_weight
(t14523, t14524, t14525) = cudnn_sdpa_bwd(t14522, t3952, t3955, t3905, None, f1590, b1591, t3956, t3957, t3958, t3959, scale=f1592, cat_grad_qkv=False)
del t14522, t3952, t3955, t3905, f1590, b1591, t3956, t3957, t3958, t3959, f1592
t14527 = torch_slice_prim_impl(t14524, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14527: "cuda:0 bf16[1, 71, 2048, 64]"
del t14524
t14531 = torch_slice_prim_impl(t14523, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14531: "cuda:0 bf16[1, 71, 2048, 64]"
del t14523
t14634 = torch.reshape(t14525, (1, 1, 71, 2048, 64)) # t14634: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14634 = ltorch.reshape(t14525, (1, 1, 71, 2048, 64)) # t14634: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14634 = prims.reshape(t14525, (1, 1, 71, 2048, 64)) # t14634: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t14525
[t14668] = nvFusion23(i1563, t14527, t14531, t14634, t61, t66)
# t14528 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t14528: "cuda:0 bf16[1, 71, 2048, 0]"
# t14529 = prims.pad(t14528, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t14529: "cuda:0 bf16[1, 71, 2048, 64]"
# t14532 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t14532: "cuda:0 bf16[1, 71, 2048, 0]"
# t14533 = prims.pad(t14532, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t14533: "cuda:0 bf16[1, 71, 2048, 64]"
# t14534 = prims.convert_element_type(t14527, dtypes.float32) # t14534: "cuda:0 f32[1, 71, 2048, 64]"
# t14538 = prims.mul(t66, t14534) # t14538: "cuda:0 f32[1, 71, 2048, 64]"
# t14541 = prims.convert_element_type(t14538, dtypes.bfloat16) # t14541: "cuda:0 bf16[1, 71, 2048, 64]"
# t14550 = prims.mul(t61, t14534) # t14550: "cuda:0 f32[1, 71, 2048, 64]"
# t14562 = prims.slice_prim(t14541, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t14562: "cuda:0 bf16[1, 71, 2048, 32]"
# t14563 = prims.slice_prim(t14541, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14563: "cuda:0 bf16[1, 71, 2048, 32]"
# t14564 = prims.convert_element_type(t14562, dtypes.float32) # t14564: "cuda:0 f32[1, 71, 2048, 32]"
# t14565 = prims.neg(t14564) # t14565: "cuda:0 f32[1, 71, 2048, 32]"
# t14566 = prims.convert_element_type(t14565, dtypes.bfloat16) # t14566: "cuda:0 bf16[1, 71, 2048, 32]"
# t14567 = prims.pad(t14566, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t14567: "cuda:0 bf16[1, 71, 2048, 64]"
# t14569 = prims.convert_element_type(t14567, dtypes.float32) # t14569: "cuda:0 f32[1, 71, 2048, 64]"
# t14570 = prims.add(t14550, t14569) # t14570: "cuda:0 f32[1, 71, 2048, 64]"
# t14572 = prims.pad(t14563, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t14572: "cuda:0 bf16[1, 71, 2048, 64]"
# t14574 = prims.convert_element_type(t14572, dtypes.float32) # t14574: "cuda:0 f32[1, 71, 2048, 64]"
# t14575 = prims.add(t14570, t14574) # t14575: "cuda:0 f32[1, 71, 2048, 64]"
# t14576 = prims.convert_element_type(t14575, dtypes.bfloat16) # t14576: "cuda:0 bf16[1, 71, 2048, 64]"
# t14577 = prims.pad(t14576, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t14577: "cuda:0 bf16[1, 71, 2048, 64]"
# t14578 = prims.convert_element_type(t14529, dtypes.float32) # t14578: "cuda:0 f32[1, 71, 2048, 64]"
# t14579 = prims.convert_element_type(t14577, dtypes.float32) # t14579: "cuda:0 f32[1, 71, 2048, 64]"
# t14580 = prims.add(t14578, t14579) # t14580: "cuda:0 f32[1, 71, 2048, 64]"
# t14581 = prims.convert_element_type(t14580, dtypes.bfloat16) # t14581: "cuda:0 bf16[1, 71, 2048, 64]"
# t14582 = prims.convert_element_type(t14531, dtypes.float32) # t14582: "cuda:0 f32[1, 71, 2048, 64]"
# t14586 = prims.mul(t66, t14582) # t14586: "cuda:0 f32[1, 71, 2048, 64]"
# t14589 = prims.convert_element_type(t14586, dtypes.bfloat16) # t14589: "cuda:0 bf16[1, 71, 2048, 64]"
# t14598 = prims.mul(t61, t14582) # t14598: "cuda:0 f32[1, 71, 2048, 64]"
# t14610 = prims.slice_prim(t14589, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t14610: "cuda:0 bf16[1, 71, 2048, 32]"
# t14611 = prims.slice_prim(t14589, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14611: "cuda:0 bf16[1, 71, 2048, 32]"
# t14612 = prims.convert_element_type(t14610, dtypes.float32) # t14612: "cuda:0 f32[1, 71, 2048, 32]"
# t14613 = prims.neg(t14612) # t14613: "cuda:0 f32[1, 71, 2048, 32]"
# t14614 = prims.convert_element_type(t14613, dtypes.bfloat16) # t14614: "cuda:0 bf16[1, 71, 2048, 32]"
# t14615 = prims.pad(t14614, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t14615: "cuda:0 bf16[1, 71, 2048, 64]"
# t14617 = prims.convert_element_type(t14615, dtypes.float32) # t14617: "cuda:0 f32[1, 71, 2048, 64]"
# t14618 = prims.add(t14598, t14617) # t14618: "cuda:0 f32[1, 71, 2048, 64]"
# t14620 = prims.pad(t14611, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t14620: "cuda:0 bf16[1, 71, 2048, 64]"
# t14622 = prims.convert_element_type(t14620, dtypes.float32) # t14622: "cuda:0 f32[1, 71, 2048, 64]"
# t14623 = prims.add(t14618, t14622) # t14623: "cuda:0 f32[1, 71, 2048, 64]"
# t14624 = prims.convert_element_type(t14623, dtypes.bfloat16) # t14624: "cuda:0 bf16[1, 71, 2048, 64]"
# t14625 = prims.pad(t14624, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t14625: "cuda:0 bf16[1, 71, 2048, 64]"
# t14626 = prims.convert_element_type(t14533, dtypes.float32) # t14626: "cuda:0 f32[1, 71, 2048, 64]"
# t14627 = prims.convert_element_type(t14625, dtypes.float32) # t14627: "cuda:0 f32[1, 71, 2048, 64]"
# t14628 = prims.add(t14626, t14627) # t14628: "cuda:0 f32[1, 71, 2048, 64]"
# t14629 = prims.convert_element_type(t14628, dtypes.bfloat16) # t14629: "cuda:0 bf16[1, 71, 2048, 64]"
# t14639 = prims.reshape(t14581, (1, 1, 71, 2048, 64)) # t14639: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14644 = prims.reshape(t14629, (1, 1, 71, 2048, 64)) # t14644: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14650 = prims.convert_element_type(t14634, dtypes.float32) # t14650: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t14651 = prims.sum(t14650, (0, 1, 2)) # t14651: "cuda:0 f32[2048, 64]"
# t14652 = prims.convert_element_type(t14651, dtypes.bfloat16) # t14652: "cuda:0 bf16[2048, 64]"
# t14653 = prims.broadcast_in_dim(t14652, [1, 1, 1, 2048, 64], [3, 4]) # t14653: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t14659 = prims.convert_element_type(t14639, dtypes.float32) # t14659: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t14660 = prims.sum(t14659, (0, 1, 2)) # t14660: "cuda:0 f32[2048, 64]"
# t14661 = prims.convert_element_type(t14660, dtypes.bfloat16) # t14661: "cuda:0 bf16[2048, 64]"
# t14662 = prims.broadcast_in_dim(t14661, [1, 1, 1, 2048, 64], [3, 4]) # t14662: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t14668 = prims.cat((t14644, t14662, t14653), i1563) # t14668: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1563, t14527, t14531, t14634
t14674 = torch.permute(t14668, (0, 3, 1, 2, 4)) # t14674: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t14674 = ltorch.permute(t14668, (0, 3, 1, 2, 4)) # t14674: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t14674 = prims.transpose(t14668, (0, 3, 1, 2, 4)) # t14674: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t14668
t14680 = torch.reshape(t14674, (1, 2048, 4672)) # t14680: "cuda:0 bf16[1, 2048, 4672]"
# t14680 = ltorch.reshape(t14674, (1, 2048, 4672)) # t14680: "cuda:0 bf16[1, 2048, 4672]"
# t14680 = prims.reshape(t14674, (1, 2048, 4672)) # t14680: "cuda:0 bf16[1, 2048, 4672]"
del t14674
t14681 = torch.reshape(t14680, (-1, 4672)) # t14681: "cuda:0 bf16[2048, 4672]"
# t14681 = ltorch.reshape(t14680, (-1, 4672)) # t14681: "cuda:0 bf16[2048, 4672]"
# t14681 = prims.reshape(t14680, (2048, 4672)) # t14681: "cuda:0 bf16[2048, 4672]"
del t14680
t14685 = torch.permute(t14681, (1, 0)) # t14685: "cuda:0 bf16[4672, 2048]"
# t14685 = ltorch.permute(t14681, (1, 0)) # t14685: "cuda:0 bf16[4672, 2048]"
# t14685 = prims.transpose(t14681, (1, 0)) # t14685: "cuda:0 bf16[4672, 2048]"
t14687 = torch.matmul(t14685, t14507) # t14687: "cuda:0 bf16[4672, 4544]"
# t14687 = ltorch.matmul(t14685, t14686) # t14687: "cuda:0 bf16[4672, 4544]"
# t14687 = prims.matmul(t14685, t14686) # t14687: "cuda:0 bf16[4672, 4544]"
del t14685, t14507
t14682 = torch.matmul(t14681, t_transformer_h_24_attn_attn_weight) # t14682: "cuda:0 bf16[2048, 4544]"
# t14682 = ltorch.matmul(t14681, t_transformer_h_24_attn_attn_weight) # t14682: "cuda:0 bf16[2048, 4544]"
# t14682 = prims.matmul(t14681, t_transformer_h_24_attn_attn_weight) # t14682: "cuda:0 bf16[2048, 4544]"
del t14681, t_transformer_h_24_attn_attn_weight
t14504 = torch.reshape(t14503, (1, 2048, 4544)) # t14504: "cuda:0 bf16[1, 2048, 4544]"
# t14504 = ltorch.reshape(t14503, (1, 2048, 4544)) # t14504: "cuda:0 bf16[1, 2048, 4544]"
# t14504 = prims.reshape(t14503, (1, 2048, 4544)) # t14504: "cuda:0 bf16[1, 2048, 4544]"
del t14503
t14683 = torch.reshape(t14682, (1, 2048, 4544)) # t14683: "cuda:0 bf16[1, 2048, 4544]"
# t14683 = ltorch.reshape(t14682, (1, 2048, 4544)) # t14683: "cuda:0 bf16[1, 2048, 4544]"
# t14683 = prims.reshape(t14682, (1, 2048, 4544)) # t14683: "cuda:0 bf16[1, 2048, 4544]"
del t14682
[t14696, t14702, t14744] = nvFusion24(i14724, t14461, t14504, t14683, t3674, t3806, t3827, t3842, t3847, t3853)
# t3833 = prims.convert_element_type(t3674, dtypes.float32) # t3833: "cuda:0 f32[1, 2048, 4544]"
# t3828 = prims.convert_element_type(t3827, dtypes.float32) # t3828: "cuda:0 f32[1, 2048, 4544]"
# t3829 = prims.convert_element_type(t3806, dtypes.float32) # t3829: "cuda:0 f32[1, 2048, 4544]"
# t3830 = prims.add(t3828, t3829) # t3830: "cuda:0 f32[1, 2048, 4544]"
# t3834 = prims.add(t3830, t3833) # t3834: "cuda:0 f32[1, 2048, 4544]"
# t3844 = prims.broadcast_in_dim(t3842, [1, 2048, 1], [0, 1]) # t3844: "cuda:0 f32[1, 2048, 1]"
# t3848 = prims.broadcast_in_dim(t3844, (1, 2048, 4544), (0, 1, 2)) # t3848: "cuda:0 f32[1, 2048, 4544]"
# t3850 = prims.sub(t3834, t3848) # t3850: "cuda:0 f32[1, 2048, 4544]"
# t3851 = prims.broadcast_in_dim(t3847, (1, 2048, 4544), (0, 1, 2)) # t3851: "cuda:0 f32[1, 2048, 4544]"
# t3852 = prims.mul(t3850, t3851) # t3852: "cuda:0 f32[1, 2048, 4544]"
# t3854 = prims.convert_element_type(t3853, dtypes.float32) # t3854: "cuda:0 f32[1, 2048, 4544]"
# t14741 = prims.convert_element_type(t14461, dtypes.float32) # t14741: "cuda:0 f32[1, 2048, 4544]"
# t14688 = prims.convert_element_type(t14504, dtypes.float32) # t14688: "cuda:0 f32[1, 2048, 4544]"
# t14689 = prims.convert_element_type(t14683, dtypes.float32) # t14689: "cuda:0 f32[1, 2048, 4544]"
# t14690 = prims.add(t14688, t14689) # t14690: "cuda:0 f32[1, 2048, 4544]"
# t14695 = prims.sum(t14690, (0, 1)) # t14695: "cuda:0 f32[4544]"
# t14696 = prims.convert_element_type(t14695, dtypes.bfloat16) # t14696: "cuda:0 bf16[4544]"
# t14697 = prims.mul(t3854, t14690) # t14697: "cuda:0 f32[1, 2048, 4544]"
# t14698 = prims.mul(t3852, t14690) # t14698: "cuda:0 f32[1, 2048, 4544]"
# t14701 = prims.sum(t14698, (0, 1)) # t14701: "cuda:0 f32[4544]"
# t14702 = prims.convert_element_type(t14701, dtypes.bfloat16) # t14702: "cuda:0 bf16[4544]"
# t14703 = prims.mul(t3851, t14697) # t14703: "cuda:0 f32[1, 2048, 4544]"
# t14704 = prims.mul(t3850, t14697) # t14704: "cuda:0 f32[1, 2048, 4544]"
# t14705 = prims.sum(t14704, (0, 2)) # t14705: "cuda:0 f32[2048]"
# t14706 = prims.broadcast_in_dim(t14705, [1, 2048, 1], [1]) # t14706: "cuda:0 f32[1, 2048, 1]"
# t14707 = prims.neg(t14703) # t14707: "cuda:0 f32[1, 2048, 4544]"
# t14709 = prims.sum(t14707, (0, 2)) # t14709: "cuda:0 f32[2048]"
# t14710 = prims.broadcast_in_dim(t14709, [1, 2048, 1], [1]) # t14710: "cuda:0 f32[1, 2048, 1]"
# t14711 = prims.mul(-0.5, t14706) # t14711: "cuda:0 f32[1, 2048, 1]"
# t14712 = prims.pow(t3847, 3.0) # t14712: "cuda:0 f32[1, 2048, 1]"
# t14713 = prims.mul(t14711, t14712) # t14713: "cuda:0 f32[1, 2048, 1]"
# t14715 = prims.sum(t14710, (0, 2)) # t14715: "cuda:0 f32[2048]"
# t14716 = prims.broadcast_in_dim(t14715, [1, 2048], [1]) # t14716: "cuda:0 f32[1, 2048]"
# t14717 = prims.sum(t14713, (0, 2)) # t14717: "cuda:0 f32[2048]"
# t14718 = prims.broadcast_in_dim(t14717, [1, 2048], [1]) # t14718: "cuda:0 f32[1, 2048]"
# t14721 = prims.broadcast_in_dim(t14716, [1, 2048, 1], [0, 1]) # t14721: "cuda:0 f32[1, 2048, 1]"
# t14722 = prims.broadcast_in_dim(t14721, (1, 2048, 4544), (0, 1, 2)) # t14722: "cuda:0 f32[1, 2048, 4544]"
# t14723 = prims.mul(0.00022007042253521127, t14722) # t14723: "cuda:0 f32[1, 2048, 4544]"
# t14725 = prims.broadcast_in_dim(t14718, [1, 2048, 1], [0, 1]) # t14725: "cuda:0 f32[1, 2048, 1]"
# t14726 = prims.broadcast_in_dim(t14725, (1, 2048, 4544), (0, 1, 2)) # t14726: "cuda:0 f32[1, 2048, 4544]"
# t14728 = prims.broadcast_in_dim(t3842, [1, 2048, 1], [0, 1]) # t14728: "cuda:0 f32[1, 2048, 1]"
# t14729 = prims.broadcast_in_dim(t14728, (1, 2048, 4544), (0, 1, 2)) # t14729: "cuda:0 f32[1, 2048, 4544]"
# t14730 = prims.mul(2.0, t14726) # t14730: "cuda:0 f32[1, 2048, 4544]"
# t14731 = prims.sub(t3834, t14729) # t14731: "cuda:0 f32[1, 2048, 4544]"
# t14732 = prims.mul(t14730, t14731) # t14732: "cuda:0 f32[1, 2048, 4544]"
# f14733 = prims.convert_element_type(i14724, float) # f14733: "float 4544.0"
# t14734 = prims.div(t14732, f14733) # t14734: "cuda:0 f32[1, 2048, 4544]"
# t14735 = prims.add(t14723, t14734) # t14735: "cuda:0 f32[1, 2048, 4544]"
# t14739 = prims.add(t14703, t14735) # t14739: "cuda:0 f32[1, 2048, 4544]"
# t14743 = prims.add(t14741, t14739) # t14743: "cuda:0 f32[1, 2048, 4544]"
# t14744 = prims.convert_element_type(t14743, dtypes.bfloat16) # t14744: "cuda:0 bf16[1, 2048, 4544]"
del i14724, t14461, t14504, t14683, t3674, t3806, t3827, t3842, t3847, t3853
t14751 = torch.reshape(t14744, (-1, 4544)) # t14751: "cuda:0 bf16[2048, 4544]"
# t14751 = ltorch.reshape(t14744, (-1, 4544)) # t14751: "cuda:0 bf16[2048, 4544]"
# t14751 = prims.reshape(t14744, (2048, 4544)) # t14751: "cuda:0 bf16[2048, 4544]"
t14755 = torch.permute(t14751, (1, 0)) # t14755: "cuda:0 bf16[4544, 2048]"
# t14755 = ltorch.permute(t14751, (1, 0)) # t14755: "cuda:0 bf16[4544, 2048]"
# t14755 = prims.transpose(t14751, (1, 0)) # t14755: "cuda:0 bf16[4544, 2048]"
t14757 = torch.matmul(t14755, t14756) # t14757: "cuda:0 bf16[4544, 18176]"
# t14757 = ltorch.matmul(t14755, t14756) # t14757: "cuda:0 bf16[4544, 18176]"
# t14757 = prims.matmul(t14755, t14756) # t14757: "cuda:0 bf16[4544, 18176]"
del t14756
t14793 = torch.matmul(t14751, t_transformer_h_23_attn_proj_weight) # t14793: "cuda:0 bf16[2048, 4544]"
# t14793 = ltorch.matmul(t14792, t_transformer_h_23_attn_proj_weight) # t14793: "cuda:0 bf16[2048, 4544]"
# t14793 = prims.matmul(t14792, t_transformer_h_23_attn_proj_weight) # t14793: "cuda:0 bf16[2048, 4544]"
del t_transformer_h_23_attn_proj_weight
t14798 = torch.matmul(t14755, t14797) # t14798: "cuda:0 bf16[4544, 4544]"
# t14798 = ltorch.matmul(t14796, t14797) # t14798: "cuda:0 bf16[4544, 4544]"
# t14798 = prims.matmul(t14796, t14797) # t14798: "cuda:0 bf16[4544, 4544]"
del t14755, t14797
t14752 = torch.matmul(t14751, t_transformer_h_23_mlp_proj_weight) # t14752: "cuda:0 bf16[2048, 18176]"
# t14752 = ltorch.matmul(t14751, t_transformer_h_23_mlp_proj_weight) # t14752: "cuda:0 bf16[2048, 18176]"
# t14752 = prims.matmul(t14751, t_transformer_h_23_mlp_proj_weight) # t14752: "cuda:0 bf16[2048, 18176]"
del t14751, t_transformer_h_23_mlp_proj_weight
t14794 = torch.reshape(t14793, (1, 2048, 4544)) # t14794: "cuda:0 bf16[1, 2048, 4544]"
# t14794 = ltorch.reshape(t14793, (1, 2048, 4544)) # t14794: "cuda:0 bf16[1, 2048, 4544]"
# t14794 = prims.reshape(t14793, (1, 2048, 4544)) # t14794: "cuda:0 bf16[1, 2048, 4544]"
del t14793
t14802 = torch.reshape(t14794, (1, 2048, 71, 64)) # t14802: "cuda:0 bf16[1, 2048, 71, 64]"
# t14802 = ltorch.reshape(t14794, (1, 2048, 71, 64)) # t14802: "cuda:0 bf16[1, 2048, 71, 64]"
# t14802 = prims.reshape(t14794, (1, 2048, 71, 64)) # t14802: "cuda:0 bf16[1, 2048, 71, 64]"
del t14794
t14805 = torch.permute(t14802, (0, 2, 1, 3)) # t14805: "cuda:0 bf16[1, 71, 2048, 64]"
# t14805 = ltorch.permute(t14802, (0, 2, 1, 3)) # t14805: "cuda:0 bf16[1, 71, 2048, 64]"
# t14805 = prims.transpose(t14802, (0, 2, 1, 3)) # t14805: "cuda:0 bf16[1, 71, 2048, 64]"
del t14802
t14753 = torch.reshape(t14752, (1, 2048, 18176)) # t14753: "cuda:0 bf16[1, 2048, 18176]"
# t14753 = ltorch.reshape(t14752, (1, 2048, 18176)) # t14753: "cuda:0 bf16[1, 2048, 18176]"
# t14753 = prims.reshape(t14752, (1, 2048, 18176)) # t14753: "cuda:0 bf16[1, 2048, 18176]"
del t14752
[t14784] = nvFusion25(f1535, f1537, t14753, t3807)
# t3808 = prims.convert_element_type(t3807, dtypes.float32) # t3808: "cuda:0 f32[1, 2048, 18176]"
# t3810 = prims.div(t3808, 1.4142135623730951) # t3810: "cuda:0 f32[1, 2048, 18176]"
# t3813 = prims.erf(t3810) # t3813: "cuda:0 f32[1, 2048, 18176]"
# t3817 = prims.mul(0.5, t3813) # t3817: "cuda:0 f32[1, 2048, 18176]"
# t3821 = prims.add(0.5, t3817) # t3821: "cuda:0 f32[1, 2048, 18176]"
# t14758 = prims.convert_element_type(t14753, dtypes.float32) # t14758: "cuda:0 f32[1, 2048, 18176]"
# t14759 = prims.mul(t3821, t14758) # t14759: "cuda:0 f32[1, 2048, 18176]"
# t14760 = prims.mul(t3808, t14758) # t14760: "cuda:0 f32[1, 2048, 18176]"
# t14768 = prims.mul(f1537, t14760) # t14768: "cuda:0 f32[1, 2048, 18176]"
# t14771 = prims.pow(t3810, 2.0) # t14771: "cuda:0 f32[1, 2048, 18176]"
# t14772 = prims.neg(t14771) # t14772: "cuda:0 f32[1, 2048, 18176]"
# t14773 = prims.exp(t14772) # t14773: "cuda:0 f32[1, 2048, 18176]"
# t14774 = prims.mul(1.1283791670955126, t14773) # t14774: "cuda:0 f32[1, 2048, 18176]"
# t14775 = prims.mul(t14774, t14768) # t14775: "cuda:0 f32[1, 2048, 18176]"
# t14779 = prims.div(t14775, f1535) # t14779: "cuda:0 f32[1, 2048, 18176]"
# t14783 = prims.add(t14759, t14779) # t14783: "cuda:0 f32[1, 2048, 18176]"
# t14784 = prims.convert_element_type(t14783, dtypes.bfloat16) # t14784: "cuda:0 bf16[1, 2048, 18176]"
del f1535, f1537, t14753, t3807
t14785 = torch.reshape(t14784, (-1, 18176)) # t14785: "cuda:0 bf16[2048, 18176]"
# t14785 = ltorch.reshape(t14784, (-1, 18176)) # t14785: "cuda:0 bf16[2048, 18176]"
# t14785 = prims.reshape(t14784, (2048, 18176)) # t14785: "cuda:0 bf16[2048, 18176]"
del t14784
t14789 = torch.permute(t14785, (1, 0)) # t14789: "cuda:0 bf16[18176, 2048]"
# t14789 = ltorch.permute(t14785, (1, 0)) # t14789: "cuda:0 bf16[18176, 2048]"
# t14789 = prims.transpose(t14785, (1, 0)) # t14789: "cuda:0 bf16[18176, 2048]"
(t14806, t14807, t14808) = cudnn_sdpa_bwd(t14805, t3791, t3794, t3744, None, f1526, b1527, t3795, t3796, t3797, t3798, scale=f1528, cat_grad_qkv=False)
del t14805, t3791, t3794, t3744, f1526, b1527, t3795, t3796, t3797, t3798, f1528
t14791 = torch.matmul(t14789, t14790) # t14791: "cuda:0 bf16[18176, 4544]"
# t14791 = ltorch.matmul(t14789, t14790) # t14791: "cuda:0 bf16[18176, 4544]"
# t14791 = prims.matmul(t14789, t14790) # t14791: "cuda:0 bf16[18176, 4544]"
del t14789
t14786 = torch.matmul(t14785, t_transformer_h_23_mlp_fc_weight) # t14786: "cuda:0 bf16[2048, 4544]"
# t14786 = ltorch.matmul(t14785, t_transformer_h_23_mlp_fc_weight) # t14786: "cuda:0 bf16[2048, 4544]"
# t14786 = prims.matmul(t14785, t_transformer_h_23_mlp_fc_weight) # t14786: "cuda:0 bf16[2048, 4544]"
del t14785, t_transformer_h_23_mlp_fc_weight
t14810 = torch_slice_prim_impl(t14807, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14810: "cuda:0 bf16[1, 71, 2048, 64]"
del t14807
t14814 = torch_slice_prim_impl(t14806, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14814: "cuda:0 bf16[1, 71, 2048, 64]"
del t14806
t14917 = torch.reshape(t14808, (1, 1, 71, 2048, 64)) # t14917: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14917 = ltorch.reshape(t14808, (1, 1, 71, 2048, 64)) # t14917: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14917 = prims.reshape(t14808, (1, 1, 71, 2048, 64)) # t14917: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t14808
[t14951] = nvFusion26(i1499, t14810, t14814, t14917, t61, t66)
# t14811 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t14811: "cuda:0 bf16[1, 71, 2048, 0]"
# t14812 = prims.pad(t14811, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t14812: "cuda:0 bf16[1, 71, 2048, 64]"
# t14815 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t14815: "cuda:0 bf16[1, 71, 2048, 0]"
# t14816 = prims.pad(t14815, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t14816: "cuda:0 bf16[1, 71, 2048, 64]"
# t14817 = prims.convert_element_type(t14810, dtypes.float32) # t14817: "cuda:0 f32[1, 71, 2048, 64]"
# t14821 = prims.mul(t66, t14817) # t14821: "cuda:0 f32[1, 71, 2048, 64]"
# t14824 = prims.convert_element_type(t14821, dtypes.bfloat16) # t14824: "cuda:0 bf16[1, 71, 2048, 64]"
# t14833 = prims.mul(t61, t14817) # t14833: "cuda:0 f32[1, 71, 2048, 64]"
# t14845 = prims.slice_prim(t14824, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t14845: "cuda:0 bf16[1, 71, 2048, 32]"
# t14846 = prims.slice_prim(t14824, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14846: "cuda:0 bf16[1, 71, 2048, 32]"
# t14847 = prims.convert_element_type(t14845, dtypes.float32) # t14847: "cuda:0 f32[1, 71, 2048, 32]"
# t14848 = prims.neg(t14847) # t14848: "cuda:0 f32[1, 71, 2048, 32]"
# t14849 = prims.convert_element_type(t14848, dtypes.bfloat16) # t14849: "cuda:0 bf16[1, 71, 2048, 32]"
# t14850 = prims.pad(t14849, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t14850: "cuda:0 bf16[1, 71, 2048, 64]"
# t14852 = prims.convert_element_type(t14850, dtypes.float32) # t14852: "cuda:0 f32[1, 71, 2048, 64]"
# t14853 = prims.add(t14833, t14852) # t14853: "cuda:0 f32[1, 71, 2048, 64]"
# t14855 = prims.pad(t14846, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t14855: "cuda:0 bf16[1, 71, 2048, 64]"
# t14857 = prims.convert_element_type(t14855, dtypes.float32) # t14857: "cuda:0 f32[1, 71, 2048, 64]"
# t14858 = prims.add(t14853, t14857) # t14858: "cuda:0 f32[1, 71, 2048, 64]"
# t14859 = prims.convert_element_type(t14858, dtypes.bfloat16) # t14859: "cuda:0 bf16[1, 71, 2048, 64]"
# t14860 = prims.pad(t14859, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t14860: "cuda:0 bf16[1, 71, 2048, 64]"
# t14861 = prims.convert_element_type(t14812, dtypes.float32) # t14861: "cuda:0 f32[1, 71, 2048, 64]"
# t14862 = prims.convert_element_type(t14860, dtypes.float32) # t14862: "cuda:0 f32[1, 71, 2048, 64]"
# t14863 = prims.add(t14861, t14862) # t14863: "cuda:0 f32[1, 71, 2048, 64]"
# t14864 = prims.convert_element_type(t14863, dtypes.bfloat16) # t14864: "cuda:0 bf16[1, 71, 2048, 64]"
# t14865 = prims.convert_element_type(t14814, dtypes.float32) # t14865: "cuda:0 f32[1, 71, 2048, 64]"
# t14869 = prims.mul(t66, t14865) # t14869: "cuda:0 f32[1, 71, 2048, 64]"
# t14872 = prims.convert_element_type(t14869, dtypes.bfloat16) # t14872: "cuda:0 bf16[1, 71, 2048, 64]"
# t14881 = prims.mul(t61, t14865) # t14881: "cuda:0 f32[1, 71, 2048, 64]"
# t14893 = prims.slice_prim(t14872, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t14893: "cuda:0 bf16[1, 71, 2048, 32]"
# t14894 = prims.slice_prim(t14872, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t14894: "cuda:0 bf16[1, 71, 2048, 32]"
# t14895 = prims.convert_element_type(t14893, dtypes.float32) # t14895: "cuda:0 f32[1, 71, 2048, 32]"
# t14896 = prims.neg(t14895) # t14896: "cuda:0 f32[1, 71, 2048, 32]"
# t14897 = prims.convert_element_type(t14896, dtypes.bfloat16) # t14897: "cuda:0 bf16[1, 71, 2048, 32]"
# t14898 = prims.pad(t14897, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t14898: "cuda:0 bf16[1, 71, 2048, 64]"
# t14900 = prims.convert_element_type(t14898, dtypes.float32) # t14900: "cuda:0 f32[1, 71, 2048, 64]"
# t14901 = prims.add(t14881, t14900) # t14901: "cuda:0 f32[1, 71, 2048, 64]"
# t14903 = prims.pad(t14894, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t14903: "cuda:0 bf16[1, 71, 2048, 64]"
# t14905 = prims.convert_element_type(t14903, dtypes.float32) # t14905: "cuda:0 f32[1, 71, 2048, 64]"
# t14906 = prims.add(t14901, t14905) # t14906: "cuda:0 f32[1, 71, 2048, 64]"
# t14907 = prims.convert_element_type(t14906, dtypes.bfloat16) # t14907: "cuda:0 bf16[1, 71, 2048, 64]"
# t14908 = prims.pad(t14907, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t14908: "cuda:0 bf16[1, 71, 2048, 64]"
# t14909 = prims.convert_element_type(t14816, dtypes.float32) # t14909: "cuda:0 f32[1, 71, 2048, 64]"
# t14910 = prims.convert_element_type(t14908, dtypes.float32) # t14910: "cuda:0 f32[1, 71, 2048, 64]"
# t14911 = prims.add(t14909, t14910) # t14911: "cuda:0 f32[1, 71, 2048, 64]"
# t14912 = prims.convert_element_type(t14911, dtypes.bfloat16) # t14912: "cuda:0 bf16[1, 71, 2048, 64]"
# t14922 = prims.reshape(t14864, (1, 1, 71, 2048, 64)) # t14922: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14927 = prims.reshape(t14912, (1, 1, 71, 2048, 64)) # t14927: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t14933 = prims.convert_element_type(t14917, dtypes.float32) # t14933: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t14934 = prims.sum(t14933, (0, 1, 2)) # t14934: "cuda:0 f32[2048, 64]"
# t14935 = prims.convert_element_type(t14934, dtypes.bfloat16) # t14935: "cuda:0 bf16[2048, 64]"
# t14936 = prims.broadcast_in_dim(t14935, [1, 1, 1, 2048, 64], [3, 4]) # t14936: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t14942 = prims.convert_element_type(t14922, dtypes.float32) # t14942: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t14943 = prims.sum(t14942, (0, 1, 2)) # t14943: "cuda:0 f32[2048, 64]"
# t14944 = prims.convert_element_type(t14943, dtypes.bfloat16) # t14944: "cuda:0 bf16[2048, 64]"
# t14945 = prims.broadcast_in_dim(t14944, [1, 1, 1, 2048, 64], [3, 4]) # t14945: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t14951 = prims.cat((t14927, t14945, t14936), i1499) # t14951: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1499, t14810, t14814, t14917
t14957 = torch.permute(t14951, (0, 3, 1, 2, 4)) # t14957: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t14957 = ltorch.permute(t14951, (0, 3, 1, 2, 4)) # t14957: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t14957 = prims.transpose(t14951, (0, 3, 1, 2, 4)) # t14957: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t14951
t14963 = torch.reshape(t14957, (1, 2048, 4672)) # t14963: "cuda:0 bf16[1, 2048, 4672]"
# t14963 = ltorch.reshape(t14957, (1, 2048, 4672)) # t14963: "cuda:0 bf16[1, 2048, 4672]"
# t14963 = prims.reshape(t14957, (1, 2048, 4672)) # t14963: "cuda:0 bf16[1, 2048, 4672]"
del t14957
t14964 = torch.reshape(t14963, (-1, 4672)) # t14964: "cuda:0 bf16[2048, 4672]"
# t14964 = ltorch.reshape(t14963, (-1, 4672)) # t14964: "cuda:0 bf16[2048, 4672]"
# t14964 = prims.reshape(t14963, (2048, 4672)) # t14964: "cuda:0 bf16[2048, 4672]"
del t14963
t14968 = torch.permute(t14964, (1, 0)) # t14968: "cuda:0 bf16[4672, 2048]"
# t14968 = ltorch.permute(t14964, (1, 0)) # t14968: "cuda:0 bf16[4672, 2048]"
# t14968 = prims.transpose(t14964, (1, 0)) # t14968: "cuda:0 bf16[4672, 2048]"
t14970 = torch.matmul(t14968, t14790) # t14970: "cuda:0 bf16[4672, 4544]"
# t14970 = ltorch.matmul(t14968, t14969) # t14970: "cuda:0 bf16[4672, 4544]"
# t14970 = prims.matmul(t14968, t14969) # t14970: "cuda:0 bf16[4672, 4544]"
del t14968, t14790
t14965 = torch.matmul(t14964, t_transformer_h_23_attn_attn_weight) # t14965: "cuda:0 bf16[2048, 4544]"
# t14965 = ltorch.matmul(t14964, t_transformer_h_23_attn_attn_weight) # t14965: "cuda:0 bf16[2048, 4544]"
# t14965 = prims.matmul(t14964, t_transformer_h_23_attn_attn_weight) # t14965: "cuda:0 bf16[2048, 4544]"
del t14964, t_transformer_h_23_attn_attn_weight
t14787 = torch.reshape(t14786, (1, 2048, 4544)) # t14787: "cuda:0 bf16[1, 2048, 4544]"
# t14787 = ltorch.reshape(t14786, (1, 2048, 4544)) # t14787: "cuda:0 bf16[1, 2048, 4544]"
# t14787 = prims.reshape(t14786, (1, 2048, 4544)) # t14787: "cuda:0 bf16[1, 2048, 4544]"
del t14786
t14966 = torch.reshape(t14965, (1, 2048, 4544)) # t14966: "cuda:0 bf16[1, 2048, 4544]"
# t14966 = ltorch.reshape(t14965, (1, 2048, 4544)) # t14966: "cuda:0 bf16[1, 2048, 4544]"
# t14966 = prims.reshape(t14965, (1, 2048, 4544)) # t14966: "cuda:0 bf16[1, 2048, 4544]"
del t14965
[t14979, t14985, t15027] = nvFusion27(i15007, t14744, t14787, t14966, t3513, t3645, t3666, t3681, t3686, t3692)
# t3672 = prims.convert_element_type(t3513, dtypes.float32) # t3672: "cuda:0 f32[1, 2048, 4544]"
# t3667 = prims.convert_element_type(t3666, dtypes.float32) # t3667: "cuda:0 f32[1, 2048, 4544]"
# t3668 = prims.convert_element_type(t3645, dtypes.float32) # t3668: "cuda:0 f32[1, 2048, 4544]"
# t3669 = prims.add(t3667, t3668) # t3669: "cuda:0 f32[1, 2048, 4544]"
# t3673 = prims.add(t3669, t3672) # t3673: "cuda:0 f32[1, 2048, 4544]"
# t3683 = prims.broadcast_in_dim(t3681, [1, 2048, 1], [0, 1]) # t3683: "cuda:0 f32[1, 2048, 1]"
# t3687 = prims.broadcast_in_dim(t3683, (1, 2048, 4544), (0, 1, 2)) # t3687: "cuda:0 f32[1, 2048, 4544]"
# t3689 = prims.sub(t3673, t3687) # t3689: "cuda:0 f32[1, 2048, 4544]"
# t3690 = prims.broadcast_in_dim(t3686, (1, 2048, 4544), (0, 1, 2)) # t3690: "cuda:0 f32[1, 2048, 4544]"
# t3691 = prims.mul(t3689, t3690) # t3691: "cuda:0 f32[1, 2048, 4544]"
# t3693 = prims.convert_element_type(t3692, dtypes.float32) # t3693: "cuda:0 f32[1, 2048, 4544]"
# t15024 = prims.convert_element_type(t14744, dtypes.float32) # t15024: "cuda:0 f32[1, 2048, 4544]"
# t14971 = prims.convert_element_type(t14787, dtypes.float32) # t14971: "cuda:0 f32[1, 2048, 4544]"
# t14972 = prims.convert_element_type(t14966, dtypes.float32) # t14972: "cuda:0 f32[1, 2048, 4544]"
# t14973 = prims.add(t14971, t14972) # t14973: "cuda:0 f32[1, 2048, 4544]"
# t14978 = prims.sum(t14973, (0, 1)) # t14978: "cuda:0 f32[4544]"
# t14979 = prims.convert_element_type(t14978, dtypes.bfloat16) # t14979: "cuda:0 bf16[4544]"
# t14980 = prims.mul(t3693, t14973) # t14980: "cuda:0 f32[1, 2048, 4544]"
# t14981 = prims.mul(t3691, t14973) # t14981: "cuda:0 f32[1, 2048, 4544]"
# t14984 = prims.sum(t14981, (0, 1)) # t14984: "cuda:0 f32[4544]"
# t14985 = prims.convert_element_type(t14984, dtypes.bfloat16) # t14985: "cuda:0 bf16[4544]"
# t14986 = prims.mul(t3690, t14980) # t14986: "cuda:0 f32[1, 2048, 4544]"
# t14987 = prims.mul(t3689, t14980) # t14987: "cuda:0 f32[1, 2048, 4544]"
# t14988 = prims.sum(t14987, (0, 2)) # t14988: "cuda:0 f32[2048]"
# t14989 = prims.broadcast_in_dim(t14988, [1, 2048, 1], [1]) # t14989: "cuda:0 f32[1, 2048, 1]"
# t14990 = prims.neg(t14986) # t14990: "cuda:0 f32[1, 2048, 4544]"
# t14992 = prims.sum(t14990, (0, 2)) # t14992: "cuda:0 f32[2048]"
# t14993 = prims.broadcast_in_dim(t14992, [1, 2048, 1], [1]) # t14993: "cuda:0 f32[1, 2048, 1]"
# t14994 = prims.mul(-0.5, t14989) # t14994: "cuda:0 f32[1, 2048, 1]"
# t14995 = prims.pow(t3686, 3.0) # t14995: "cuda:0 f32[1, 2048, 1]"
# t14996 = prims.mul(t14994, t14995) # t14996: "cuda:0 f32[1, 2048, 1]"
# t14998 = prims.sum(t14993, (0, 2)) # t14998: "cuda:0 f32[2048]"
# t14999 = prims.broadcast_in_dim(t14998, [1, 2048], [1]) # t14999: "cuda:0 f32[1, 2048]"
# t15000 = prims.sum(t14996, (0, 2)) # t15000: "cuda:0 f32[2048]"
# t15001 = prims.broadcast_in_dim(t15000, [1, 2048], [1]) # t15001: "cuda:0 f32[1, 2048]"
# t15004 = prims.broadcast_in_dim(t14999, [1, 2048, 1], [0, 1]) # t15004: "cuda:0 f32[1, 2048, 1]"
# t15005 = prims.broadcast_in_dim(t15004, (1, 2048, 4544), (0, 1, 2)) # t15005: "cuda:0 f32[1, 2048, 4544]"
# t15006 = prims.mul(0.00022007042253521127, t15005) # t15006: "cuda:0 f32[1, 2048, 4544]"
# t15008 = prims.broadcast_in_dim(t15001, [1, 2048, 1], [0, 1]) # t15008: "cuda:0 f32[1, 2048, 1]"
# t15009 = prims.broadcast_in_dim(t15008, (1, 2048, 4544), (0, 1, 2)) # t15009: "cuda:0 f32[1, 2048, 4544]"
# t15011 = prims.broadcast_in_dim(t3681, [1, 2048, 1], [0, 1]) # t15011: "cuda:0 f32[1, 2048, 1]"
# t15012 = prims.broadcast_in_dim(t15011, (1, 2048, 4544), (0, 1, 2)) # t15012: "cuda:0 f32[1, 2048, 4544]"
# t15013 = prims.mul(2.0, t15009) # t15013: "cuda:0 f32[1, 2048, 4544]"
# t15014 = prims.sub(t3673, t15012) # t15014: "cuda:0 f32[1, 2048, 4544]"
# t15015 = prims.mul(t15013, t15014) # t15015: "cuda:0 f32[1, 2048, 4544]"
# f15016 = prims.convert_element_type(i15007, float) # f15016: "float 4544.0"
# t15017 = prims.div(t15015, f15016) # t15017: "cuda:0 f32[1, 2048, 4544]"
# t15018 = prims.add(t15006, t15017) # t15018: "cuda:0 f32[1, 2048, 4544]"
# t15022 = prims.add(t14986, t15018) # t15022: "cuda:0 f32[1, 2048, 4544]"
# t15026 = prims.add(t15024, t15022) # t15026: "cuda:0 f32[1, 2048, 4544]"
# t15027 = prims.convert_element_type(t15026, dtypes.bfloat16) # t15027: "cuda:0 bf16[1, 2048, 4544]"
del i15007, t14744, t14787, t14966, t3513, t3645, t3666, t3681, t3686, t3692
t15034 = torch.reshape(t15027, (-1, 4544)) # t15034: "cuda:0 bf16[2048, 4544]"
# t15034 = ltorch.reshape(t15027, (-1, 4544)) # t15034: "cuda:0 bf16[2048, 4544]"
# t15034 = prims.reshape(t15027, (2048, 4544)) # t15034: "cuda:0 bf16[2048, 4544]"
t15038 = torch.permute(t15034, (1, 0)) # t15038: "cuda:0 bf16[4544, 2048]"
# t15038 = ltorch.permute(t15034, (1, 0)) # t15038: "cuda:0 bf16[4544, 2048]"
# t15038 = prims.transpose(t15034, (1, 0)) # t15038: "cuda:0 bf16[4544, 2048]"
t15081 = torch.matmul(t15038, t15080) # t15081: "cuda:0 bf16[4544, 4544]"
# t15081 = ltorch.matmul(t15079, t15080) # t15081: "cuda:0 bf16[4544, 4544]"
# t15081 = prims.matmul(t15079, t15080) # t15081: "cuda:0 bf16[4544, 4544]"
del t15080
t15035 = torch.matmul(t15034, t_transformer_h_22_mlp_proj_weight) # t15035: "cuda:0 bf16[2048, 18176]"
# t15035 = ltorch.matmul(t15034, t_transformer_h_22_mlp_proj_weight) # t15035: "cuda:0 bf16[2048, 18176]"
# t15035 = prims.matmul(t15034, t_transformer_h_22_mlp_proj_weight) # t15035: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_22_mlp_proj_weight
t15040 = torch.matmul(t15038, t15039) # t15040: "cuda:0 bf16[4544, 18176]"
# t15040 = ltorch.matmul(t15038, t15039) # t15040: "cuda:0 bf16[4544, 18176]"
# t15040 = prims.matmul(t15038, t15039) # t15040: "cuda:0 bf16[4544, 18176]"
del t15038, t15039
t15076 = torch.matmul(t15034, t_transformer_h_22_attn_proj_weight) # t15076: "cuda:0 bf16[2048, 4544]"
# t15076 = ltorch.matmul(t15075, t_transformer_h_22_attn_proj_weight) # t15076: "cuda:0 bf16[2048, 4544]"
# t15076 = prims.matmul(t15075, t_transformer_h_22_attn_proj_weight) # t15076: "cuda:0 bf16[2048, 4544]"
del t15034, t_transformer_h_22_attn_proj_weight
t15036 = torch.reshape(t15035, (1, 2048, 18176)) # t15036: "cuda:0 bf16[1, 2048, 18176]"
# t15036 = ltorch.reshape(t15035, (1, 2048, 18176)) # t15036: "cuda:0 bf16[1, 2048, 18176]"
# t15036 = prims.reshape(t15035, (1, 2048, 18176)) # t15036: "cuda:0 bf16[1, 2048, 18176]"
del t15035
t15077 = torch.reshape(t15076, (1, 2048, 4544)) # t15077: "cuda:0 bf16[1, 2048, 4544]"
# t15077 = ltorch.reshape(t15076, (1, 2048, 4544)) # t15077: "cuda:0 bf16[1, 2048, 4544]"
# t15077 = prims.reshape(t15076, (1, 2048, 4544)) # t15077: "cuda:0 bf16[1, 2048, 4544]"
del t15076
t15085 = torch.reshape(t15077, (1, 2048, 71, 64)) # t15085: "cuda:0 bf16[1, 2048, 71, 64]"
# t15085 = ltorch.reshape(t15077, (1, 2048, 71, 64)) # t15085: "cuda:0 bf16[1, 2048, 71, 64]"
# t15085 = prims.reshape(t15077, (1, 2048, 71, 64)) # t15085: "cuda:0 bf16[1, 2048, 71, 64]"
del t15077
t15088 = torch.permute(t15085, (0, 2, 1, 3)) # t15088: "cuda:0 bf16[1, 71, 2048, 64]"
# t15088 = ltorch.permute(t15085, (0, 2, 1, 3)) # t15088: "cuda:0 bf16[1, 71, 2048, 64]"
# t15088 = prims.transpose(t15085, (0, 2, 1, 3)) # t15088: "cuda:0 bf16[1, 71, 2048, 64]"
del t15085
[t15067] = nvFusion28(f1471, f1473, t15036, t3646)
# t3647 = prims.convert_element_type(t3646, dtypes.float32) # t3647: "cuda:0 f32[1, 2048, 18176]"
# t3649 = prims.div(t3647, 1.4142135623730951) # t3649: "cuda:0 f32[1, 2048, 18176]"
# t3652 = prims.erf(t3649) # t3652: "cuda:0 f32[1, 2048, 18176]"
# t3656 = prims.mul(0.5, t3652) # t3656: "cuda:0 f32[1, 2048, 18176]"
# t3660 = prims.add(0.5, t3656) # t3660: "cuda:0 f32[1, 2048, 18176]"
# t15041 = prims.convert_element_type(t15036, dtypes.float32) # t15041: "cuda:0 f32[1, 2048, 18176]"
# t15042 = prims.mul(t3660, t15041) # t15042: "cuda:0 f32[1, 2048, 18176]"
# t15043 = prims.mul(t3647, t15041) # t15043: "cuda:0 f32[1, 2048, 18176]"
# t15051 = prims.mul(f1473, t15043) # t15051: "cuda:0 f32[1, 2048, 18176]"
# t15054 = prims.pow(t3649, 2.0) # t15054: "cuda:0 f32[1, 2048, 18176]"
# t15055 = prims.neg(t15054) # t15055: "cuda:0 f32[1, 2048, 18176]"
# t15056 = prims.exp(t15055) # t15056: "cuda:0 f32[1, 2048, 18176]"
# t15057 = prims.mul(1.1283791670955126, t15056) # t15057: "cuda:0 f32[1, 2048, 18176]"
# t15058 = prims.mul(t15057, t15051) # t15058: "cuda:0 f32[1, 2048, 18176]"
# t15062 = prims.div(t15058, f1471) # t15062: "cuda:0 f32[1, 2048, 18176]"
# t15066 = prims.add(t15042, t15062) # t15066: "cuda:0 f32[1, 2048, 18176]"
# t15067 = prims.convert_element_type(t15066, dtypes.bfloat16) # t15067: "cuda:0 bf16[1, 2048, 18176]"
del f1471, f1473, t15036, t3646
t15068 = torch.reshape(t15067, (-1, 18176)) # t15068: "cuda:0 bf16[2048, 18176]"
# t15068 = ltorch.reshape(t15067, (-1, 18176)) # t15068: "cuda:0 bf16[2048, 18176]"
# t15068 = prims.reshape(t15067, (2048, 18176)) # t15068: "cuda:0 bf16[2048, 18176]"
del t15067
t15072 = torch.permute(t15068, (1, 0)) # t15072: "cuda:0 bf16[18176, 2048]"
# t15072 = ltorch.permute(t15068, (1, 0)) # t15072: "cuda:0 bf16[18176, 2048]"
# t15072 = prims.transpose(t15068, (1, 0)) # t15072: "cuda:0 bf16[18176, 2048]"
t15074 = torch.matmul(t15072, t15073) # t15074: "cuda:0 bf16[18176, 4544]"
# t15074 = ltorch.matmul(t15072, t15073) # t15074: "cuda:0 bf16[18176, 4544]"
# t15074 = prims.matmul(t15072, t15073) # t15074: "cuda:0 bf16[18176, 4544]"
del t15072
t15069 = torch.matmul(t15068, t_transformer_h_22_mlp_fc_weight) # t15069: "cuda:0 bf16[2048, 4544]"
# t15069 = ltorch.matmul(t15068, t_transformer_h_22_mlp_fc_weight) # t15069: "cuda:0 bf16[2048, 4544]"
# t15069 = prims.matmul(t15068, t_transformer_h_22_mlp_fc_weight) # t15069: "cuda:0 bf16[2048, 4544]"
del t15068, t_transformer_h_22_mlp_fc_weight
(t15089, t15090, t15091) = cudnn_sdpa_bwd(t15088, t3630, t3633, t3583, None, f1462, b1463, t3634, t3635, t3636, t3637, scale=f1464, cat_grad_qkv=False)
del t15088, t3630, t3633, t3583, f1462, b1463, t3634, t3635, t3636, t3637, f1464
t15093 = torch_slice_prim_impl(t15090, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15093: "cuda:0 bf16[1, 71, 2048, 64]"
del t15090
t15097 = torch_slice_prim_impl(t15089, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15097: "cuda:0 bf16[1, 71, 2048, 64]"
del t15089
t15200 = torch.reshape(t15091, (1, 1, 71, 2048, 64)) # t15200: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15200 = ltorch.reshape(t15091, (1, 1, 71, 2048, 64)) # t15200: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15200 = prims.reshape(t15091, (1, 1, 71, 2048, 64)) # t15200: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t15091
[t15234] = nvFusion29(i1435, t15093, t15097, t15200, t61, t66)
# t15094 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t15094: "cuda:0 bf16[1, 71, 2048, 0]"
# t15095 = prims.pad(t15094, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t15095: "cuda:0 bf16[1, 71, 2048, 64]"
# t15098 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t15098: "cuda:0 bf16[1, 71, 2048, 0]"
# t15099 = prims.pad(t15098, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t15099: "cuda:0 bf16[1, 71, 2048, 64]"
# t15100 = prims.convert_element_type(t15093, dtypes.float32) # t15100: "cuda:0 f32[1, 71, 2048, 64]"
# t15104 = prims.mul(t66, t15100) # t15104: "cuda:0 f32[1, 71, 2048, 64]"
# t15107 = prims.convert_element_type(t15104, dtypes.bfloat16) # t15107: "cuda:0 bf16[1, 71, 2048, 64]"
# t15116 = prims.mul(t61, t15100) # t15116: "cuda:0 f32[1, 71, 2048, 64]"
# t15128 = prims.slice_prim(t15107, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t15128: "cuda:0 bf16[1, 71, 2048, 32]"
# t15129 = prims.slice_prim(t15107, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15129: "cuda:0 bf16[1, 71, 2048, 32]"
# t15130 = prims.convert_element_type(t15128, dtypes.float32) # t15130: "cuda:0 f32[1, 71, 2048, 32]"
# t15131 = prims.neg(t15130) # t15131: "cuda:0 f32[1, 71, 2048, 32]"
# t15132 = prims.convert_element_type(t15131, dtypes.bfloat16) # t15132: "cuda:0 bf16[1, 71, 2048, 32]"
# t15133 = prims.pad(t15132, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t15133: "cuda:0 bf16[1, 71, 2048, 64]"
# t15135 = prims.convert_element_type(t15133, dtypes.float32) # t15135: "cuda:0 f32[1, 71, 2048, 64]"
# t15136 = prims.add(t15116, t15135) # t15136: "cuda:0 f32[1, 71, 2048, 64]"
# t15138 = prims.pad(t15129, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t15138: "cuda:0 bf16[1, 71, 2048, 64]"
# t15140 = prims.convert_element_type(t15138, dtypes.float32) # t15140: "cuda:0 f32[1, 71, 2048, 64]"
# t15141 = prims.add(t15136, t15140) # t15141: "cuda:0 f32[1, 71, 2048, 64]"
# t15142 = prims.convert_element_type(t15141, dtypes.bfloat16) # t15142: "cuda:0 bf16[1, 71, 2048, 64]"
# t15143 = prims.pad(t15142, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t15143: "cuda:0 bf16[1, 71, 2048, 64]"
# t15144 = prims.convert_element_type(t15095, dtypes.float32) # t15144: "cuda:0 f32[1, 71, 2048, 64]"
# t15145 = prims.convert_element_type(t15143, dtypes.float32) # t15145: "cuda:0 f32[1, 71, 2048, 64]"
# t15146 = prims.add(t15144, t15145) # t15146: "cuda:0 f32[1, 71, 2048, 64]"
# t15147 = prims.convert_element_type(t15146, dtypes.bfloat16) # t15147: "cuda:0 bf16[1, 71, 2048, 64]"
# t15148 = prims.convert_element_type(t15097, dtypes.float32) # t15148: "cuda:0 f32[1, 71, 2048, 64]"
# t15152 = prims.mul(t66, t15148) # t15152: "cuda:0 f32[1, 71, 2048, 64]"
# t15155 = prims.convert_element_type(t15152, dtypes.bfloat16) # t15155: "cuda:0 bf16[1, 71, 2048, 64]"
# t15164 = prims.mul(t61, t15148) # t15164: "cuda:0 f32[1, 71, 2048, 64]"
# t15176 = prims.slice_prim(t15155, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t15176: "cuda:0 bf16[1, 71, 2048, 32]"
# t15177 = prims.slice_prim(t15155, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15177: "cuda:0 bf16[1, 71, 2048, 32]"
# t15178 = prims.convert_element_type(t15176, dtypes.float32) # t15178: "cuda:0 f32[1, 71, 2048, 32]"
# t15179 = prims.neg(t15178) # t15179: "cuda:0 f32[1, 71, 2048, 32]"
# t15180 = prims.convert_element_type(t15179, dtypes.bfloat16) # t15180: "cuda:0 bf16[1, 71, 2048, 32]"
# t15181 = prims.pad(t15180, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t15181: "cuda:0 bf16[1, 71, 2048, 64]"
# t15183 = prims.convert_element_type(t15181, dtypes.float32) # t15183: "cuda:0 f32[1, 71, 2048, 64]"
# t15184 = prims.add(t15164, t15183) # t15184: "cuda:0 f32[1, 71, 2048, 64]"
# t15186 = prims.pad(t15177, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t15186: "cuda:0 bf16[1, 71, 2048, 64]"
# t15188 = prims.convert_element_type(t15186, dtypes.float32) # t15188: "cuda:0 f32[1, 71, 2048, 64]"
# t15189 = prims.add(t15184, t15188) # t15189: "cuda:0 f32[1, 71, 2048, 64]"
# t15190 = prims.convert_element_type(t15189, dtypes.bfloat16) # t15190: "cuda:0 bf16[1, 71, 2048, 64]"
# t15191 = prims.pad(t15190, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t15191: "cuda:0 bf16[1, 71, 2048, 64]"
# t15192 = prims.convert_element_type(t15099, dtypes.float32) # t15192: "cuda:0 f32[1, 71, 2048, 64]"
# t15193 = prims.convert_element_type(t15191, dtypes.float32) # t15193: "cuda:0 f32[1, 71, 2048, 64]"
# t15194 = prims.add(t15192, t15193) # t15194: "cuda:0 f32[1, 71, 2048, 64]"
# t15195 = prims.convert_element_type(t15194, dtypes.bfloat16) # t15195: "cuda:0 bf16[1, 71, 2048, 64]"
# t15205 = prims.reshape(t15147, (1, 1, 71, 2048, 64)) # t15205: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15210 = prims.reshape(t15195, (1, 1, 71, 2048, 64)) # t15210: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15216 = prims.convert_element_type(t15200, dtypes.float32) # t15216: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t15217 = prims.sum(t15216, (0, 1, 2)) # t15217: "cuda:0 f32[2048, 64]"
# t15218 = prims.convert_element_type(t15217, dtypes.bfloat16) # t15218: "cuda:0 bf16[2048, 64]"
# t15219 = prims.broadcast_in_dim(t15218, [1, 1, 1, 2048, 64], [3, 4]) # t15219: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t15225 = prims.convert_element_type(t15205, dtypes.float32) # t15225: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t15226 = prims.sum(t15225, (0, 1, 2)) # t15226: "cuda:0 f32[2048, 64]"
# t15227 = prims.convert_element_type(t15226, dtypes.bfloat16) # t15227: "cuda:0 bf16[2048, 64]"
# t15228 = prims.broadcast_in_dim(t15227, [1, 1, 1, 2048, 64], [3, 4]) # t15228: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t15234 = prims.cat((t15210, t15228, t15219), i1435) # t15234: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1435, t15093, t15097, t15200
t15240 = torch.permute(t15234, (0, 3, 1, 2, 4)) # t15240: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t15240 = ltorch.permute(t15234, (0, 3, 1, 2, 4)) # t15240: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t15240 = prims.transpose(t15234, (0, 3, 1, 2, 4)) # t15240: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t15234
t15246 = torch.reshape(t15240, (1, 2048, 4672)) # t15246: "cuda:0 bf16[1, 2048, 4672]"
# t15246 = ltorch.reshape(t15240, (1, 2048, 4672)) # t15246: "cuda:0 bf16[1, 2048, 4672]"
# t15246 = prims.reshape(t15240, (1, 2048, 4672)) # t15246: "cuda:0 bf16[1, 2048, 4672]"
del t15240
t15247 = torch.reshape(t15246, (-1, 4672)) # t15247: "cuda:0 bf16[2048, 4672]"
# t15247 = ltorch.reshape(t15246, (-1, 4672)) # t15247: "cuda:0 bf16[2048, 4672]"
# t15247 = prims.reshape(t15246, (2048, 4672)) # t15247: "cuda:0 bf16[2048, 4672]"
del t15246
t15251 = torch.permute(t15247, (1, 0)) # t15251: "cuda:0 bf16[4672, 2048]"
# t15251 = ltorch.permute(t15247, (1, 0)) # t15251: "cuda:0 bf16[4672, 2048]"
# t15251 = prims.transpose(t15247, (1, 0)) # t15251: "cuda:0 bf16[4672, 2048]"
t15253 = torch.matmul(t15251, t15073) # t15253: "cuda:0 bf16[4672, 4544]"
# t15253 = ltorch.matmul(t15251, t15252) # t15253: "cuda:0 bf16[4672, 4544]"
# t15253 = prims.matmul(t15251, t15252) # t15253: "cuda:0 bf16[4672, 4544]"
del t15251, t15073
t15248 = torch.matmul(t15247, t_transformer_h_22_attn_attn_weight) # t15248: "cuda:0 bf16[2048, 4544]"
# t15248 = ltorch.matmul(t15247, t_transformer_h_22_attn_attn_weight) # t15248: "cuda:0 bf16[2048, 4544]"
# t15248 = prims.matmul(t15247, t_transformer_h_22_attn_attn_weight) # t15248: "cuda:0 bf16[2048, 4544]"
del t15247, t_transformer_h_22_attn_attn_weight
t15070 = torch.reshape(t15069, (1, 2048, 4544)) # t15070: "cuda:0 bf16[1, 2048, 4544]"
# t15070 = ltorch.reshape(t15069, (1, 2048, 4544)) # t15070: "cuda:0 bf16[1, 2048, 4544]"
# t15070 = prims.reshape(t15069, (1, 2048, 4544)) # t15070: "cuda:0 bf16[1, 2048, 4544]"
del t15069
t15249 = torch.reshape(t15248, (1, 2048, 4544)) # t15249: "cuda:0 bf16[1, 2048, 4544]"
# t15249 = ltorch.reshape(t15248, (1, 2048, 4544)) # t15249: "cuda:0 bf16[1, 2048, 4544]"
# t15249 = prims.reshape(t15248, (1, 2048, 4544)) # t15249: "cuda:0 bf16[1, 2048, 4544]"
del t15248
[t15262, t15268, t15310] = nvFusion30(i15290, t15027, t15070, t15249, t3352, t3484, t3505, t3520, t3525, t3531)
# t3511 = prims.convert_element_type(t3352, dtypes.float32) # t3511: "cuda:0 f32[1, 2048, 4544]"
# t3506 = prims.convert_element_type(t3505, dtypes.float32) # t3506: "cuda:0 f32[1, 2048, 4544]"
# t3507 = prims.convert_element_type(t3484, dtypes.float32) # t3507: "cuda:0 f32[1, 2048, 4544]"
# t3508 = prims.add(t3506, t3507) # t3508: "cuda:0 f32[1, 2048, 4544]"
# t3512 = prims.add(t3508, t3511) # t3512: "cuda:0 f32[1, 2048, 4544]"
# t3522 = prims.broadcast_in_dim(t3520, [1, 2048, 1], [0, 1]) # t3522: "cuda:0 f32[1, 2048, 1]"
# t3526 = prims.broadcast_in_dim(t3522, (1, 2048, 4544), (0, 1, 2)) # t3526: "cuda:0 f32[1, 2048, 4544]"
# t3528 = prims.sub(t3512, t3526) # t3528: "cuda:0 f32[1, 2048, 4544]"
# t3529 = prims.broadcast_in_dim(t3525, (1, 2048, 4544), (0, 1, 2)) # t3529: "cuda:0 f32[1, 2048, 4544]"
# t3530 = prims.mul(t3528, t3529) # t3530: "cuda:0 f32[1, 2048, 4544]"
# t3532 = prims.convert_element_type(t3531, dtypes.float32) # t3532: "cuda:0 f32[1, 2048, 4544]"
# t15307 = prims.convert_element_type(t15027, dtypes.float32) # t15307: "cuda:0 f32[1, 2048, 4544]"
# t15254 = prims.convert_element_type(t15070, dtypes.float32) # t15254: "cuda:0 f32[1, 2048, 4544]"
# t15255 = prims.convert_element_type(t15249, dtypes.float32) # t15255: "cuda:0 f32[1, 2048, 4544]"
# t15256 = prims.add(t15254, t15255) # t15256: "cuda:0 f32[1, 2048, 4544]"
# t15261 = prims.sum(t15256, (0, 1)) # t15261: "cuda:0 f32[4544]"
# t15262 = prims.convert_element_type(t15261, dtypes.bfloat16) # t15262: "cuda:0 bf16[4544]"
# t15263 = prims.mul(t3532, t15256) # t15263: "cuda:0 f32[1, 2048, 4544]"
# t15264 = prims.mul(t3530, t15256) # t15264: "cuda:0 f32[1, 2048, 4544]"
# t15267 = prims.sum(t15264, (0, 1)) # t15267: "cuda:0 f32[4544]"
# t15268 = prims.convert_element_type(t15267, dtypes.bfloat16) # t15268: "cuda:0 bf16[4544]"
# t15269 = prims.mul(t3529, t15263) # t15269: "cuda:0 f32[1, 2048, 4544]"
# t15270 = prims.mul(t3528, t15263) # t15270: "cuda:0 f32[1, 2048, 4544]"
# t15271 = prims.sum(t15270, (0, 2)) # t15271: "cuda:0 f32[2048]"
# t15272 = prims.broadcast_in_dim(t15271, [1, 2048, 1], [1]) # t15272: "cuda:0 f32[1, 2048, 1]"
# t15273 = prims.neg(t15269) # t15273: "cuda:0 f32[1, 2048, 4544]"
# t15275 = prims.sum(t15273, (0, 2)) # t15275: "cuda:0 f32[2048]"
# t15276 = prims.broadcast_in_dim(t15275, [1, 2048, 1], [1]) # t15276: "cuda:0 f32[1, 2048, 1]"
# t15277 = prims.mul(-0.5, t15272) # t15277: "cuda:0 f32[1, 2048, 1]"
# t15278 = prims.pow(t3525, 3.0) # t15278: "cuda:0 f32[1, 2048, 1]"
# t15279 = prims.mul(t15277, t15278) # t15279: "cuda:0 f32[1, 2048, 1]"
# t15281 = prims.sum(t15276, (0, 2)) # t15281: "cuda:0 f32[2048]"
# t15282 = prims.broadcast_in_dim(t15281, [1, 2048], [1]) # t15282: "cuda:0 f32[1, 2048]"
# t15283 = prims.sum(t15279, (0, 2)) # t15283: "cuda:0 f32[2048]"
# t15284 = prims.broadcast_in_dim(t15283, [1, 2048], [1]) # t15284: "cuda:0 f32[1, 2048]"
# t15287 = prims.broadcast_in_dim(t15282, [1, 2048, 1], [0, 1]) # t15287: "cuda:0 f32[1, 2048, 1]"
# t15288 = prims.broadcast_in_dim(t15287, (1, 2048, 4544), (0, 1, 2)) # t15288: "cuda:0 f32[1, 2048, 4544]"
# t15289 = prims.mul(0.00022007042253521127, t15288) # t15289: "cuda:0 f32[1, 2048, 4544]"
# t15291 = prims.broadcast_in_dim(t15284, [1, 2048, 1], [0, 1]) # t15291: "cuda:0 f32[1, 2048, 1]"
# t15292 = prims.broadcast_in_dim(t15291, (1, 2048, 4544), (0, 1, 2)) # t15292: "cuda:0 f32[1, 2048, 4544]"
# t15294 = prims.broadcast_in_dim(t3520, [1, 2048, 1], [0, 1]) # t15294: "cuda:0 f32[1, 2048, 1]"
# t15295 = prims.broadcast_in_dim(t15294, (1, 2048, 4544), (0, 1, 2)) # t15295: "cuda:0 f32[1, 2048, 4544]"
# t15296 = prims.mul(2.0, t15292) # t15296: "cuda:0 f32[1, 2048, 4544]"
# t15297 = prims.sub(t3512, t15295) # t15297: "cuda:0 f32[1, 2048, 4544]"
# t15298 = prims.mul(t15296, t15297) # t15298: "cuda:0 f32[1, 2048, 4544]"
# f15299 = prims.convert_element_type(i15290, float) # f15299: "float 4544.0"
# t15300 = prims.div(t15298, f15299) # t15300: "cuda:0 f32[1, 2048, 4544]"
# t15301 = prims.add(t15289, t15300) # t15301: "cuda:0 f32[1, 2048, 4544]"
# t15305 = prims.add(t15269, t15301) # t15305: "cuda:0 f32[1, 2048, 4544]"
# t15309 = prims.add(t15307, t15305) # t15309: "cuda:0 f32[1, 2048, 4544]"
# t15310 = prims.convert_element_type(t15309, dtypes.bfloat16) # t15310: "cuda:0 bf16[1, 2048, 4544]"
del i15290, t15027, t15070, t15249, t3352, t3484, t3505, t3520, t3525, t3531
t15317 = torch.reshape(t15310, (-1, 4544)) # t15317: "cuda:0 bf16[2048, 4544]"
# t15317 = ltorch.reshape(t15310, (-1, 4544)) # t15317: "cuda:0 bf16[2048, 4544]"
# t15317 = prims.reshape(t15310, (2048, 4544)) # t15317: "cuda:0 bf16[2048, 4544]"
t15321 = torch.permute(t15317, (1, 0)) # t15321: "cuda:0 bf16[4544, 2048]"
# t15321 = ltorch.permute(t15317, (1, 0)) # t15321: "cuda:0 bf16[4544, 2048]"
# t15321 = prims.transpose(t15317, (1, 0)) # t15321: "cuda:0 bf16[4544, 2048]"
t15318 = torch.matmul(t15317, t_transformer_h_21_mlp_proj_weight) # t15318: "cuda:0 bf16[2048, 18176]"
# t15318 = ltorch.matmul(t15317, t_transformer_h_21_mlp_proj_weight) # t15318: "cuda:0 bf16[2048, 18176]"
# t15318 = prims.matmul(t15317, t_transformer_h_21_mlp_proj_weight) # t15318: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_21_mlp_proj_weight
t15323 = torch.matmul(t15321, t15322) # t15323: "cuda:0 bf16[4544, 18176]"
# t15323 = ltorch.matmul(t15321, t15322) # t15323: "cuda:0 bf16[4544, 18176]"
# t15323 = prims.matmul(t15321, t15322) # t15323: "cuda:0 bf16[4544, 18176]"
del t15322
t15359 = torch.matmul(t15317, t_transformer_h_21_attn_proj_weight) # t15359: "cuda:0 bf16[2048, 4544]"
# t15359 = ltorch.matmul(t15358, t_transformer_h_21_attn_proj_weight) # t15359: "cuda:0 bf16[2048, 4544]"
# t15359 = prims.matmul(t15358, t_transformer_h_21_attn_proj_weight) # t15359: "cuda:0 bf16[2048, 4544]"
del t15317, t_transformer_h_21_attn_proj_weight
t15364 = torch.matmul(t15321, t15363) # t15364: "cuda:0 bf16[4544, 4544]"
# t15364 = ltorch.matmul(t15362, t15363) # t15364: "cuda:0 bf16[4544, 4544]"
# t15364 = prims.matmul(t15362, t15363) # t15364: "cuda:0 bf16[4544, 4544]"
del t15321, t15363
t15319 = torch.reshape(t15318, (1, 2048, 18176)) # t15319: "cuda:0 bf16[1, 2048, 18176]"
# t15319 = ltorch.reshape(t15318, (1, 2048, 18176)) # t15319: "cuda:0 bf16[1, 2048, 18176]"
# t15319 = prims.reshape(t15318, (1, 2048, 18176)) # t15319: "cuda:0 bf16[1, 2048, 18176]"
del t15318
t15360 = torch.reshape(t15359, (1, 2048, 4544)) # t15360: "cuda:0 bf16[1, 2048, 4544]"
# t15360 = ltorch.reshape(t15359, (1, 2048, 4544)) # t15360: "cuda:0 bf16[1, 2048, 4544]"
# t15360 = prims.reshape(t15359, (1, 2048, 4544)) # t15360: "cuda:0 bf16[1, 2048, 4544]"
del t15359
t15368 = torch.reshape(t15360, (1, 2048, 71, 64)) # t15368: "cuda:0 bf16[1, 2048, 71, 64]"
# t15368 = ltorch.reshape(t15360, (1, 2048, 71, 64)) # t15368: "cuda:0 bf16[1, 2048, 71, 64]"
# t15368 = prims.reshape(t15360, (1, 2048, 71, 64)) # t15368: "cuda:0 bf16[1, 2048, 71, 64]"
del t15360
t15371 = torch.permute(t15368, (0, 2, 1, 3)) # t15371: "cuda:0 bf16[1, 71, 2048, 64]"
# t15371 = ltorch.permute(t15368, (0, 2, 1, 3)) # t15371: "cuda:0 bf16[1, 71, 2048, 64]"
# t15371 = prims.transpose(t15368, (0, 2, 1, 3)) # t15371: "cuda:0 bf16[1, 71, 2048, 64]"
del t15368
[t15350] = nvFusion31(f1407, f1409, t15319, t3485)
# t3486 = prims.convert_element_type(t3485, dtypes.float32) # t3486: "cuda:0 f32[1, 2048, 18176]"
# t3488 = prims.div(t3486, 1.4142135623730951) # t3488: "cuda:0 f32[1, 2048, 18176]"
# t3491 = prims.erf(t3488) # t3491: "cuda:0 f32[1, 2048, 18176]"
# t3495 = prims.mul(0.5, t3491) # t3495: "cuda:0 f32[1, 2048, 18176]"
# t3499 = prims.add(0.5, t3495) # t3499: "cuda:0 f32[1, 2048, 18176]"
# t15324 = prims.convert_element_type(t15319, dtypes.float32) # t15324: "cuda:0 f32[1, 2048, 18176]"
# t15325 = prims.mul(t3499, t15324) # t15325: "cuda:0 f32[1, 2048, 18176]"
# t15326 = prims.mul(t3486, t15324) # t15326: "cuda:0 f32[1, 2048, 18176]"
# t15334 = prims.mul(f1409, t15326) # t15334: "cuda:0 f32[1, 2048, 18176]"
# t15337 = prims.pow(t3488, 2.0) # t15337: "cuda:0 f32[1, 2048, 18176]"
# t15338 = prims.neg(t15337) # t15338: "cuda:0 f32[1, 2048, 18176]"
# t15339 = prims.exp(t15338) # t15339: "cuda:0 f32[1, 2048, 18176]"
# t15340 = prims.mul(1.1283791670955126, t15339) # t15340: "cuda:0 f32[1, 2048, 18176]"
# t15341 = prims.mul(t15340, t15334) # t15341: "cuda:0 f32[1, 2048, 18176]"
# t15345 = prims.div(t15341, f1407) # t15345: "cuda:0 f32[1, 2048, 18176]"
# t15349 = prims.add(t15325, t15345) # t15349: "cuda:0 f32[1, 2048, 18176]"
# t15350 = prims.convert_element_type(t15349, dtypes.bfloat16) # t15350: "cuda:0 bf16[1, 2048, 18176]"
del f1407, f1409, t15319, t3485
t15351 = torch.reshape(t15350, (-1, 18176)) # t15351: "cuda:0 bf16[2048, 18176]"
# t15351 = ltorch.reshape(t15350, (-1, 18176)) # t15351: "cuda:0 bf16[2048, 18176]"
# t15351 = prims.reshape(t15350, (2048, 18176)) # t15351: "cuda:0 bf16[2048, 18176]"
del t15350
t15355 = torch.permute(t15351, (1, 0)) # t15355: "cuda:0 bf16[18176, 2048]"
# t15355 = ltorch.permute(t15351, (1, 0)) # t15355: "cuda:0 bf16[18176, 2048]"
# t15355 = prims.transpose(t15351, (1, 0)) # t15355: "cuda:0 bf16[18176, 2048]"
t15357 = torch.matmul(t15355, t15356) # t15357: "cuda:0 bf16[18176, 4544]"
# t15357 = ltorch.matmul(t15355, t15356) # t15357: "cuda:0 bf16[18176, 4544]"
# t15357 = prims.matmul(t15355, t15356) # t15357: "cuda:0 bf16[18176, 4544]"
del t15355
t15352 = torch.matmul(t15351, t_transformer_h_21_mlp_fc_weight) # t15352: "cuda:0 bf16[2048, 4544]"
# t15352 = ltorch.matmul(t15351, t_transformer_h_21_mlp_fc_weight) # t15352: "cuda:0 bf16[2048, 4544]"
# t15352 = prims.matmul(t15351, t_transformer_h_21_mlp_fc_weight) # t15352: "cuda:0 bf16[2048, 4544]"
del t15351, t_transformer_h_21_mlp_fc_weight
(t15372, t15373, t15374) = cudnn_sdpa_bwd(t15371, t3469, t3472, t3422, None, f1398, b1399, t3473, t3474, t3475, t3476, scale=f1400, cat_grad_qkv=False)
del t15371, t3469, t3472, t3422, f1398, b1399, t3473, t3474, t3475, t3476, f1400
t15376 = torch_slice_prim_impl(t15373, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15376: "cuda:0 bf16[1, 71, 2048, 64]"
del t15373
t15380 = torch_slice_prim_impl(t15372, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15380: "cuda:0 bf16[1, 71, 2048, 64]"
del t15372
t15483 = torch.reshape(t15374, (1, 1, 71, 2048, 64)) # t15483: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15483 = ltorch.reshape(t15374, (1, 1, 71, 2048, 64)) # t15483: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15483 = prims.reshape(t15374, (1, 1, 71, 2048, 64)) # t15483: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t15374
[t15517] = nvFusion32(i1371, t15376, t15380, t15483, t61, t66)
# t15377 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t15377: "cuda:0 bf16[1, 71, 2048, 0]"
# t15378 = prims.pad(t15377, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t15378: "cuda:0 bf16[1, 71, 2048, 64]"
# t15381 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t15381: "cuda:0 bf16[1, 71, 2048, 0]"
# t15382 = prims.pad(t15381, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t15382: "cuda:0 bf16[1, 71, 2048, 64]"
# t15383 = prims.convert_element_type(t15376, dtypes.float32) # t15383: "cuda:0 f32[1, 71, 2048, 64]"
# t15387 = prims.mul(t66, t15383) # t15387: "cuda:0 f32[1, 71, 2048, 64]"
# t15390 = prims.convert_element_type(t15387, dtypes.bfloat16) # t15390: "cuda:0 bf16[1, 71, 2048, 64]"
# t15399 = prims.mul(t61, t15383) # t15399: "cuda:0 f32[1, 71, 2048, 64]"
# t15411 = prims.slice_prim(t15390, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t15411: "cuda:0 bf16[1, 71, 2048, 32]"
# t15412 = prims.slice_prim(t15390, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15412: "cuda:0 bf16[1, 71, 2048, 32]"
# t15413 = prims.convert_element_type(t15411, dtypes.float32) # t15413: "cuda:0 f32[1, 71, 2048, 32]"
# t15414 = prims.neg(t15413) # t15414: "cuda:0 f32[1, 71, 2048, 32]"
# t15415 = prims.convert_element_type(t15414, dtypes.bfloat16) # t15415: "cuda:0 bf16[1, 71, 2048, 32]"
# t15416 = prims.pad(t15415, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t15416: "cuda:0 bf16[1, 71, 2048, 64]"
# t15418 = prims.convert_element_type(t15416, dtypes.float32) # t15418: "cuda:0 f32[1, 71, 2048, 64]"
# t15419 = prims.add(t15399, t15418) # t15419: "cuda:0 f32[1, 71, 2048, 64]"
# t15421 = prims.pad(t15412, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t15421: "cuda:0 bf16[1, 71, 2048, 64]"
# t15423 = prims.convert_element_type(t15421, dtypes.float32) # t15423: "cuda:0 f32[1, 71, 2048, 64]"
# t15424 = prims.add(t15419, t15423) # t15424: "cuda:0 f32[1, 71, 2048, 64]"
# t15425 = prims.convert_element_type(t15424, dtypes.bfloat16) # t15425: "cuda:0 bf16[1, 71, 2048, 64]"
# t15426 = prims.pad(t15425, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t15426: "cuda:0 bf16[1, 71, 2048, 64]"
# t15427 = prims.convert_element_type(t15378, dtypes.float32) # t15427: "cuda:0 f32[1, 71, 2048, 64]"
# t15428 = prims.convert_element_type(t15426, dtypes.float32) # t15428: "cuda:0 f32[1, 71, 2048, 64]"
# t15429 = prims.add(t15427, t15428) # t15429: "cuda:0 f32[1, 71, 2048, 64]"
# t15430 = prims.convert_element_type(t15429, dtypes.bfloat16) # t15430: "cuda:0 bf16[1, 71, 2048, 64]"
# t15431 = prims.convert_element_type(t15380, dtypes.float32) # t15431: "cuda:0 f32[1, 71, 2048, 64]"
# t15435 = prims.mul(t66, t15431) # t15435: "cuda:0 f32[1, 71, 2048, 64]"
# t15438 = prims.convert_element_type(t15435, dtypes.bfloat16) # t15438: "cuda:0 bf16[1, 71, 2048, 64]"
# t15447 = prims.mul(t61, t15431) # t15447: "cuda:0 f32[1, 71, 2048, 64]"
# t15459 = prims.slice_prim(t15438, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t15459: "cuda:0 bf16[1, 71, 2048, 32]"
# t15460 = prims.slice_prim(t15438, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15460: "cuda:0 bf16[1, 71, 2048, 32]"
# t15461 = prims.convert_element_type(t15459, dtypes.float32) # t15461: "cuda:0 f32[1, 71, 2048, 32]"
# t15462 = prims.neg(t15461) # t15462: "cuda:0 f32[1, 71, 2048, 32]"
# t15463 = prims.convert_element_type(t15462, dtypes.bfloat16) # t15463: "cuda:0 bf16[1, 71, 2048, 32]"
# t15464 = prims.pad(t15463, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t15464: "cuda:0 bf16[1, 71, 2048, 64]"
# t15466 = prims.convert_element_type(t15464, dtypes.float32) # t15466: "cuda:0 f32[1, 71, 2048, 64]"
# t15467 = prims.add(t15447, t15466) # t15467: "cuda:0 f32[1, 71, 2048, 64]"
# t15469 = prims.pad(t15460, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t15469: "cuda:0 bf16[1, 71, 2048, 64]"
# t15471 = prims.convert_element_type(t15469, dtypes.float32) # t15471: "cuda:0 f32[1, 71, 2048, 64]"
# t15472 = prims.add(t15467, t15471) # t15472: "cuda:0 f32[1, 71, 2048, 64]"
# t15473 = prims.convert_element_type(t15472, dtypes.bfloat16) # t15473: "cuda:0 bf16[1, 71, 2048, 64]"
# t15474 = prims.pad(t15473, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t15474: "cuda:0 bf16[1, 71, 2048, 64]"
# t15475 = prims.convert_element_type(t15382, dtypes.float32) # t15475: "cuda:0 f32[1, 71, 2048, 64]"
# t15476 = prims.convert_element_type(t15474, dtypes.float32) # t15476: "cuda:0 f32[1, 71, 2048, 64]"
# t15477 = prims.add(t15475, t15476) # t15477: "cuda:0 f32[1, 71, 2048, 64]"
# t15478 = prims.convert_element_type(t15477, dtypes.bfloat16) # t15478: "cuda:0 bf16[1, 71, 2048, 64]"
# t15488 = prims.reshape(t15430, (1, 1, 71, 2048, 64)) # t15488: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15493 = prims.reshape(t15478, (1, 1, 71, 2048, 64)) # t15493: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15499 = prims.convert_element_type(t15483, dtypes.float32) # t15499: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t15500 = prims.sum(t15499, (0, 1, 2)) # t15500: "cuda:0 f32[2048, 64]"
# t15501 = prims.convert_element_type(t15500, dtypes.bfloat16) # t15501: "cuda:0 bf16[2048, 64]"
# t15502 = prims.broadcast_in_dim(t15501, [1, 1, 1, 2048, 64], [3, 4]) # t15502: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t15508 = prims.convert_element_type(t15488, dtypes.float32) # t15508: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t15509 = prims.sum(t15508, (0, 1, 2)) # t15509: "cuda:0 f32[2048, 64]"
# t15510 = prims.convert_element_type(t15509, dtypes.bfloat16) # t15510: "cuda:0 bf16[2048, 64]"
# t15511 = prims.broadcast_in_dim(t15510, [1, 1, 1, 2048, 64], [3, 4]) # t15511: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t15517 = prims.cat((t15493, t15511, t15502), i1371) # t15517: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1371, t15376, t15380, t15483
t15523 = torch.permute(t15517, (0, 3, 1, 2, 4)) # t15523: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t15523 = ltorch.permute(t15517, (0, 3, 1, 2, 4)) # t15523: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t15523 = prims.transpose(t15517, (0, 3, 1, 2, 4)) # t15523: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t15517
t15529 = torch.reshape(t15523, (1, 2048, 4672)) # t15529: "cuda:0 bf16[1, 2048, 4672]"
# t15529 = ltorch.reshape(t15523, (1, 2048, 4672)) # t15529: "cuda:0 bf16[1, 2048, 4672]"
# t15529 = prims.reshape(t15523, (1, 2048, 4672)) # t15529: "cuda:0 bf16[1, 2048, 4672]"
del t15523
t15530 = torch.reshape(t15529, (-1, 4672)) # t15530: "cuda:0 bf16[2048, 4672]"
# t15530 = ltorch.reshape(t15529, (-1, 4672)) # t15530: "cuda:0 bf16[2048, 4672]"
# t15530 = prims.reshape(t15529, (2048, 4672)) # t15530: "cuda:0 bf16[2048, 4672]"
del t15529
t15534 = torch.permute(t15530, (1, 0)) # t15534: "cuda:0 bf16[4672, 2048]"
# t15534 = ltorch.permute(t15530, (1, 0)) # t15534: "cuda:0 bf16[4672, 2048]"
# t15534 = prims.transpose(t15530, (1, 0)) # t15534: "cuda:0 bf16[4672, 2048]"
t15536 = torch.matmul(t15534, t15356) # t15536: "cuda:0 bf16[4672, 4544]"
# t15536 = ltorch.matmul(t15534, t15535) # t15536: "cuda:0 bf16[4672, 4544]"
# t15536 = prims.matmul(t15534, t15535) # t15536: "cuda:0 bf16[4672, 4544]"
del t15534, t15356
t15531 = torch.matmul(t15530, t_transformer_h_21_attn_attn_weight) # t15531: "cuda:0 bf16[2048, 4544]"
# t15531 = ltorch.matmul(t15530, t_transformer_h_21_attn_attn_weight) # t15531: "cuda:0 bf16[2048, 4544]"
# t15531 = prims.matmul(t15530, t_transformer_h_21_attn_attn_weight) # t15531: "cuda:0 bf16[2048, 4544]"
del t15530, t_transformer_h_21_attn_attn_weight
t15353 = torch.reshape(t15352, (1, 2048, 4544)) # t15353: "cuda:0 bf16[1, 2048, 4544]"
# t15353 = ltorch.reshape(t15352, (1, 2048, 4544)) # t15353: "cuda:0 bf16[1, 2048, 4544]"
# t15353 = prims.reshape(t15352, (1, 2048, 4544)) # t15353: "cuda:0 bf16[1, 2048, 4544]"
del t15352
t15532 = torch.reshape(t15531, (1, 2048, 4544)) # t15532: "cuda:0 bf16[1, 2048, 4544]"
# t15532 = ltorch.reshape(t15531, (1, 2048, 4544)) # t15532: "cuda:0 bf16[1, 2048, 4544]"
# t15532 = prims.reshape(t15531, (1, 2048, 4544)) # t15532: "cuda:0 bf16[1, 2048, 4544]"
del t15531
[t15545, t15551, t15593] = nvFusion33(i15573, t15310, t15353, t15532, t3191, t3323, t3344, t3359, t3364, t3370)
# t3350 = prims.convert_element_type(t3191, dtypes.float32) # t3350: "cuda:0 f32[1, 2048, 4544]"
# t3345 = prims.convert_element_type(t3344, dtypes.float32) # t3345: "cuda:0 f32[1, 2048, 4544]"
# t3346 = prims.convert_element_type(t3323, dtypes.float32) # t3346: "cuda:0 f32[1, 2048, 4544]"
# t3347 = prims.add(t3345, t3346) # t3347: "cuda:0 f32[1, 2048, 4544]"
# t3351 = prims.add(t3347, t3350) # t3351: "cuda:0 f32[1, 2048, 4544]"
# t3361 = prims.broadcast_in_dim(t3359, [1, 2048, 1], [0, 1]) # t3361: "cuda:0 f32[1, 2048, 1]"
# t3365 = prims.broadcast_in_dim(t3361, (1, 2048, 4544), (0, 1, 2)) # t3365: "cuda:0 f32[1, 2048, 4544]"
# t3367 = prims.sub(t3351, t3365) # t3367: "cuda:0 f32[1, 2048, 4544]"
# t3368 = prims.broadcast_in_dim(t3364, (1, 2048, 4544), (0, 1, 2)) # t3368: "cuda:0 f32[1, 2048, 4544]"
# t3369 = prims.mul(t3367, t3368) # t3369: "cuda:0 f32[1, 2048, 4544]"
# t3371 = prims.convert_element_type(t3370, dtypes.float32) # t3371: "cuda:0 f32[1, 2048, 4544]"
# t15590 = prims.convert_element_type(t15310, dtypes.float32) # t15590: "cuda:0 f32[1, 2048, 4544]"
# t15537 = prims.convert_element_type(t15353, dtypes.float32) # t15537: "cuda:0 f32[1, 2048, 4544]"
# t15538 = prims.convert_element_type(t15532, dtypes.float32) # t15538: "cuda:0 f32[1, 2048, 4544]"
# t15539 = prims.add(t15537, t15538) # t15539: "cuda:0 f32[1, 2048, 4544]"
# t15544 = prims.sum(t15539, (0, 1)) # t15544: "cuda:0 f32[4544]"
# t15545 = prims.convert_element_type(t15544, dtypes.bfloat16) # t15545: "cuda:0 bf16[4544]"
# t15546 = prims.mul(t3371, t15539) # t15546: "cuda:0 f32[1, 2048, 4544]"
# t15547 = prims.mul(t3369, t15539) # t15547: "cuda:0 f32[1, 2048, 4544]"
# t15550 = prims.sum(t15547, (0, 1)) # t15550: "cuda:0 f32[4544]"
# t15551 = prims.convert_element_type(t15550, dtypes.bfloat16) # t15551: "cuda:0 bf16[4544]"
# t15552 = prims.mul(t3368, t15546) # t15552: "cuda:0 f32[1, 2048, 4544]"
# t15553 = prims.mul(t3367, t15546) # t15553: "cuda:0 f32[1, 2048, 4544]"
# t15554 = prims.sum(t15553, (0, 2)) # t15554: "cuda:0 f32[2048]"
# t15555 = prims.broadcast_in_dim(t15554, [1, 2048, 1], [1]) # t15555: "cuda:0 f32[1, 2048, 1]"
# t15556 = prims.neg(t15552) # t15556: "cuda:0 f32[1, 2048, 4544]"
# t15558 = prims.sum(t15556, (0, 2)) # t15558: "cuda:0 f32[2048]"
# t15559 = prims.broadcast_in_dim(t15558, [1, 2048, 1], [1]) # t15559: "cuda:0 f32[1, 2048, 1]"
# t15560 = prims.mul(-0.5, t15555) # t15560: "cuda:0 f32[1, 2048, 1]"
# t15561 = prims.pow(t3364, 3.0) # t15561: "cuda:0 f32[1, 2048, 1]"
# t15562 = prims.mul(t15560, t15561) # t15562: "cuda:0 f32[1, 2048, 1]"
# t15564 = prims.sum(t15559, (0, 2)) # t15564: "cuda:0 f32[2048]"
# t15565 = prims.broadcast_in_dim(t15564, [1, 2048], [1]) # t15565: "cuda:0 f32[1, 2048]"
# t15566 = prims.sum(t15562, (0, 2)) # t15566: "cuda:0 f32[2048]"
# t15567 = prims.broadcast_in_dim(t15566, [1, 2048], [1]) # t15567: "cuda:0 f32[1, 2048]"
# t15570 = prims.broadcast_in_dim(t15565, [1, 2048, 1], [0, 1]) # t15570: "cuda:0 f32[1, 2048, 1]"
# t15571 = prims.broadcast_in_dim(t15570, (1, 2048, 4544), (0, 1, 2)) # t15571: "cuda:0 f32[1, 2048, 4544]"
# t15572 = prims.mul(0.00022007042253521127, t15571) # t15572: "cuda:0 f32[1, 2048, 4544]"
# t15574 = prims.broadcast_in_dim(t15567, [1, 2048, 1], [0, 1]) # t15574: "cuda:0 f32[1, 2048, 1]"
# t15575 = prims.broadcast_in_dim(t15574, (1, 2048, 4544), (0, 1, 2)) # t15575: "cuda:0 f32[1, 2048, 4544]"
# t15577 = prims.broadcast_in_dim(t3359, [1, 2048, 1], [0, 1]) # t15577: "cuda:0 f32[1, 2048, 1]"
# t15578 = prims.broadcast_in_dim(t15577, (1, 2048, 4544), (0, 1, 2)) # t15578: "cuda:0 f32[1, 2048, 4544]"
# t15579 = prims.mul(2.0, t15575) # t15579: "cuda:0 f32[1, 2048, 4544]"
# t15580 = prims.sub(t3351, t15578) # t15580: "cuda:0 f32[1, 2048, 4544]"
# t15581 = prims.mul(t15579, t15580) # t15581: "cuda:0 f32[1, 2048, 4544]"
# f15582 = prims.convert_element_type(i15573, float) # f15582: "float 4544.0"
# t15583 = prims.div(t15581, f15582) # t15583: "cuda:0 f32[1, 2048, 4544]"
# t15584 = prims.add(t15572, t15583) # t15584: "cuda:0 f32[1, 2048, 4544]"
# t15588 = prims.add(t15552, t15584) # t15588: "cuda:0 f32[1, 2048, 4544]"
# t15592 = prims.add(t15590, t15588) # t15592: "cuda:0 f32[1, 2048, 4544]"
# t15593 = prims.convert_element_type(t15592, dtypes.bfloat16) # t15593: "cuda:0 bf16[1, 2048, 4544]"
del i15573, t15310, t15353, t15532, t3191, t3323, t3344, t3359, t3364, t3370
t15600 = torch.reshape(t15593, (-1, 4544)) # t15600: "cuda:0 bf16[2048, 4544]"
# t15600 = ltorch.reshape(t15593, (-1, 4544)) # t15600: "cuda:0 bf16[2048, 4544]"
# t15600 = prims.reshape(t15593, (2048, 4544)) # t15600: "cuda:0 bf16[2048, 4544]"
t15604 = torch.permute(t15600, (1, 0)) # t15604: "cuda:0 bf16[4544, 2048]"
# t15604 = ltorch.permute(t15600, (1, 0)) # t15604: "cuda:0 bf16[4544, 2048]"
# t15604 = prims.transpose(t15600, (1, 0)) # t15604: "cuda:0 bf16[4544, 2048]"
t15601 = torch.matmul(t15600, t_transformer_h_20_mlp_proj_weight) # t15601: "cuda:0 bf16[2048, 18176]"
# t15601 = ltorch.matmul(t15600, t_transformer_h_20_mlp_proj_weight) # t15601: "cuda:0 bf16[2048, 18176]"
# t15601 = prims.matmul(t15600, t_transformer_h_20_mlp_proj_weight) # t15601: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_20_mlp_proj_weight
t15606 = torch.matmul(t15604, t15605) # t15606: "cuda:0 bf16[4544, 18176]"
# t15606 = ltorch.matmul(t15604, t15605) # t15606: "cuda:0 bf16[4544, 18176]"
# t15606 = prims.matmul(t15604, t15605) # t15606: "cuda:0 bf16[4544, 18176]"
del t15605
t15642 = torch.matmul(t15600, t_transformer_h_20_attn_proj_weight) # t15642: "cuda:0 bf16[2048, 4544]"
# t15642 = ltorch.matmul(t15641, t_transformer_h_20_attn_proj_weight) # t15642: "cuda:0 bf16[2048, 4544]"
# t15642 = prims.matmul(t15641, t_transformer_h_20_attn_proj_weight) # t15642: "cuda:0 bf16[2048, 4544]"
del t15600, t_transformer_h_20_attn_proj_weight
t15647 = torch.matmul(t15604, t15646) # t15647: "cuda:0 bf16[4544, 4544]"
# t15647 = ltorch.matmul(t15645, t15646) # t15647: "cuda:0 bf16[4544, 4544]"
# t15647 = prims.matmul(t15645, t15646) # t15647: "cuda:0 bf16[4544, 4544]"
del t15604, t15646
t15602 = torch.reshape(t15601, (1, 2048, 18176)) # t15602: "cuda:0 bf16[1, 2048, 18176]"
# t15602 = ltorch.reshape(t15601, (1, 2048, 18176)) # t15602: "cuda:0 bf16[1, 2048, 18176]"
# t15602 = prims.reshape(t15601, (1, 2048, 18176)) # t15602: "cuda:0 bf16[1, 2048, 18176]"
del t15601
t15643 = torch.reshape(t15642, (1, 2048, 4544)) # t15643: "cuda:0 bf16[1, 2048, 4544]"
# t15643 = ltorch.reshape(t15642, (1, 2048, 4544)) # t15643: "cuda:0 bf16[1, 2048, 4544]"
# t15643 = prims.reshape(t15642, (1, 2048, 4544)) # t15643: "cuda:0 bf16[1, 2048, 4544]"
del t15642
t15651 = torch.reshape(t15643, (1, 2048, 71, 64)) # t15651: "cuda:0 bf16[1, 2048, 71, 64]"
# t15651 = ltorch.reshape(t15643, (1, 2048, 71, 64)) # t15651: "cuda:0 bf16[1, 2048, 71, 64]"
# t15651 = prims.reshape(t15643, (1, 2048, 71, 64)) # t15651: "cuda:0 bf16[1, 2048, 71, 64]"
del t15643
t15654 = torch.permute(t15651, (0, 2, 1, 3)) # t15654: "cuda:0 bf16[1, 71, 2048, 64]"
# t15654 = ltorch.permute(t15651, (0, 2, 1, 3)) # t15654: "cuda:0 bf16[1, 71, 2048, 64]"
# t15654 = prims.transpose(t15651, (0, 2, 1, 3)) # t15654: "cuda:0 bf16[1, 71, 2048, 64]"
del t15651
[t15633] = nvFusion34(f1343, f1345, t15602, t3324)
# t3325 = prims.convert_element_type(t3324, dtypes.float32) # t3325: "cuda:0 f32[1, 2048, 18176]"
# t3327 = prims.div(t3325, 1.4142135623730951) # t3327: "cuda:0 f32[1, 2048, 18176]"
# t3330 = prims.erf(t3327) # t3330: "cuda:0 f32[1, 2048, 18176]"
# t3334 = prims.mul(0.5, t3330) # t3334: "cuda:0 f32[1, 2048, 18176]"
# t3338 = prims.add(0.5, t3334) # t3338: "cuda:0 f32[1, 2048, 18176]"
# t15607 = prims.convert_element_type(t15602, dtypes.float32) # t15607: "cuda:0 f32[1, 2048, 18176]"
# t15608 = prims.mul(t3338, t15607) # t15608: "cuda:0 f32[1, 2048, 18176]"
# t15609 = prims.mul(t3325, t15607) # t15609: "cuda:0 f32[1, 2048, 18176]"
# t15617 = prims.mul(f1345, t15609) # t15617: "cuda:0 f32[1, 2048, 18176]"
# t15620 = prims.pow(t3327, 2.0) # t15620: "cuda:0 f32[1, 2048, 18176]"
# t15621 = prims.neg(t15620) # t15621: "cuda:0 f32[1, 2048, 18176]"
# t15622 = prims.exp(t15621) # t15622: "cuda:0 f32[1, 2048, 18176]"
# t15623 = prims.mul(1.1283791670955126, t15622) # t15623: "cuda:0 f32[1, 2048, 18176]"
# t15624 = prims.mul(t15623, t15617) # t15624: "cuda:0 f32[1, 2048, 18176]"
# t15628 = prims.div(t15624, f1343) # t15628: "cuda:0 f32[1, 2048, 18176]"
# t15632 = prims.add(t15608, t15628) # t15632: "cuda:0 f32[1, 2048, 18176]"
# t15633 = prims.convert_element_type(t15632, dtypes.bfloat16) # t15633: "cuda:0 bf16[1, 2048, 18176]"
del f1343, f1345, t15602, t3324
t15634 = torch.reshape(t15633, (-1, 18176)) # t15634: "cuda:0 bf16[2048, 18176]"
# t15634 = ltorch.reshape(t15633, (-1, 18176)) # t15634: "cuda:0 bf16[2048, 18176]"
# t15634 = prims.reshape(t15633, (2048, 18176)) # t15634: "cuda:0 bf16[2048, 18176]"
del t15633
t15638 = torch.permute(t15634, (1, 0)) # t15638: "cuda:0 bf16[18176, 2048]"
# t15638 = ltorch.permute(t15634, (1, 0)) # t15638: "cuda:0 bf16[18176, 2048]"
# t15638 = prims.transpose(t15634, (1, 0)) # t15638: "cuda:0 bf16[18176, 2048]"
t15640 = torch.matmul(t15638, t15639) # t15640: "cuda:0 bf16[18176, 4544]"
# t15640 = ltorch.matmul(t15638, t15639) # t15640: "cuda:0 bf16[18176, 4544]"
# t15640 = prims.matmul(t15638, t15639) # t15640: "cuda:0 bf16[18176, 4544]"
del t15638
t15635 = torch.matmul(t15634, t_transformer_h_20_mlp_fc_weight) # t15635: "cuda:0 bf16[2048, 4544]"
# t15635 = ltorch.matmul(t15634, t_transformer_h_20_mlp_fc_weight) # t15635: "cuda:0 bf16[2048, 4544]"
# t15635 = prims.matmul(t15634, t_transformer_h_20_mlp_fc_weight) # t15635: "cuda:0 bf16[2048, 4544]"
del t15634, t_transformer_h_20_mlp_fc_weight
(t15655, t15656, t15657) = cudnn_sdpa_bwd(t15654, t3308, t3311, t3261, None, f1334, b1335, t3312, t3313, t3314, t3315, scale=f1336, cat_grad_qkv=False)
del t15654, t3308, t3311, t3261, f1334, b1335, t3312, t3313, t3314, t3315, f1336
t15659 = torch_slice_prim_impl(t15656, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15659: "cuda:0 bf16[1, 71, 2048, 64]"
del t15656
t15663 = torch_slice_prim_impl(t15655, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15663: "cuda:0 bf16[1, 71, 2048, 64]"
del t15655
t15766 = torch.reshape(t15657, (1, 1, 71, 2048, 64)) # t15766: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15766 = ltorch.reshape(t15657, (1, 1, 71, 2048, 64)) # t15766: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15766 = prims.reshape(t15657, (1, 1, 71, 2048, 64)) # t15766: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t15657
[t15800] = nvFusion35(i1307, t15659, t15663, t15766, t61, t66)
# t15660 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t15660: "cuda:0 bf16[1, 71, 2048, 0]"
# t15661 = prims.pad(t15660, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t15661: "cuda:0 bf16[1, 71, 2048, 64]"
# t15664 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t15664: "cuda:0 bf16[1, 71, 2048, 0]"
# t15665 = prims.pad(t15664, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t15665: "cuda:0 bf16[1, 71, 2048, 64]"
# t15666 = prims.convert_element_type(t15659, dtypes.float32) # t15666: "cuda:0 f32[1, 71, 2048, 64]"
# t15670 = prims.mul(t66, t15666) # t15670: "cuda:0 f32[1, 71, 2048, 64]"
# t15673 = prims.convert_element_type(t15670, dtypes.bfloat16) # t15673: "cuda:0 bf16[1, 71, 2048, 64]"
# t15682 = prims.mul(t61, t15666) # t15682: "cuda:0 f32[1, 71, 2048, 64]"
# t15694 = prims.slice_prim(t15673, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t15694: "cuda:0 bf16[1, 71, 2048, 32]"
# t15695 = prims.slice_prim(t15673, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15695: "cuda:0 bf16[1, 71, 2048, 32]"
# t15696 = prims.convert_element_type(t15694, dtypes.float32) # t15696: "cuda:0 f32[1, 71, 2048, 32]"
# t15697 = prims.neg(t15696) # t15697: "cuda:0 f32[1, 71, 2048, 32]"
# t15698 = prims.convert_element_type(t15697, dtypes.bfloat16) # t15698: "cuda:0 bf16[1, 71, 2048, 32]"
# t15699 = prims.pad(t15698, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t15699: "cuda:0 bf16[1, 71, 2048, 64]"
# t15701 = prims.convert_element_type(t15699, dtypes.float32) # t15701: "cuda:0 f32[1, 71, 2048, 64]"
# t15702 = prims.add(t15682, t15701) # t15702: "cuda:0 f32[1, 71, 2048, 64]"
# t15704 = prims.pad(t15695, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t15704: "cuda:0 bf16[1, 71, 2048, 64]"
# t15706 = prims.convert_element_type(t15704, dtypes.float32) # t15706: "cuda:0 f32[1, 71, 2048, 64]"
# t15707 = prims.add(t15702, t15706) # t15707: "cuda:0 f32[1, 71, 2048, 64]"
# t15708 = prims.convert_element_type(t15707, dtypes.bfloat16) # t15708: "cuda:0 bf16[1, 71, 2048, 64]"
# t15709 = prims.pad(t15708, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t15709: "cuda:0 bf16[1, 71, 2048, 64]"
# t15710 = prims.convert_element_type(t15661, dtypes.float32) # t15710: "cuda:0 f32[1, 71, 2048, 64]"
# t15711 = prims.convert_element_type(t15709, dtypes.float32) # t15711: "cuda:0 f32[1, 71, 2048, 64]"
# t15712 = prims.add(t15710, t15711) # t15712: "cuda:0 f32[1, 71, 2048, 64]"
# t15713 = prims.convert_element_type(t15712, dtypes.bfloat16) # t15713: "cuda:0 bf16[1, 71, 2048, 64]"
# t15714 = prims.convert_element_type(t15663, dtypes.float32) # t15714: "cuda:0 f32[1, 71, 2048, 64]"
# t15718 = prims.mul(t66, t15714) # t15718: "cuda:0 f32[1, 71, 2048, 64]"
# t15721 = prims.convert_element_type(t15718, dtypes.bfloat16) # t15721: "cuda:0 bf16[1, 71, 2048, 64]"
# t15730 = prims.mul(t61, t15714) # t15730: "cuda:0 f32[1, 71, 2048, 64]"
# t15742 = prims.slice_prim(t15721, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t15742: "cuda:0 bf16[1, 71, 2048, 32]"
# t15743 = prims.slice_prim(t15721, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15743: "cuda:0 bf16[1, 71, 2048, 32]"
# t15744 = prims.convert_element_type(t15742, dtypes.float32) # t15744: "cuda:0 f32[1, 71, 2048, 32]"
# t15745 = prims.neg(t15744) # t15745: "cuda:0 f32[1, 71, 2048, 32]"
# t15746 = prims.convert_element_type(t15745, dtypes.bfloat16) # t15746: "cuda:0 bf16[1, 71, 2048, 32]"
# t15747 = prims.pad(t15746, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t15747: "cuda:0 bf16[1, 71, 2048, 64]"
# t15749 = prims.convert_element_type(t15747, dtypes.float32) # t15749: "cuda:0 f32[1, 71, 2048, 64]"
# t15750 = prims.add(t15730, t15749) # t15750: "cuda:0 f32[1, 71, 2048, 64]"
# t15752 = prims.pad(t15743, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t15752: "cuda:0 bf16[1, 71, 2048, 64]"
# t15754 = prims.convert_element_type(t15752, dtypes.float32) # t15754: "cuda:0 f32[1, 71, 2048, 64]"
# t15755 = prims.add(t15750, t15754) # t15755: "cuda:0 f32[1, 71, 2048, 64]"
# t15756 = prims.convert_element_type(t15755, dtypes.bfloat16) # t15756: "cuda:0 bf16[1, 71, 2048, 64]"
# t15757 = prims.pad(t15756, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t15757: "cuda:0 bf16[1, 71, 2048, 64]"
# t15758 = prims.convert_element_type(t15665, dtypes.float32) # t15758: "cuda:0 f32[1, 71, 2048, 64]"
# t15759 = prims.convert_element_type(t15757, dtypes.float32) # t15759: "cuda:0 f32[1, 71, 2048, 64]"
# t15760 = prims.add(t15758, t15759) # t15760: "cuda:0 f32[1, 71, 2048, 64]"
# t15761 = prims.convert_element_type(t15760, dtypes.bfloat16) # t15761: "cuda:0 bf16[1, 71, 2048, 64]"
# t15771 = prims.reshape(t15713, (1, 1, 71, 2048, 64)) # t15771: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15776 = prims.reshape(t15761, (1, 1, 71, 2048, 64)) # t15776: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t15782 = prims.convert_element_type(t15766, dtypes.float32) # t15782: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t15783 = prims.sum(t15782, (0, 1, 2)) # t15783: "cuda:0 f32[2048, 64]"
# t15784 = prims.convert_element_type(t15783, dtypes.bfloat16) # t15784: "cuda:0 bf16[2048, 64]"
# t15785 = prims.broadcast_in_dim(t15784, [1, 1, 1, 2048, 64], [3, 4]) # t15785: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t15791 = prims.convert_element_type(t15771, dtypes.float32) # t15791: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t15792 = prims.sum(t15791, (0, 1, 2)) # t15792: "cuda:0 f32[2048, 64]"
# t15793 = prims.convert_element_type(t15792, dtypes.bfloat16) # t15793: "cuda:0 bf16[2048, 64]"
# t15794 = prims.broadcast_in_dim(t15793, [1, 1, 1, 2048, 64], [3, 4]) # t15794: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t15800 = prims.cat((t15776, t15794, t15785), i1307) # t15800: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1307, t15659, t15663, t15766
t15806 = torch.permute(t15800, (0, 3, 1, 2, 4)) # t15806: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t15806 = ltorch.permute(t15800, (0, 3, 1, 2, 4)) # t15806: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t15806 = prims.transpose(t15800, (0, 3, 1, 2, 4)) # t15806: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t15800
t15812 = torch.reshape(t15806, (1, 2048, 4672)) # t15812: "cuda:0 bf16[1, 2048, 4672]"
# t15812 = ltorch.reshape(t15806, (1, 2048, 4672)) # t15812: "cuda:0 bf16[1, 2048, 4672]"
# t15812 = prims.reshape(t15806, (1, 2048, 4672)) # t15812: "cuda:0 bf16[1, 2048, 4672]"
del t15806
t15813 = torch.reshape(t15812, (-1, 4672)) # t15813: "cuda:0 bf16[2048, 4672]"
# t15813 = ltorch.reshape(t15812, (-1, 4672)) # t15813: "cuda:0 bf16[2048, 4672]"
# t15813 = prims.reshape(t15812, (2048, 4672)) # t15813: "cuda:0 bf16[2048, 4672]"
del t15812
t15817 = torch.permute(t15813, (1, 0)) # t15817: "cuda:0 bf16[4672, 2048]"
# t15817 = ltorch.permute(t15813, (1, 0)) # t15817: "cuda:0 bf16[4672, 2048]"
# t15817 = prims.transpose(t15813, (1, 0)) # t15817: "cuda:0 bf16[4672, 2048]"
t15819 = torch.matmul(t15817, t15639) # t15819: "cuda:0 bf16[4672, 4544]"
# t15819 = ltorch.matmul(t15817, t15818) # t15819: "cuda:0 bf16[4672, 4544]"
# t15819 = prims.matmul(t15817, t15818) # t15819: "cuda:0 bf16[4672, 4544]"
del t15817, t15639
t15814 = torch.matmul(t15813, t_transformer_h_20_attn_attn_weight) # t15814: "cuda:0 bf16[2048, 4544]"
# t15814 = ltorch.matmul(t15813, t_transformer_h_20_attn_attn_weight) # t15814: "cuda:0 bf16[2048, 4544]"
# t15814 = prims.matmul(t15813, t_transformer_h_20_attn_attn_weight) # t15814: "cuda:0 bf16[2048, 4544]"
del t15813, t_transformer_h_20_attn_attn_weight
t15636 = torch.reshape(t15635, (1, 2048, 4544)) # t15636: "cuda:0 bf16[1, 2048, 4544]"
# t15636 = ltorch.reshape(t15635, (1, 2048, 4544)) # t15636: "cuda:0 bf16[1, 2048, 4544]"
# t15636 = prims.reshape(t15635, (1, 2048, 4544)) # t15636: "cuda:0 bf16[1, 2048, 4544]"
del t15635
t15815 = torch.reshape(t15814, (1, 2048, 4544)) # t15815: "cuda:0 bf16[1, 2048, 4544]"
# t15815 = ltorch.reshape(t15814, (1, 2048, 4544)) # t15815: "cuda:0 bf16[1, 2048, 4544]"
# t15815 = prims.reshape(t15814, (1, 2048, 4544)) # t15815: "cuda:0 bf16[1, 2048, 4544]"
del t15814
[t15828, t15834, t15876] = nvFusion36(i15856, t15593, t15636, t15815, t3030, t3162, t3183, t3198, t3203, t3209)
# t3189 = prims.convert_element_type(t3030, dtypes.float32) # t3189: "cuda:0 f32[1, 2048, 4544]"
# t3184 = prims.convert_element_type(t3183, dtypes.float32) # t3184: "cuda:0 f32[1, 2048, 4544]"
# t3185 = prims.convert_element_type(t3162, dtypes.float32) # t3185: "cuda:0 f32[1, 2048, 4544]"
# t3186 = prims.add(t3184, t3185) # t3186: "cuda:0 f32[1, 2048, 4544]"
# t3190 = prims.add(t3186, t3189) # t3190: "cuda:0 f32[1, 2048, 4544]"
# t3200 = prims.broadcast_in_dim(t3198, [1, 2048, 1], [0, 1]) # t3200: "cuda:0 f32[1, 2048, 1]"
# t3204 = prims.broadcast_in_dim(t3200, (1, 2048, 4544), (0, 1, 2)) # t3204: "cuda:0 f32[1, 2048, 4544]"
# t3206 = prims.sub(t3190, t3204) # t3206: "cuda:0 f32[1, 2048, 4544]"
# t3207 = prims.broadcast_in_dim(t3203, (1, 2048, 4544), (0, 1, 2)) # t3207: "cuda:0 f32[1, 2048, 4544]"
# t3208 = prims.mul(t3206, t3207) # t3208: "cuda:0 f32[1, 2048, 4544]"
# t3210 = prims.convert_element_type(t3209, dtypes.float32) # t3210: "cuda:0 f32[1, 2048, 4544]"
# t15873 = prims.convert_element_type(t15593, dtypes.float32) # t15873: "cuda:0 f32[1, 2048, 4544]"
# t15820 = prims.convert_element_type(t15636, dtypes.float32) # t15820: "cuda:0 f32[1, 2048, 4544]"
# t15821 = prims.convert_element_type(t15815, dtypes.float32) # t15821: "cuda:0 f32[1, 2048, 4544]"
# t15822 = prims.add(t15820, t15821) # t15822: "cuda:0 f32[1, 2048, 4544]"
# t15827 = prims.sum(t15822, (0, 1)) # t15827: "cuda:0 f32[4544]"
# t15828 = prims.convert_element_type(t15827, dtypes.bfloat16) # t15828: "cuda:0 bf16[4544]"
# t15829 = prims.mul(t3210, t15822) # t15829: "cuda:0 f32[1, 2048, 4544]"
# t15830 = prims.mul(t3208, t15822) # t15830: "cuda:0 f32[1, 2048, 4544]"
# t15833 = prims.sum(t15830, (0, 1)) # t15833: "cuda:0 f32[4544]"
# t15834 = prims.convert_element_type(t15833, dtypes.bfloat16) # t15834: "cuda:0 bf16[4544]"
# t15835 = prims.mul(t3207, t15829) # t15835: "cuda:0 f32[1, 2048, 4544]"
# t15836 = prims.mul(t3206, t15829) # t15836: "cuda:0 f32[1, 2048, 4544]"
# t15837 = prims.sum(t15836, (0, 2)) # t15837: "cuda:0 f32[2048]"
# t15838 = prims.broadcast_in_dim(t15837, [1, 2048, 1], [1]) # t15838: "cuda:0 f32[1, 2048, 1]"
# t15839 = prims.neg(t15835) # t15839: "cuda:0 f32[1, 2048, 4544]"
# t15841 = prims.sum(t15839, (0, 2)) # t15841: "cuda:0 f32[2048]"
# t15842 = prims.broadcast_in_dim(t15841, [1, 2048, 1], [1]) # t15842: "cuda:0 f32[1, 2048, 1]"
# t15843 = prims.mul(-0.5, t15838) # t15843: "cuda:0 f32[1, 2048, 1]"
# t15844 = prims.pow(t3203, 3.0) # t15844: "cuda:0 f32[1, 2048, 1]"
# t15845 = prims.mul(t15843, t15844) # t15845: "cuda:0 f32[1, 2048, 1]"
# t15847 = prims.sum(t15842, (0, 2)) # t15847: "cuda:0 f32[2048]"
# t15848 = prims.broadcast_in_dim(t15847, [1, 2048], [1]) # t15848: "cuda:0 f32[1, 2048]"
# t15849 = prims.sum(t15845, (0, 2)) # t15849: "cuda:0 f32[2048]"
# t15850 = prims.broadcast_in_dim(t15849, [1, 2048], [1]) # t15850: "cuda:0 f32[1, 2048]"
# t15853 = prims.broadcast_in_dim(t15848, [1, 2048, 1], [0, 1]) # t15853: "cuda:0 f32[1, 2048, 1]"
# t15854 = prims.broadcast_in_dim(t15853, (1, 2048, 4544), (0, 1, 2)) # t15854: "cuda:0 f32[1, 2048, 4544]"
# t15855 = prims.mul(0.00022007042253521127, t15854) # t15855: "cuda:0 f32[1, 2048, 4544]"
# t15857 = prims.broadcast_in_dim(t15850, [1, 2048, 1], [0, 1]) # t15857: "cuda:0 f32[1, 2048, 1]"
# t15858 = prims.broadcast_in_dim(t15857, (1, 2048, 4544), (0, 1, 2)) # t15858: "cuda:0 f32[1, 2048, 4544]"
# t15860 = prims.broadcast_in_dim(t3198, [1, 2048, 1], [0, 1]) # t15860: "cuda:0 f32[1, 2048, 1]"
# t15861 = prims.broadcast_in_dim(t15860, (1, 2048, 4544), (0, 1, 2)) # t15861: "cuda:0 f32[1, 2048, 4544]"
# t15862 = prims.mul(2.0, t15858) # t15862: "cuda:0 f32[1, 2048, 4544]"
# t15863 = prims.sub(t3190, t15861) # t15863: "cuda:0 f32[1, 2048, 4544]"
# t15864 = prims.mul(t15862, t15863) # t15864: "cuda:0 f32[1, 2048, 4544]"
# f15865 = prims.convert_element_type(i15856, float) # f15865: "float 4544.0"
# t15866 = prims.div(t15864, f15865) # t15866: "cuda:0 f32[1, 2048, 4544]"
# t15867 = prims.add(t15855, t15866) # t15867: "cuda:0 f32[1, 2048, 4544]"
# t15871 = prims.add(t15835, t15867) # t15871: "cuda:0 f32[1, 2048, 4544]"
# t15875 = prims.add(t15873, t15871) # t15875: "cuda:0 f32[1, 2048, 4544]"
# t15876 = prims.convert_element_type(t15875, dtypes.bfloat16) # t15876: "cuda:0 bf16[1, 2048, 4544]"
del i15856, t15593, t15636, t15815, t3030, t3162, t3183, t3198, t3203, t3209
t15883 = torch.reshape(t15876, (-1, 4544)) # t15883: "cuda:0 bf16[2048, 4544]"
# t15883 = ltorch.reshape(t15876, (-1, 4544)) # t15883: "cuda:0 bf16[2048, 4544]"
# t15883 = prims.reshape(t15876, (2048, 4544)) # t15883: "cuda:0 bf16[2048, 4544]"
t15887 = torch.permute(t15883, (1, 0)) # t15887: "cuda:0 bf16[4544, 2048]"
# t15887 = ltorch.permute(t15883, (1, 0)) # t15887: "cuda:0 bf16[4544, 2048]"
# t15887 = prims.transpose(t15883, (1, 0)) # t15887: "cuda:0 bf16[4544, 2048]"
t15889 = torch.matmul(t15887, t15888) # t15889: "cuda:0 bf16[4544, 18176]"
# t15889 = ltorch.matmul(t15887, t15888) # t15889: "cuda:0 bf16[4544, 18176]"
# t15889 = prims.matmul(t15887, t15888) # t15889: "cuda:0 bf16[4544, 18176]"
del t15888
t15925 = torch.matmul(t15883, t_transformer_h_19_attn_proj_weight) # t15925: "cuda:0 bf16[2048, 4544]"
# t15925 = ltorch.matmul(t15924, t_transformer_h_19_attn_proj_weight) # t15925: "cuda:0 bf16[2048, 4544]"
# t15925 = prims.matmul(t15924, t_transformer_h_19_attn_proj_weight) # t15925: "cuda:0 bf16[2048, 4544]"
del t_transformer_h_19_attn_proj_weight
t15930 = torch.matmul(t15887, t15929) # t15930: "cuda:0 bf16[4544, 4544]"
# t15930 = ltorch.matmul(t15928, t15929) # t15930: "cuda:0 bf16[4544, 4544]"
# t15930 = prims.matmul(t15928, t15929) # t15930: "cuda:0 bf16[4544, 4544]"
del t15887, t15929
t15884 = torch.matmul(t15883, t_transformer_h_19_mlp_proj_weight) # t15884: "cuda:0 bf16[2048, 18176]"
# t15884 = ltorch.matmul(t15883, t_transformer_h_19_mlp_proj_weight) # t15884: "cuda:0 bf16[2048, 18176]"
# t15884 = prims.matmul(t15883, t_transformer_h_19_mlp_proj_weight) # t15884: "cuda:0 bf16[2048, 18176]"
del t15883, t_transformer_h_19_mlp_proj_weight
t15926 = torch.reshape(t15925, (1, 2048, 4544)) # t15926: "cuda:0 bf16[1, 2048, 4544]"
# t15926 = ltorch.reshape(t15925, (1, 2048, 4544)) # t15926: "cuda:0 bf16[1, 2048, 4544]"
# t15926 = prims.reshape(t15925, (1, 2048, 4544)) # t15926: "cuda:0 bf16[1, 2048, 4544]"
del t15925
t15934 = torch.reshape(t15926, (1, 2048, 71, 64)) # t15934: "cuda:0 bf16[1, 2048, 71, 64]"
# t15934 = ltorch.reshape(t15926, (1, 2048, 71, 64)) # t15934: "cuda:0 bf16[1, 2048, 71, 64]"
# t15934 = prims.reshape(t15926, (1, 2048, 71, 64)) # t15934: "cuda:0 bf16[1, 2048, 71, 64]"
del t15926
t15937 = torch.permute(t15934, (0, 2, 1, 3)) # t15937: "cuda:0 bf16[1, 71, 2048, 64]"
# t15937 = ltorch.permute(t15934, (0, 2, 1, 3)) # t15937: "cuda:0 bf16[1, 71, 2048, 64]"
# t15937 = prims.transpose(t15934, (0, 2, 1, 3)) # t15937: "cuda:0 bf16[1, 71, 2048, 64]"
del t15934
t15885 = torch.reshape(t15884, (1, 2048, 18176)) # t15885: "cuda:0 bf16[1, 2048, 18176]"
# t15885 = ltorch.reshape(t15884, (1, 2048, 18176)) # t15885: "cuda:0 bf16[1, 2048, 18176]"
# t15885 = prims.reshape(t15884, (1, 2048, 18176)) # t15885: "cuda:0 bf16[1, 2048, 18176]"
del t15884
[t15916] = nvFusion37(f1279, f1281, t15885, t3163)
# t3164 = prims.convert_element_type(t3163, dtypes.float32) # t3164: "cuda:0 f32[1, 2048, 18176]"
# t3166 = prims.div(t3164, 1.4142135623730951) # t3166: "cuda:0 f32[1, 2048, 18176]"
# t3169 = prims.erf(t3166) # t3169: "cuda:0 f32[1, 2048, 18176]"
# t3173 = prims.mul(0.5, t3169) # t3173: "cuda:0 f32[1, 2048, 18176]"
# t3177 = prims.add(0.5, t3173) # t3177: "cuda:0 f32[1, 2048, 18176]"
# t15890 = prims.convert_element_type(t15885, dtypes.float32) # t15890: "cuda:0 f32[1, 2048, 18176]"
# t15891 = prims.mul(t3177, t15890) # t15891: "cuda:0 f32[1, 2048, 18176]"
# t15892 = prims.mul(t3164, t15890) # t15892: "cuda:0 f32[1, 2048, 18176]"
# t15900 = prims.mul(f1281, t15892) # t15900: "cuda:0 f32[1, 2048, 18176]"
# t15903 = prims.pow(t3166, 2.0) # t15903: "cuda:0 f32[1, 2048, 18176]"
# t15904 = prims.neg(t15903) # t15904: "cuda:0 f32[1, 2048, 18176]"
# t15905 = prims.exp(t15904) # t15905: "cuda:0 f32[1, 2048, 18176]"
# t15906 = prims.mul(1.1283791670955126, t15905) # t15906: "cuda:0 f32[1, 2048, 18176]"
# t15907 = prims.mul(t15906, t15900) # t15907: "cuda:0 f32[1, 2048, 18176]"
# t15911 = prims.div(t15907, f1279) # t15911: "cuda:0 f32[1, 2048, 18176]"
# t15915 = prims.add(t15891, t15911) # t15915: "cuda:0 f32[1, 2048, 18176]"
# t15916 = prims.convert_element_type(t15915, dtypes.bfloat16) # t15916: "cuda:0 bf16[1, 2048, 18176]"
del f1279, f1281, t15885, t3163
t15917 = torch.reshape(t15916, (-1, 18176)) # t15917: "cuda:0 bf16[2048, 18176]"
# t15917 = ltorch.reshape(t15916, (-1, 18176)) # t15917: "cuda:0 bf16[2048, 18176]"
# t15917 = prims.reshape(t15916, (2048, 18176)) # t15917: "cuda:0 bf16[2048, 18176]"
del t15916
t15921 = torch.permute(t15917, (1, 0)) # t15921: "cuda:0 bf16[18176, 2048]"
# t15921 = ltorch.permute(t15917, (1, 0)) # t15921: "cuda:0 bf16[18176, 2048]"
# t15921 = prims.transpose(t15917, (1, 0)) # t15921: "cuda:0 bf16[18176, 2048]"
(t15938, t15939, t15940) = cudnn_sdpa_bwd(t15937, t3147, t3150, t3100, None, f1270, b1271, t3151, t3152, t3153, t3154, scale=f1272, cat_grad_qkv=False)
del t15937, t3147, t3150, t3100, f1270, b1271, t3151, t3152, t3153, t3154, f1272
t15923 = torch.matmul(t15921, t15922) # t15923: "cuda:0 bf16[18176, 4544]"
# t15923 = ltorch.matmul(t15921, t15922) # t15923: "cuda:0 bf16[18176, 4544]"
# t15923 = prims.matmul(t15921, t15922) # t15923: "cuda:0 bf16[18176, 4544]"
del t15921
t15918 = torch.matmul(t15917, t_transformer_h_19_mlp_fc_weight) # t15918: "cuda:0 bf16[2048, 4544]"
# t15918 = ltorch.matmul(t15917, t_transformer_h_19_mlp_fc_weight) # t15918: "cuda:0 bf16[2048, 4544]"
# t15918 = prims.matmul(t15917, t_transformer_h_19_mlp_fc_weight) # t15918: "cuda:0 bf16[2048, 4544]"
del t15917, t_transformer_h_19_mlp_fc_weight
t15942 = torch_slice_prim_impl(t15939, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15942: "cuda:0 bf16[1, 71, 2048, 64]"
del t15939
t15946 = torch_slice_prim_impl(t15938, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15946: "cuda:0 bf16[1, 71, 2048, 64]"
del t15938
t16049 = torch.reshape(t15940, (1, 1, 71, 2048, 64)) # t16049: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16049 = ltorch.reshape(t15940, (1, 1, 71, 2048, 64)) # t16049: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16049 = prims.reshape(t15940, (1, 1, 71, 2048, 64)) # t16049: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t15940
[t16083] = nvFusion38(i1243, t15942, t15946, t16049, t61, t66)
# t15943 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t15943: "cuda:0 bf16[1, 71, 2048, 0]"
# t15944 = prims.pad(t15943, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t15944: "cuda:0 bf16[1, 71, 2048, 64]"
# t15947 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t15947: "cuda:0 bf16[1, 71, 2048, 0]"
# t15948 = prims.pad(t15947, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t15948: "cuda:0 bf16[1, 71, 2048, 64]"
# t15949 = prims.convert_element_type(t15942, dtypes.float32) # t15949: "cuda:0 f32[1, 71, 2048, 64]"
# t15953 = prims.mul(t66, t15949) # t15953: "cuda:0 f32[1, 71, 2048, 64]"
# t15956 = prims.convert_element_type(t15953, dtypes.bfloat16) # t15956: "cuda:0 bf16[1, 71, 2048, 64]"
# t15965 = prims.mul(t61, t15949) # t15965: "cuda:0 f32[1, 71, 2048, 64]"
# t15977 = prims.slice_prim(t15956, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t15977: "cuda:0 bf16[1, 71, 2048, 32]"
# t15978 = prims.slice_prim(t15956, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t15978: "cuda:0 bf16[1, 71, 2048, 32]"
# t15979 = prims.convert_element_type(t15977, dtypes.float32) # t15979: "cuda:0 f32[1, 71, 2048, 32]"
# t15980 = prims.neg(t15979) # t15980: "cuda:0 f32[1, 71, 2048, 32]"
# t15981 = prims.convert_element_type(t15980, dtypes.bfloat16) # t15981: "cuda:0 bf16[1, 71, 2048, 32]"
# t15982 = prims.pad(t15981, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t15982: "cuda:0 bf16[1, 71, 2048, 64]"
# t15984 = prims.convert_element_type(t15982, dtypes.float32) # t15984: "cuda:0 f32[1, 71, 2048, 64]"
# t15985 = prims.add(t15965, t15984) # t15985: "cuda:0 f32[1, 71, 2048, 64]"
# t15987 = prims.pad(t15978, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t15987: "cuda:0 bf16[1, 71, 2048, 64]"
# t15989 = prims.convert_element_type(t15987, dtypes.float32) # t15989: "cuda:0 f32[1, 71, 2048, 64]"
# t15990 = prims.add(t15985, t15989) # t15990: "cuda:0 f32[1, 71, 2048, 64]"
# t15991 = prims.convert_element_type(t15990, dtypes.bfloat16) # t15991: "cuda:0 bf16[1, 71, 2048, 64]"
# t15992 = prims.pad(t15991, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t15992: "cuda:0 bf16[1, 71, 2048, 64]"
# t15993 = prims.convert_element_type(t15944, dtypes.float32) # t15993: "cuda:0 f32[1, 71, 2048, 64]"
# t15994 = prims.convert_element_type(t15992, dtypes.float32) # t15994: "cuda:0 f32[1, 71, 2048, 64]"
# t15995 = prims.add(t15993, t15994) # t15995: "cuda:0 f32[1, 71, 2048, 64]"
# t15996 = prims.convert_element_type(t15995, dtypes.bfloat16) # t15996: "cuda:0 bf16[1, 71, 2048, 64]"
# t15997 = prims.convert_element_type(t15946, dtypes.float32) # t15997: "cuda:0 f32[1, 71, 2048, 64]"
# t16001 = prims.mul(t66, t15997) # t16001: "cuda:0 f32[1, 71, 2048, 64]"
# t16004 = prims.convert_element_type(t16001, dtypes.bfloat16) # t16004: "cuda:0 bf16[1, 71, 2048, 64]"
# t16013 = prims.mul(t61, t15997) # t16013: "cuda:0 f32[1, 71, 2048, 64]"
# t16025 = prims.slice_prim(t16004, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t16025: "cuda:0 bf16[1, 71, 2048, 32]"
# t16026 = prims.slice_prim(t16004, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16026: "cuda:0 bf16[1, 71, 2048, 32]"
# t16027 = prims.convert_element_type(t16025, dtypes.float32) # t16027: "cuda:0 f32[1, 71, 2048, 32]"
# t16028 = prims.neg(t16027) # t16028: "cuda:0 f32[1, 71, 2048, 32]"
# t16029 = prims.convert_element_type(t16028, dtypes.bfloat16) # t16029: "cuda:0 bf16[1, 71, 2048, 32]"
# t16030 = prims.pad(t16029, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t16030: "cuda:0 bf16[1, 71, 2048, 64]"
# t16032 = prims.convert_element_type(t16030, dtypes.float32) # t16032: "cuda:0 f32[1, 71, 2048, 64]"
# t16033 = prims.add(t16013, t16032) # t16033: "cuda:0 f32[1, 71, 2048, 64]"
# t16035 = prims.pad(t16026, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t16035: "cuda:0 bf16[1, 71, 2048, 64]"
# t16037 = prims.convert_element_type(t16035, dtypes.float32) # t16037: "cuda:0 f32[1, 71, 2048, 64]"
# t16038 = prims.add(t16033, t16037) # t16038: "cuda:0 f32[1, 71, 2048, 64]"
# t16039 = prims.convert_element_type(t16038, dtypes.bfloat16) # t16039: "cuda:0 bf16[1, 71, 2048, 64]"
# t16040 = prims.pad(t16039, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t16040: "cuda:0 bf16[1, 71, 2048, 64]"
# t16041 = prims.convert_element_type(t15948, dtypes.float32) # t16041: "cuda:0 f32[1, 71, 2048, 64]"
# t16042 = prims.convert_element_type(t16040, dtypes.float32) # t16042: "cuda:0 f32[1, 71, 2048, 64]"
# t16043 = prims.add(t16041, t16042) # t16043: "cuda:0 f32[1, 71, 2048, 64]"
# t16044 = prims.convert_element_type(t16043, dtypes.bfloat16) # t16044: "cuda:0 bf16[1, 71, 2048, 64]"
# t16054 = prims.reshape(t15996, (1, 1, 71, 2048, 64)) # t16054: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16059 = prims.reshape(t16044, (1, 1, 71, 2048, 64)) # t16059: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16065 = prims.convert_element_type(t16049, dtypes.float32) # t16065: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t16066 = prims.sum(t16065, (0, 1, 2)) # t16066: "cuda:0 f32[2048, 64]"
# t16067 = prims.convert_element_type(t16066, dtypes.bfloat16) # t16067: "cuda:0 bf16[2048, 64]"
# t16068 = prims.broadcast_in_dim(t16067, [1, 1, 1, 2048, 64], [3, 4]) # t16068: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t16074 = prims.convert_element_type(t16054, dtypes.float32) # t16074: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t16075 = prims.sum(t16074, (0, 1, 2)) # t16075: "cuda:0 f32[2048, 64]"
# t16076 = prims.convert_element_type(t16075, dtypes.bfloat16) # t16076: "cuda:0 bf16[2048, 64]"
# t16077 = prims.broadcast_in_dim(t16076, [1, 1, 1, 2048, 64], [3, 4]) # t16077: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t16083 = prims.cat((t16059, t16077, t16068), i1243) # t16083: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1243, t15942, t15946, t16049
t16089 = torch.permute(t16083, (0, 3, 1, 2, 4)) # t16089: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t16089 = ltorch.permute(t16083, (0, 3, 1, 2, 4)) # t16089: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t16089 = prims.transpose(t16083, (0, 3, 1, 2, 4)) # t16089: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t16083
t16095 = torch.reshape(t16089, (1, 2048, 4672)) # t16095: "cuda:0 bf16[1, 2048, 4672]"
# t16095 = ltorch.reshape(t16089, (1, 2048, 4672)) # t16095: "cuda:0 bf16[1, 2048, 4672]"
# t16095 = prims.reshape(t16089, (1, 2048, 4672)) # t16095: "cuda:0 bf16[1, 2048, 4672]"
del t16089
t16096 = torch.reshape(t16095, (-1, 4672)) # t16096: "cuda:0 bf16[2048, 4672]"
# t16096 = ltorch.reshape(t16095, (-1, 4672)) # t16096: "cuda:0 bf16[2048, 4672]"
# t16096 = prims.reshape(t16095, (2048, 4672)) # t16096: "cuda:0 bf16[2048, 4672]"
del t16095
t16100 = torch.permute(t16096, (1, 0)) # t16100: "cuda:0 bf16[4672, 2048]"
# t16100 = ltorch.permute(t16096, (1, 0)) # t16100: "cuda:0 bf16[4672, 2048]"
# t16100 = prims.transpose(t16096, (1, 0)) # t16100: "cuda:0 bf16[4672, 2048]"
t16102 = torch.matmul(t16100, t15922) # t16102: "cuda:0 bf16[4672, 4544]"
# t16102 = ltorch.matmul(t16100, t16101) # t16102: "cuda:0 bf16[4672, 4544]"
# t16102 = prims.matmul(t16100, t16101) # t16102: "cuda:0 bf16[4672, 4544]"
del t16100, t15922
t16097 = torch.matmul(t16096, t_transformer_h_19_attn_attn_weight) # t16097: "cuda:0 bf16[2048, 4544]"
# t16097 = ltorch.matmul(t16096, t_transformer_h_19_attn_attn_weight) # t16097: "cuda:0 bf16[2048, 4544]"
# t16097 = prims.matmul(t16096, t_transformer_h_19_attn_attn_weight) # t16097: "cuda:0 bf16[2048, 4544]"
del t16096, t_transformer_h_19_attn_attn_weight
t15919 = torch.reshape(t15918, (1, 2048, 4544)) # t15919: "cuda:0 bf16[1, 2048, 4544]"
# t15919 = ltorch.reshape(t15918, (1, 2048, 4544)) # t15919: "cuda:0 bf16[1, 2048, 4544]"
# t15919 = prims.reshape(t15918, (1, 2048, 4544)) # t15919: "cuda:0 bf16[1, 2048, 4544]"
del t15918
t16098 = torch.reshape(t16097, (1, 2048, 4544)) # t16098: "cuda:0 bf16[1, 2048, 4544]"
# t16098 = ltorch.reshape(t16097, (1, 2048, 4544)) # t16098: "cuda:0 bf16[1, 2048, 4544]"
# t16098 = prims.reshape(t16097, (1, 2048, 4544)) # t16098: "cuda:0 bf16[1, 2048, 4544]"
del t16097
[t16111, t16117, t16159] = nvFusion39(i16139, t15876, t15919, t16098, t2869, t3001, t3022, t3037, t3042, t3048)
# t3028 = prims.convert_element_type(t2869, dtypes.float32) # t3028: "cuda:0 f32[1, 2048, 4544]"
# t3023 = prims.convert_element_type(t3022, dtypes.float32) # t3023: "cuda:0 f32[1, 2048, 4544]"
# t3024 = prims.convert_element_type(t3001, dtypes.float32) # t3024: "cuda:0 f32[1, 2048, 4544]"
# t3025 = prims.add(t3023, t3024) # t3025: "cuda:0 f32[1, 2048, 4544]"
# t3029 = prims.add(t3025, t3028) # t3029: "cuda:0 f32[1, 2048, 4544]"
# t3039 = prims.broadcast_in_dim(t3037, [1, 2048, 1], [0, 1]) # t3039: "cuda:0 f32[1, 2048, 1]"
# t3043 = prims.broadcast_in_dim(t3039, (1, 2048, 4544), (0, 1, 2)) # t3043: "cuda:0 f32[1, 2048, 4544]"
# t3045 = prims.sub(t3029, t3043) # t3045: "cuda:0 f32[1, 2048, 4544]"
# t3046 = prims.broadcast_in_dim(t3042, (1, 2048, 4544), (0, 1, 2)) # t3046: "cuda:0 f32[1, 2048, 4544]"
# t3047 = prims.mul(t3045, t3046) # t3047: "cuda:0 f32[1, 2048, 4544]"
# t3049 = prims.convert_element_type(t3048, dtypes.float32) # t3049: "cuda:0 f32[1, 2048, 4544]"
# t16156 = prims.convert_element_type(t15876, dtypes.float32) # t16156: "cuda:0 f32[1, 2048, 4544]"
# t16103 = prims.convert_element_type(t15919, dtypes.float32) # t16103: "cuda:0 f32[1, 2048, 4544]"
# t16104 = prims.convert_element_type(t16098, dtypes.float32) # t16104: "cuda:0 f32[1, 2048, 4544]"
# t16105 = prims.add(t16103, t16104) # t16105: "cuda:0 f32[1, 2048, 4544]"
# t16110 = prims.sum(t16105, (0, 1)) # t16110: "cuda:0 f32[4544]"
# t16111 = prims.convert_element_type(t16110, dtypes.bfloat16) # t16111: "cuda:0 bf16[4544]"
# t16112 = prims.mul(t3049, t16105) # t16112: "cuda:0 f32[1, 2048, 4544]"
# t16113 = prims.mul(t3047, t16105) # t16113: "cuda:0 f32[1, 2048, 4544]"
# t16116 = prims.sum(t16113, (0, 1)) # t16116: "cuda:0 f32[4544]"
# t16117 = prims.convert_element_type(t16116, dtypes.bfloat16) # t16117: "cuda:0 bf16[4544]"
# t16118 = prims.mul(t3046, t16112) # t16118: "cuda:0 f32[1, 2048, 4544]"
# t16119 = prims.mul(t3045, t16112) # t16119: "cuda:0 f32[1, 2048, 4544]"
# t16120 = prims.sum(t16119, (0, 2)) # t16120: "cuda:0 f32[2048]"
# t16121 = prims.broadcast_in_dim(t16120, [1, 2048, 1], [1]) # t16121: "cuda:0 f32[1, 2048, 1]"
# t16122 = prims.neg(t16118) # t16122: "cuda:0 f32[1, 2048, 4544]"
# t16124 = prims.sum(t16122, (0, 2)) # t16124: "cuda:0 f32[2048]"
# t16125 = prims.broadcast_in_dim(t16124, [1, 2048, 1], [1]) # t16125: "cuda:0 f32[1, 2048, 1]"
# t16126 = prims.mul(-0.5, t16121) # t16126: "cuda:0 f32[1, 2048, 1]"
# t16127 = prims.pow(t3042, 3.0) # t16127: "cuda:0 f32[1, 2048, 1]"
# t16128 = prims.mul(t16126, t16127) # t16128: "cuda:0 f32[1, 2048, 1]"
# t16130 = prims.sum(t16125, (0, 2)) # t16130: "cuda:0 f32[2048]"
# t16131 = prims.broadcast_in_dim(t16130, [1, 2048], [1]) # t16131: "cuda:0 f32[1, 2048]"
# t16132 = prims.sum(t16128, (0, 2)) # t16132: "cuda:0 f32[2048]"
# t16133 = prims.broadcast_in_dim(t16132, [1, 2048], [1]) # t16133: "cuda:0 f32[1, 2048]"
# t16136 = prims.broadcast_in_dim(t16131, [1, 2048, 1], [0, 1]) # t16136: "cuda:0 f32[1, 2048, 1]"
# t16137 = prims.broadcast_in_dim(t16136, (1, 2048, 4544), (0, 1, 2)) # t16137: "cuda:0 f32[1, 2048, 4544]"
# t16138 = prims.mul(0.00022007042253521127, t16137) # t16138: "cuda:0 f32[1, 2048, 4544]"
# t16140 = prims.broadcast_in_dim(t16133, [1, 2048, 1], [0, 1]) # t16140: "cuda:0 f32[1, 2048, 1]"
# t16141 = prims.broadcast_in_dim(t16140, (1, 2048, 4544), (0, 1, 2)) # t16141: "cuda:0 f32[1, 2048, 4544]"
# t16143 = prims.broadcast_in_dim(t3037, [1, 2048, 1], [0, 1]) # t16143: "cuda:0 f32[1, 2048, 1]"
# t16144 = prims.broadcast_in_dim(t16143, (1, 2048, 4544), (0, 1, 2)) # t16144: "cuda:0 f32[1, 2048, 4544]"
# t16145 = prims.mul(2.0, t16141) # t16145: "cuda:0 f32[1, 2048, 4544]"
# t16146 = prims.sub(t3029, t16144) # t16146: "cuda:0 f32[1, 2048, 4544]"
# t16147 = prims.mul(t16145, t16146) # t16147: "cuda:0 f32[1, 2048, 4544]"
# f16148 = prims.convert_element_type(i16139, float) # f16148: "float 4544.0"
# t16149 = prims.div(t16147, f16148) # t16149: "cuda:0 f32[1, 2048, 4544]"
# t16150 = prims.add(t16138, t16149) # t16150: "cuda:0 f32[1, 2048, 4544]"
# t16154 = prims.add(t16118, t16150) # t16154: "cuda:0 f32[1, 2048, 4544]"
# t16158 = prims.add(t16156, t16154) # t16158: "cuda:0 f32[1, 2048, 4544]"
# t16159 = prims.convert_element_type(t16158, dtypes.bfloat16) # t16159: "cuda:0 bf16[1, 2048, 4544]"
del i16139, t15876, t15919, t16098, t2869, t3001, t3022, t3037, t3042, t3048
t16166 = torch.reshape(t16159, (-1, 4544)) # t16166: "cuda:0 bf16[2048, 4544]"
# t16166 = ltorch.reshape(t16159, (-1, 4544)) # t16166: "cuda:0 bf16[2048, 4544]"
# t16166 = prims.reshape(t16159, (2048, 4544)) # t16166: "cuda:0 bf16[2048, 4544]"
t16170 = torch.permute(t16166, (1, 0)) # t16170: "cuda:0 bf16[4544, 2048]"
# t16170 = ltorch.permute(t16166, (1, 0)) # t16170: "cuda:0 bf16[4544, 2048]"
# t16170 = prims.transpose(t16166, (1, 0)) # t16170: "cuda:0 bf16[4544, 2048]"
t16213 = torch.matmul(t16170, t16212) # t16213: "cuda:0 bf16[4544, 4544]"
# t16213 = ltorch.matmul(t16211, t16212) # t16213: "cuda:0 bf16[4544, 4544]"
# t16213 = prims.matmul(t16211, t16212) # t16213: "cuda:0 bf16[4544, 4544]"
del t16212
t16167 = torch.matmul(t16166, t_transformer_h_18_mlp_proj_weight) # t16167: "cuda:0 bf16[2048, 18176]"
# t16167 = ltorch.matmul(t16166, t_transformer_h_18_mlp_proj_weight) # t16167: "cuda:0 bf16[2048, 18176]"
# t16167 = prims.matmul(t16166, t_transformer_h_18_mlp_proj_weight) # t16167: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_18_mlp_proj_weight
t16172 = torch.matmul(t16170, t16171) # t16172: "cuda:0 bf16[4544, 18176]"
# t16172 = ltorch.matmul(t16170, t16171) # t16172: "cuda:0 bf16[4544, 18176]"
# t16172 = prims.matmul(t16170, t16171) # t16172: "cuda:0 bf16[4544, 18176]"
del t16170, t16171
t16208 = torch.matmul(t16166, t_transformer_h_18_attn_proj_weight) # t16208: "cuda:0 bf16[2048, 4544]"
# t16208 = ltorch.matmul(t16207, t_transformer_h_18_attn_proj_weight) # t16208: "cuda:0 bf16[2048, 4544]"
# t16208 = prims.matmul(t16207, t_transformer_h_18_attn_proj_weight) # t16208: "cuda:0 bf16[2048, 4544]"
del t16166, t_transformer_h_18_attn_proj_weight
t16168 = torch.reshape(t16167, (1, 2048, 18176)) # t16168: "cuda:0 bf16[1, 2048, 18176]"
# t16168 = ltorch.reshape(t16167, (1, 2048, 18176)) # t16168: "cuda:0 bf16[1, 2048, 18176]"
# t16168 = prims.reshape(t16167, (1, 2048, 18176)) # t16168: "cuda:0 bf16[1, 2048, 18176]"
del t16167
t16209 = torch.reshape(t16208, (1, 2048, 4544)) # t16209: "cuda:0 bf16[1, 2048, 4544]"
# t16209 = ltorch.reshape(t16208, (1, 2048, 4544)) # t16209: "cuda:0 bf16[1, 2048, 4544]"
# t16209 = prims.reshape(t16208, (1, 2048, 4544)) # t16209: "cuda:0 bf16[1, 2048, 4544]"
del t16208
t16217 = torch.reshape(t16209, (1, 2048, 71, 64)) # t16217: "cuda:0 bf16[1, 2048, 71, 64]"
# t16217 = ltorch.reshape(t16209, (1, 2048, 71, 64)) # t16217: "cuda:0 bf16[1, 2048, 71, 64]"
# t16217 = prims.reshape(t16209, (1, 2048, 71, 64)) # t16217: "cuda:0 bf16[1, 2048, 71, 64]"
del t16209
t16220 = torch.permute(t16217, (0, 2, 1, 3)) # t16220: "cuda:0 bf16[1, 71, 2048, 64]"
# t16220 = ltorch.permute(t16217, (0, 2, 1, 3)) # t16220: "cuda:0 bf16[1, 71, 2048, 64]"
# t16220 = prims.transpose(t16217, (0, 2, 1, 3)) # t16220: "cuda:0 bf16[1, 71, 2048, 64]"
del t16217
[t16199] = nvFusion40(f1215, f1217, t16168, t3002)
# t3003 = prims.convert_element_type(t3002, dtypes.float32) # t3003: "cuda:0 f32[1, 2048, 18176]"
# t3005 = prims.div(t3003, 1.4142135623730951) # t3005: "cuda:0 f32[1, 2048, 18176]"
# t3008 = prims.erf(t3005) # t3008: "cuda:0 f32[1, 2048, 18176]"
# t3012 = prims.mul(0.5, t3008) # t3012: "cuda:0 f32[1, 2048, 18176]"
# t3016 = prims.add(0.5, t3012) # t3016: "cuda:0 f32[1, 2048, 18176]"
# t16173 = prims.convert_element_type(t16168, dtypes.float32) # t16173: "cuda:0 f32[1, 2048, 18176]"
# t16174 = prims.mul(t3016, t16173) # t16174: "cuda:0 f32[1, 2048, 18176]"
# t16175 = prims.mul(t3003, t16173) # t16175: "cuda:0 f32[1, 2048, 18176]"
# t16183 = prims.mul(f1217, t16175) # t16183: "cuda:0 f32[1, 2048, 18176]"
# t16186 = prims.pow(t3005, 2.0) # t16186: "cuda:0 f32[1, 2048, 18176]"
# t16187 = prims.neg(t16186) # t16187: "cuda:0 f32[1, 2048, 18176]"
# t16188 = prims.exp(t16187) # t16188: "cuda:0 f32[1, 2048, 18176]"
# t16189 = prims.mul(1.1283791670955126, t16188) # t16189: "cuda:0 f32[1, 2048, 18176]"
# t16190 = prims.mul(t16189, t16183) # t16190: "cuda:0 f32[1, 2048, 18176]"
# t16194 = prims.div(t16190, f1215) # t16194: "cuda:0 f32[1, 2048, 18176]"
# t16198 = prims.add(t16174, t16194) # t16198: "cuda:0 f32[1, 2048, 18176]"
# t16199 = prims.convert_element_type(t16198, dtypes.bfloat16) # t16199: "cuda:0 bf16[1, 2048, 18176]"
del f1215, f1217, t16168, t3002
t16200 = torch.reshape(t16199, (-1, 18176)) # t16200: "cuda:0 bf16[2048, 18176]"
# t16200 = ltorch.reshape(t16199, (-1, 18176)) # t16200: "cuda:0 bf16[2048, 18176]"
# t16200 = prims.reshape(t16199, (2048, 18176)) # t16200: "cuda:0 bf16[2048, 18176]"
del t16199
t16204 = torch.permute(t16200, (1, 0)) # t16204: "cuda:0 bf16[18176, 2048]"
# t16204 = ltorch.permute(t16200, (1, 0)) # t16204: "cuda:0 bf16[18176, 2048]"
# t16204 = prims.transpose(t16200, (1, 0)) # t16204: "cuda:0 bf16[18176, 2048]"
t16206 = torch.matmul(t16204, t16205) # t16206: "cuda:0 bf16[18176, 4544]"
# t16206 = ltorch.matmul(t16204, t16205) # t16206: "cuda:0 bf16[18176, 4544]"
# t16206 = prims.matmul(t16204, t16205) # t16206: "cuda:0 bf16[18176, 4544]"
del t16204
t16201 = torch.matmul(t16200, t_transformer_h_18_mlp_fc_weight) # t16201: "cuda:0 bf16[2048, 4544]"
# t16201 = ltorch.matmul(t16200, t_transformer_h_18_mlp_fc_weight) # t16201: "cuda:0 bf16[2048, 4544]"
# t16201 = prims.matmul(t16200, t_transformer_h_18_mlp_fc_weight) # t16201: "cuda:0 bf16[2048, 4544]"
del t16200, t_transformer_h_18_mlp_fc_weight
(t16221, t16222, t16223) = cudnn_sdpa_bwd(t16220, t2986, t2989, t2939, None, f1206, b1207, t2990, t2991, t2992, t2993, scale=f1208, cat_grad_qkv=False)
del t16220, t2986, t2989, t2939, f1206, b1207, t2990, t2991, t2992, t2993, f1208
t16225 = torch_slice_prim_impl(t16222, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16225: "cuda:0 bf16[1, 71, 2048, 64]"
del t16222
t16229 = torch_slice_prim_impl(t16221, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16229: "cuda:0 bf16[1, 71, 2048, 64]"
del t16221
t16332 = torch.reshape(t16223, (1, 1, 71, 2048, 64)) # t16332: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16332 = ltorch.reshape(t16223, (1, 1, 71, 2048, 64)) # t16332: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16332 = prims.reshape(t16223, (1, 1, 71, 2048, 64)) # t16332: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t16223
[t16366] = nvFusion41(i1179, t16225, t16229, t16332, t61, t66)
# t16226 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t16226: "cuda:0 bf16[1, 71, 2048, 0]"
# t16227 = prims.pad(t16226, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t16227: "cuda:0 bf16[1, 71, 2048, 64]"
# t16230 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t16230: "cuda:0 bf16[1, 71, 2048, 0]"
# t16231 = prims.pad(t16230, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t16231: "cuda:0 bf16[1, 71, 2048, 64]"
# t16232 = prims.convert_element_type(t16225, dtypes.float32) # t16232: "cuda:0 f32[1, 71, 2048, 64]"
# t16236 = prims.mul(t66, t16232) # t16236: "cuda:0 f32[1, 71, 2048, 64]"
# t16239 = prims.convert_element_type(t16236, dtypes.bfloat16) # t16239: "cuda:0 bf16[1, 71, 2048, 64]"
# t16248 = prims.mul(t61, t16232) # t16248: "cuda:0 f32[1, 71, 2048, 64]"
# t16260 = prims.slice_prim(t16239, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t16260: "cuda:0 bf16[1, 71, 2048, 32]"
# t16261 = prims.slice_prim(t16239, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16261: "cuda:0 bf16[1, 71, 2048, 32]"
# t16262 = prims.convert_element_type(t16260, dtypes.float32) # t16262: "cuda:0 f32[1, 71, 2048, 32]"
# t16263 = prims.neg(t16262) # t16263: "cuda:0 f32[1, 71, 2048, 32]"
# t16264 = prims.convert_element_type(t16263, dtypes.bfloat16) # t16264: "cuda:0 bf16[1, 71, 2048, 32]"
# t16265 = prims.pad(t16264, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t16265: "cuda:0 bf16[1, 71, 2048, 64]"
# t16267 = prims.convert_element_type(t16265, dtypes.float32) # t16267: "cuda:0 f32[1, 71, 2048, 64]"
# t16268 = prims.add(t16248, t16267) # t16268: "cuda:0 f32[1, 71, 2048, 64]"
# t16270 = prims.pad(t16261, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t16270: "cuda:0 bf16[1, 71, 2048, 64]"
# t16272 = prims.convert_element_type(t16270, dtypes.float32) # t16272: "cuda:0 f32[1, 71, 2048, 64]"
# t16273 = prims.add(t16268, t16272) # t16273: "cuda:0 f32[1, 71, 2048, 64]"
# t16274 = prims.convert_element_type(t16273, dtypes.bfloat16) # t16274: "cuda:0 bf16[1, 71, 2048, 64]"
# t16275 = prims.pad(t16274, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t16275: "cuda:0 bf16[1, 71, 2048, 64]"
# t16276 = prims.convert_element_type(t16227, dtypes.float32) # t16276: "cuda:0 f32[1, 71, 2048, 64]"
# t16277 = prims.convert_element_type(t16275, dtypes.float32) # t16277: "cuda:0 f32[1, 71, 2048, 64]"
# t16278 = prims.add(t16276, t16277) # t16278: "cuda:0 f32[1, 71, 2048, 64]"
# t16279 = prims.convert_element_type(t16278, dtypes.bfloat16) # t16279: "cuda:0 bf16[1, 71, 2048, 64]"
# t16280 = prims.convert_element_type(t16229, dtypes.float32) # t16280: "cuda:0 f32[1, 71, 2048, 64]"
# t16284 = prims.mul(t66, t16280) # t16284: "cuda:0 f32[1, 71, 2048, 64]"
# t16287 = prims.convert_element_type(t16284, dtypes.bfloat16) # t16287: "cuda:0 bf16[1, 71, 2048, 64]"
# t16296 = prims.mul(t61, t16280) # t16296: "cuda:0 f32[1, 71, 2048, 64]"
# t16308 = prims.slice_prim(t16287, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t16308: "cuda:0 bf16[1, 71, 2048, 32]"
# t16309 = prims.slice_prim(t16287, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16309: "cuda:0 bf16[1, 71, 2048, 32]"
# t16310 = prims.convert_element_type(t16308, dtypes.float32) # t16310: "cuda:0 f32[1, 71, 2048, 32]"
# t16311 = prims.neg(t16310) # t16311: "cuda:0 f32[1, 71, 2048, 32]"
# t16312 = prims.convert_element_type(t16311, dtypes.bfloat16) # t16312: "cuda:0 bf16[1, 71, 2048, 32]"
# t16313 = prims.pad(t16312, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t16313: "cuda:0 bf16[1, 71, 2048, 64]"
# t16315 = prims.convert_element_type(t16313, dtypes.float32) # t16315: "cuda:0 f32[1, 71, 2048, 64]"
# t16316 = prims.add(t16296, t16315) # t16316: "cuda:0 f32[1, 71, 2048, 64]"
# t16318 = prims.pad(t16309, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t16318: "cuda:0 bf16[1, 71, 2048, 64]"
# t16320 = prims.convert_element_type(t16318, dtypes.float32) # t16320: "cuda:0 f32[1, 71, 2048, 64]"
# t16321 = prims.add(t16316, t16320) # t16321: "cuda:0 f32[1, 71, 2048, 64]"
# t16322 = prims.convert_element_type(t16321, dtypes.bfloat16) # t16322: "cuda:0 bf16[1, 71, 2048, 64]"
# t16323 = prims.pad(t16322, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t16323: "cuda:0 bf16[1, 71, 2048, 64]"
# t16324 = prims.convert_element_type(t16231, dtypes.float32) # t16324: "cuda:0 f32[1, 71, 2048, 64]"
# t16325 = prims.convert_element_type(t16323, dtypes.float32) # t16325: "cuda:0 f32[1, 71, 2048, 64]"
# t16326 = prims.add(t16324, t16325) # t16326: "cuda:0 f32[1, 71, 2048, 64]"
# t16327 = prims.convert_element_type(t16326, dtypes.bfloat16) # t16327: "cuda:0 bf16[1, 71, 2048, 64]"
# t16337 = prims.reshape(t16279, (1, 1, 71, 2048, 64)) # t16337: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16342 = prims.reshape(t16327, (1, 1, 71, 2048, 64)) # t16342: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16348 = prims.convert_element_type(t16332, dtypes.float32) # t16348: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t16349 = prims.sum(t16348, (0, 1, 2)) # t16349: "cuda:0 f32[2048, 64]"
# t16350 = prims.convert_element_type(t16349, dtypes.bfloat16) # t16350: "cuda:0 bf16[2048, 64]"
# t16351 = prims.broadcast_in_dim(t16350, [1, 1, 1, 2048, 64], [3, 4]) # t16351: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t16357 = prims.convert_element_type(t16337, dtypes.float32) # t16357: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t16358 = prims.sum(t16357, (0, 1, 2)) # t16358: "cuda:0 f32[2048, 64]"
# t16359 = prims.convert_element_type(t16358, dtypes.bfloat16) # t16359: "cuda:0 bf16[2048, 64]"
# t16360 = prims.broadcast_in_dim(t16359, [1, 1, 1, 2048, 64], [3, 4]) # t16360: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t16366 = prims.cat((t16342, t16360, t16351), i1179) # t16366: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1179, t16225, t16229, t16332
t16372 = torch.permute(t16366, (0, 3, 1, 2, 4)) # t16372: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t16372 = ltorch.permute(t16366, (0, 3, 1, 2, 4)) # t16372: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t16372 = prims.transpose(t16366, (0, 3, 1, 2, 4)) # t16372: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t16366
t16378 = torch.reshape(t16372, (1, 2048, 4672)) # t16378: "cuda:0 bf16[1, 2048, 4672]"
# t16378 = ltorch.reshape(t16372, (1, 2048, 4672)) # t16378: "cuda:0 bf16[1, 2048, 4672]"
# t16378 = prims.reshape(t16372, (1, 2048, 4672)) # t16378: "cuda:0 bf16[1, 2048, 4672]"
del t16372
t16379 = torch.reshape(t16378, (-1, 4672)) # t16379: "cuda:0 bf16[2048, 4672]"
# t16379 = ltorch.reshape(t16378, (-1, 4672)) # t16379: "cuda:0 bf16[2048, 4672]"
# t16379 = prims.reshape(t16378, (2048, 4672)) # t16379: "cuda:0 bf16[2048, 4672]"
del t16378
t16383 = torch.permute(t16379, (1, 0)) # t16383: "cuda:0 bf16[4672, 2048]"
# t16383 = ltorch.permute(t16379, (1, 0)) # t16383: "cuda:0 bf16[4672, 2048]"
# t16383 = prims.transpose(t16379, (1, 0)) # t16383: "cuda:0 bf16[4672, 2048]"
t16385 = torch.matmul(t16383, t16205) # t16385: "cuda:0 bf16[4672, 4544]"
# t16385 = ltorch.matmul(t16383, t16384) # t16385: "cuda:0 bf16[4672, 4544]"
# t16385 = prims.matmul(t16383, t16384) # t16385: "cuda:0 bf16[4672, 4544]"
del t16383, t16205
t16380 = torch.matmul(t16379, t_transformer_h_18_attn_attn_weight) # t16380: "cuda:0 bf16[2048, 4544]"
# t16380 = ltorch.matmul(t16379, t_transformer_h_18_attn_attn_weight) # t16380: "cuda:0 bf16[2048, 4544]"
# t16380 = prims.matmul(t16379, t_transformer_h_18_attn_attn_weight) # t16380: "cuda:0 bf16[2048, 4544]"
del t16379, t_transformer_h_18_attn_attn_weight
t16202 = torch.reshape(t16201, (1, 2048, 4544)) # t16202: "cuda:0 bf16[1, 2048, 4544]"
# t16202 = ltorch.reshape(t16201, (1, 2048, 4544)) # t16202: "cuda:0 bf16[1, 2048, 4544]"
# t16202 = prims.reshape(t16201, (1, 2048, 4544)) # t16202: "cuda:0 bf16[1, 2048, 4544]"
del t16201
t16381 = torch.reshape(t16380, (1, 2048, 4544)) # t16381: "cuda:0 bf16[1, 2048, 4544]"
# t16381 = ltorch.reshape(t16380, (1, 2048, 4544)) # t16381: "cuda:0 bf16[1, 2048, 4544]"
# t16381 = prims.reshape(t16380, (1, 2048, 4544)) # t16381: "cuda:0 bf16[1, 2048, 4544]"
del t16380
[t16394, t16400, t16442] = nvFusion42(i16422, t16159, t16202, t16381, t2708, t2840, t2861, t2876, t2881, t2887)
# t2867 = prims.convert_element_type(t2708, dtypes.float32) # t2867: "cuda:0 f32[1, 2048, 4544]"
# t2862 = prims.convert_element_type(t2861, dtypes.float32) # t2862: "cuda:0 f32[1, 2048, 4544]"
# t2863 = prims.convert_element_type(t2840, dtypes.float32) # t2863: "cuda:0 f32[1, 2048, 4544]"
# t2864 = prims.add(t2862, t2863) # t2864: "cuda:0 f32[1, 2048, 4544]"
# t2868 = prims.add(t2864, t2867) # t2868: "cuda:0 f32[1, 2048, 4544]"
# t2878 = prims.broadcast_in_dim(t2876, [1, 2048, 1], [0, 1]) # t2878: "cuda:0 f32[1, 2048, 1]"
# t2882 = prims.broadcast_in_dim(t2878, (1, 2048, 4544), (0, 1, 2)) # t2882: "cuda:0 f32[1, 2048, 4544]"
# t2884 = prims.sub(t2868, t2882) # t2884: "cuda:0 f32[1, 2048, 4544]"
# t2885 = prims.broadcast_in_dim(t2881, (1, 2048, 4544), (0, 1, 2)) # t2885: "cuda:0 f32[1, 2048, 4544]"
# t2886 = prims.mul(t2884, t2885) # t2886: "cuda:0 f32[1, 2048, 4544]"
# t2888 = prims.convert_element_type(t2887, dtypes.float32) # t2888: "cuda:0 f32[1, 2048, 4544]"
# t16439 = prims.convert_element_type(t16159, dtypes.float32) # t16439: "cuda:0 f32[1, 2048, 4544]"
# t16386 = prims.convert_element_type(t16202, dtypes.float32) # t16386: "cuda:0 f32[1, 2048, 4544]"
# t16387 = prims.convert_element_type(t16381, dtypes.float32) # t16387: "cuda:0 f32[1, 2048, 4544]"
# t16388 = prims.add(t16386, t16387) # t16388: "cuda:0 f32[1, 2048, 4544]"
# t16393 = prims.sum(t16388, (0, 1)) # t16393: "cuda:0 f32[4544]"
# t16394 = prims.convert_element_type(t16393, dtypes.bfloat16) # t16394: "cuda:0 bf16[4544]"
# t16395 = prims.mul(t2888, t16388) # t16395: "cuda:0 f32[1, 2048, 4544]"
# t16396 = prims.mul(t2886, t16388) # t16396: "cuda:0 f32[1, 2048, 4544]"
# t16399 = prims.sum(t16396, (0, 1)) # t16399: "cuda:0 f32[4544]"
# t16400 = prims.convert_element_type(t16399, dtypes.bfloat16) # t16400: "cuda:0 bf16[4544]"
# t16401 = prims.mul(t2885, t16395) # t16401: "cuda:0 f32[1, 2048, 4544]"
# t16402 = prims.mul(t2884, t16395) # t16402: "cuda:0 f32[1, 2048, 4544]"
# t16403 = prims.sum(t16402, (0, 2)) # t16403: "cuda:0 f32[2048]"
# t16404 = prims.broadcast_in_dim(t16403, [1, 2048, 1], [1]) # t16404: "cuda:0 f32[1, 2048, 1]"
# t16405 = prims.neg(t16401) # t16405: "cuda:0 f32[1, 2048, 4544]"
# t16407 = prims.sum(t16405, (0, 2)) # t16407: "cuda:0 f32[2048]"
# t16408 = prims.broadcast_in_dim(t16407, [1, 2048, 1], [1]) # t16408: "cuda:0 f32[1, 2048, 1]"
# t16409 = prims.mul(-0.5, t16404) # t16409: "cuda:0 f32[1, 2048, 1]"
# t16410 = prims.pow(t2881, 3.0) # t16410: "cuda:0 f32[1, 2048, 1]"
# t16411 = prims.mul(t16409, t16410) # t16411: "cuda:0 f32[1, 2048, 1]"
# t16413 = prims.sum(t16408, (0, 2)) # t16413: "cuda:0 f32[2048]"
# t16414 = prims.broadcast_in_dim(t16413, [1, 2048], [1]) # t16414: "cuda:0 f32[1, 2048]"
# t16415 = prims.sum(t16411, (0, 2)) # t16415: "cuda:0 f32[2048]"
# t16416 = prims.broadcast_in_dim(t16415, [1, 2048], [1]) # t16416: "cuda:0 f32[1, 2048]"
# t16419 = prims.broadcast_in_dim(t16414, [1, 2048, 1], [0, 1]) # t16419: "cuda:0 f32[1, 2048, 1]"
# t16420 = prims.broadcast_in_dim(t16419, (1, 2048, 4544), (0, 1, 2)) # t16420: "cuda:0 f32[1, 2048, 4544]"
# t16421 = prims.mul(0.00022007042253521127, t16420) # t16421: "cuda:0 f32[1, 2048, 4544]"
# t16423 = prims.broadcast_in_dim(t16416, [1, 2048, 1], [0, 1]) # t16423: "cuda:0 f32[1, 2048, 1]"
# t16424 = prims.broadcast_in_dim(t16423, (1, 2048, 4544), (0, 1, 2)) # t16424: "cuda:0 f32[1, 2048, 4544]"
# t16426 = prims.broadcast_in_dim(t2876, [1, 2048, 1], [0, 1]) # t16426: "cuda:0 f32[1, 2048, 1]"
# t16427 = prims.broadcast_in_dim(t16426, (1, 2048, 4544), (0, 1, 2)) # t16427: "cuda:0 f32[1, 2048, 4544]"
# t16428 = prims.mul(2.0, t16424) # t16428: "cuda:0 f32[1, 2048, 4544]"
# t16429 = prims.sub(t2868, t16427) # t16429: "cuda:0 f32[1, 2048, 4544]"
# t16430 = prims.mul(t16428, t16429) # t16430: "cuda:0 f32[1, 2048, 4544]"
# f16431 = prims.convert_element_type(i16422, float) # f16431: "float 4544.0"
# t16432 = prims.div(t16430, f16431) # t16432: "cuda:0 f32[1, 2048, 4544]"
# t16433 = prims.add(t16421, t16432) # t16433: "cuda:0 f32[1, 2048, 4544]"
# t16437 = prims.add(t16401, t16433) # t16437: "cuda:0 f32[1, 2048, 4544]"
# t16441 = prims.add(t16439, t16437) # t16441: "cuda:0 f32[1, 2048, 4544]"
# t16442 = prims.convert_element_type(t16441, dtypes.bfloat16) # t16442: "cuda:0 bf16[1, 2048, 4544]"
del i16422, t16159, t16202, t16381, t2708, t2840, t2861, t2876, t2881, t2887
t16449 = torch.reshape(t16442, (-1, 4544)) # t16449: "cuda:0 bf16[2048, 4544]"
# t16449 = ltorch.reshape(t16442, (-1, 4544)) # t16449: "cuda:0 bf16[2048, 4544]"
# t16449 = prims.reshape(t16442, (2048, 4544)) # t16449: "cuda:0 bf16[2048, 4544]"
t16453 = torch.permute(t16449, (1, 0)) # t16453: "cuda:0 bf16[4544, 2048]"
# t16453 = ltorch.permute(t16449, (1, 0)) # t16453: "cuda:0 bf16[4544, 2048]"
# t16453 = prims.transpose(t16449, (1, 0)) # t16453: "cuda:0 bf16[4544, 2048]"
t16450 = torch.matmul(t16449, t_transformer_h_17_mlp_proj_weight) # t16450: "cuda:0 bf16[2048, 18176]"
# t16450 = ltorch.matmul(t16449, t_transformer_h_17_mlp_proj_weight) # t16450: "cuda:0 bf16[2048, 18176]"
# t16450 = prims.matmul(t16449, t_transformer_h_17_mlp_proj_weight) # t16450: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_17_mlp_proj_weight
t16455 = torch.matmul(t16453, t16454) # t16455: "cuda:0 bf16[4544, 18176]"
# t16455 = ltorch.matmul(t16453, t16454) # t16455: "cuda:0 bf16[4544, 18176]"
# t16455 = prims.matmul(t16453, t16454) # t16455: "cuda:0 bf16[4544, 18176]"
del t16454
t16491 = torch.matmul(t16449, t_transformer_h_17_attn_proj_weight) # t16491: "cuda:0 bf16[2048, 4544]"
# t16491 = ltorch.matmul(t16490, t_transformer_h_17_attn_proj_weight) # t16491: "cuda:0 bf16[2048, 4544]"
# t16491 = prims.matmul(t16490, t_transformer_h_17_attn_proj_weight) # t16491: "cuda:0 bf16[2048, 4544]"
del t16449, t_transformer_h_17_attn_proj_weight
t16496 = torch.matmul(t16453, t16495) # t16496: "cuda:0 bf16[4544, 4544]"
# t16496 = ltorch.matmul(t16494, t16495) # t16496: "cuda:0 bf16[4544, 4544]"
# t16496 = prims.matmul(t16494, t16495) # t16496: "cuda:0 bf16[4544, 4544]"
del t16453, t16495
t16451 = torch.reshape(t16450, (1, 2048, 18176)) # t16451: "cuda:0 bf16[1, 2048, 18176]"
# t16451 = ltorch.reshape(t16450, (1, 2048, 18176)) # t16451: "cuda:0 bf16[1, 2048, 18176]"
# t16451 = prims.reshape(t16450, (1, 2048, 18176)) # t16451: "cuda:0 bf16[1, 2048, 18176]"
del t16450
t16492 = torch.reshape(t16491, (1, 2048, 4544)) # t16492: "cuda:0 bf16[1, 2048, 4544]"
# t16492 = ltorch.reshape(t16491, (1, 2048, 4544)) # t16492: "cuda:0 bf16[1, 2048, 4544]"
# t16492 = prims.reshape(t16491, (1, 2048, 4544)) # t16492: "cuda:0 bf16[1, 2048, 4544]"
del t16491
t16500 = torch.reshape(t16492, (1, 2048, 71, 64)) # t16500: "cuda:0 bf16[1, 2048, 71, 64]"
# t16500 = ltorch.reshape(t16492, (1, 2048, 71, 64)) # t16500: "cuda:0 bf16[1, 2048, 71, 64]"
# t16500 = prims.reshape(t16492, (1, 2048, 71, 64)) # t16500: "cuda:0 bf16[1, 2048, 71, 64]"
del t16492
t16503 = torch.permute(t16500, (0, 2, 1, 3)) # t16503: "cuda:0 bf16[1, 71, 2048, 64]"
# t16503 = ltorch.permute(t16500, (0, 2, 1, 3)) # t16503: "cuda:0 bf16[1, 71, 2048, 64]"
# t16503 = prims.transpose(t16500, (0, 2, 1, 3)) # t16503: "cuda:0 bf16[1, 71, 2048, 64]"
del t16500
[t16482] = nvFusion43(f1151, f1153, t16451, t2841)
# t2842 = prims.convert_element_type(t2841, dtypes.float32) # t2842: "cuda:0 f32[1, 2048, 18176]"
# t2844 = prims.div(t2842, 1.4142135623730951) # t2844: "cuda:0 f32[1, 2048, 18176]"
# t2847 = prims.erf(t2844) # t2847: "cuda:0 f32[1, 2048, 18176]"
# t2851 = prims.mul(0.5, t2847) # t2851: "cuda:0 f32[1, 2048, 18176]"
# t2855 = prims.add(0.5, t2851) # t2855: "cuda:0 f32[1, 2048, 18176]"
# t16456 = prims.convert_element_type(t16451, dtypes.float32) # t16456: "cuda:0 f32[1, 2048, 18176]"
# t16457 = prims.mul(t2855, t16456) # t16457: "cuda:0 f32[1, 2048, 18176]"
# t16458 = prims.mul(t2842, t16456) # t16458: "cuda:0 f32[1, 2048, 18176]"
# t16466 = prims.mul(f1153, t16458) # t16466: "cuda:0 f32[1, 2048, 18176]"
# t16469 = prims.pow(t2844, 2.0) # t16469: "cuda:0 f32[1, 2048, 18176]"
# t16470 = prims.neg(t16469) # t16470: "cuda:0 f32[1, 2048, 18176]"
# t16471 = prims.exp(t16470) # t16471: "cuda:0 f32[1, 2048, 18176]"
# t16472 = prims.mul(1.1283791670955126, t16471) # t16472: "cuda:0 f32[1, 2048, 18176]"
# t16473 = prims.mul(t16472, t16466) # t16473: "cuda:0 f32[1, 2048, 18176]"
# t16477 = prims.div(t16473, f1151) # t16477: "cuda:0 f32[1, 2048, 18176]"
# t16481 = prims.add(t16457, t16477) # t16481: "cuda:0 f32[1, 2048, 18176]"
# t16482 = prims.convert_element_type(t16481, dtypes.bfloat16) # t16482: "cuda:0 bf16[1, 2048, 18176]"
del f1151, f1153, t16451, t2841
t16483 = torch.reshape(t16482, (-1, 18176)) # t16483: "cuda:0 bf16[2048, 18176]"
# t16483 = ltorch.reshape(t16482, (-1, 18176)) # t16483: "cuda:0 bf16[2048, 18176]"
# t16483 = prims.reshape(t16482, (2048, 18176)) # t16483: "cuda:0 bf16[2048, 18176]"
del t16482
t16487 = torch.permute(t16483, (1, 0)) # t16487: "cuda:0 bf16[18176, 2048]"
# t16487 = ltorch.permute(t16483, (1, 0)) # t16487: "cuda:0 bf16[18176, 2048]"
# t16487 = prims.transpose(t16483, (1, 0)) # t16487: "cuda:0 bf16[18176, 2048]"
t16489 = torch.matmul(t16487, t16488) # t16489: "cuda:0 bf16[18176, 4544]"
# t16489 = ltorch.matmul(t16487, t16488) # t16489: "cuda:0 bf16[18176, 4544]"
# t16489 = prims.matmul(t16487, t16488) # t16489: "cuda:0 bf16[18176, 4544]"
del t16487
t16484 = torch.matmul(t16483, t_transformer_h_17_mlp_fc_weight) # t16484: "cuda:0 bf16[2048, 4544]"
# t16484 = ltorch.matmul(t16483, t_transformer_h_17_mlp_fc_weight) # t16484: "cuda:0 bf16[2048, 4544]"
# t16484 = prims.matmul(t16483, t_transformer_h_17_mlp_fc_weight) # t16484: "cuda:0 bf16[2048, 4544]"
del t16483, t_transformer_h_17_mlp_fc_weight
(t16504, t16505, t16506) = cudnn_sdpa_bwd(t16503, t2825, t2828, t2778, None, f1142, b1143, t2829, t2830, t2831, t2832, scale=f1144, cat_grad_qkv=False)
del t16503, t2825, t2828, t2778, f1142, b1143, t2829, t2830, t2831, t2832, f1144
t16508 = torch_slice_prim_impl(t16505, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16508: "cuda:0 bf16[1, 71, 2048, 64]"
del t16505
t16512 = torch_slice_prim_impl(t16504, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16512: "cuda:0 bf16[1, 71, 2048, 64]"
del t16504
t16615 = torch.reshape(t16506, (1, 1, 71, 2048, 64)) # t16615: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16615 = ltorch.reshape(t16506, (1, 1, 71, 2048, 64)) # t16615: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16615 = prims.reshape(t16506, (1, 1, 71, 2048, 64)) # t16615: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t16506
[t16649] = nvFusion44(i1115, t16508, t16512, t16615, t61, t66)
# t16509 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t16509: "cuda:0 bf16[1, 71, 2048, 0]"
# t16510 = prims.pad(t16509, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t16510: "cuda:0 bf16[1, 71, 2048, 64]"
# t16513 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t16513: "cuda:0 bf16[1, 71, 2048, 0]"
# t16514 = prims.pad(t16513, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t16514: "cuda:0 bf16[1, 71, 2048, 64]"
# t16515 = prims.convert_element_type(t16508, dtypes.float32) # t16515: "cuda:0 f32[1, 71, 2048, 64]"
# t16519 = prims.mul(t66, t16515) # t16519: "cuda:0 f32[1, 71, 2048, 64]"
# t16522 = prims.convert_element_type(t16519, dtypes.bfloat16) # t16522: "cuda:0 bf16[1, 71, 2048, 64]"
# t16531 = prims.mul(t61, t16515) # t16531: "cuda:0 f32[1, 71, 2048, 64]"
# t16543 = prims.slice_prim(t16522, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t16543: "cuda:0 bf16[1, 71, 2048, 32]"
# t16544 = prims.slice_prim(t16522, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16544: "cuda:0 bf16[1, 71, 2048, 32]"
# t16545 = prims.convert_element_type(t16543, dtypes.float32) # t16545: "cuda:0 f32[1, 71, 2048, 32]"
# t16546 = prims.neg(t16545) # t16546: "cuda:0 f32[1, 71, 2048, 32]"
# t16547 = prims.convert_element_type(t16546, dtypes.bfloat16) # t16547: "cuda:0 bf16[1, 71, 2048, 32]"
# t16548 = prims.pad(t16547, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t16548: "cuda:0 bf16[1, 71, 2048, 64]"
# t16550 = prims.convert_element_type(t16548, dtypes.float32) # t16550: "cuda:0 f32[1, 71, 2048, 64]"
# t16551 = prims.add(t16531, t16550) # t16551: "cuda:0 f32[1, 71, 2048, 64]"
# t16553 = prims.pad(t16544, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t16553: "cuda:0 bf16[1, 71, 2048, 64]"
# t16555 = prims.convert_element_type(t16553, dtypes.float32) # t16555: "cuda:0 f32[1, 71, 2048, 64]"
# t16556 = prims.add(t16551, t16555) # t16556: "cuda:0 f32[1, 71, 2048, 64]"
# t16557 = prims.convert_element_type(t16556, dtypes.bfloat16) # t16557: "cuda:0 bf16[1, 71, 2048, 64]"
# t16558 = prims.pad(t16557, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t16558: "cuda:0 bf16[1, 71, 2048, 64]"
# t16559 = prims.convert_element_type(t16510, dtypes.float32) # t16559: "cuda:0 f32[1, 71, 2048, 64]"
# t16560 = prims.convert_element_type(t16558, dtypes.float32) # t16560: "cuda:0 f32[1, 71, 2048, 64]"
# t16561 = prims.add(t16559, t16560) # t16561: "cuda:0 f32[1, 71, 2048, 64]"
# t16562 = prims.convert_element_type(t16561, dtypes.bfloat16) # t16562: "cuda:0 bf16[1, 71, 2048, 64]"
# t16563 = prims.convert_element_type(t16512, dtypes.float32) # t16563: "cuda:0 f32[1, 71, 2048, 64]"
# t16567 = prims.mul(t66, t16563) # t16567: "cuda:0 f32[1, 71, 2048, 64]"
# t16570 = prims.convert_element_type(t16567, dtypes.bfloat16) # t16570: "cuda:0 bf16[1, 71, 2048, 64]"
# t16579 = prims.mul(t61, t16563) # t16579: "cuda:0 f32[1, 71, 2048, 64]"
# t16591 = prims.slice_prim(t16570, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t16591: "cuda:0 bf16[1, 71, 2048, 32]"
# t16592 = prims.slice_prim(t16570, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16592: "cuda:0 bf16[1, 71, 2048, 32]"
# t16593 = prims.convert_element_type(t16591, dtypes.float32) # t16593: "cuda:0 f32[1, 71, 2048, 32]"
# t16594 = prims.neg(t16593) # t16594: "cuda:0 f32[1, 71, 2048, 32]"
# t16595 = prims.convert_element_type(t16594, dtypes.bfloat16) # t16595: "cuda:0 bf16[1, 71, 2048, 32]"
# t16596 = prims.pad(t16595, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t16596: "cuda:0 bf16[1, 71, 2048, 64]"
# t16598 = prims.convert_element_type(t16596, dtypes.float32) # t16598: "cuda:0 f32[1, 71, 2048, 64]"
# t16599 = prims.add(t16579, t16598) # t16599: "cuda:0 f32[1, 71, 2048, 64]"
# t16601 = prims.pad(t16592, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t16601: "cuda:0 bf16[1, 71, 2048, 64]"
# t16603 = prims.convert_element_type(t16601, dtypes.float32) # t16603: "cuda:0 f32[1, 71, 2048, 64]"
# t16604 = prims.add(t16599, t16603) # t16604: "cuda:0 f32[1, 71, 2048, 64]"
# t16605 = prims.convert_element_type(t16604, dtypes.bfloat16) # t16605: "cuda:0 bf16[1, 71, 2048, 64]"
# t16606 = prims.pad(t16605, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t16606: "cuda:0 bf16[1, 71, 2048, 64]"
# t16607 = prims.convert_element_type(t16514, dtypes.float32) # t16607: "cuda:0 f32[1, 71, 2048, 64]"
# t16608 = prims.convert_element_type(t16606, dtypes.float32) # t16608: "cuda:0 f32[1, 71, 2048, 64]"
# t16609 = prims.add(t16607, t16608) # t16609: "cuda:0 f32[1, 71, 2048, 64]"
# t16610 = prims.convert_element_type(t16609, dtypes.bfloat16) # t16610: "cuda:0 bf16[1, 71, 2048, 64]"
# t16620 = prims.reshape(t16562, (1, 1, 71, 2048, 64)) # t16620: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16625 = prims.reshape(t16610, (1, 1, 71, 2048, 64)) # t16625: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16631 = prims.convert_element_type(t16615, dtypes.float32) # t16631: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t16632 = prims.sum(t16631, (0, 1, 2)) # t16632: "cuda:0 f32[2048, 64]"
# t16633 = prims.convert_element_type(t16632, dtypes.bfloat16) # t16633: "cuda:0 bf16[2048, 64]"
# t16634 = prims.broadcast_in_dim(t16633, [1, 1, 1, 2048, 64], [3, 4]) # t16634: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t16640 = prims.convert_element_type(t16620, dtypes.float32) # t16640: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t16641 = prims.sum(t16640, (0, 1, 2)) # t16641: "cuda:0 f32[2048, 64]"
# t16642 = prims.convert_element_type(t16641, dtypes.bfloat16) # t16642: "cuda:0 bf16[2048, 64]"
# t16643 = prims.broadcast_in_dim(t16642, [1, 1, 1, 2048, 64], [3, 4]) # t16643: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t16649 = prims.cat((t16625, t16643, t16634), i1115) # t16649: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1115, t16508, t16512, t16615
t16655 = torch.permute(t16649, (0, 3, 1, 2, 4)) # t16655: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t16655 = ltorch.permute(t16649, (0, 3, 1, 2, 4)) # t16655: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t16655 = prims.transpose(t16649, (0, 3, 1, 2, 4)) # t16655: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t16649
t16661 = torch.reshape(t16655, (1, 2048, 4672)) # t16661: "cuda:0 bf16[1, 2048, 4672]"
# t16661 = ltorch.reshape(t16655, (1, 2048, 4672)) # t16661: "cuda:0 bf16[1, 2048, 4672]"
# t16661 = prims.reshape(t16655, (1, 2048, 4672)) # t16661: "cuda:0 bf16[1, 2048, 4672]"
del t16655
t16662 = torch.reshape(t16661, (-1, 4672)) # t16662: "cuda:0 bf16[2048, 4672]"
# t16662 = ltorch.reshape(t16661, (-1, 4672)) # t16662: "cuda:0 bf16[2048, 4672]"
# t16662 = prims.reshape(t16661, (2048, 4672)) # t16662: "cuda:0 bf16[2048, 4672]"
del t16661
t16666 = torch.permute(t16662, (1, 0)) # t16666: "cuda:0 bf16[4672, 2048]"
# t16666 = ltorch.permute(t16662, (1, 0)) # t16666: "cuda:0 bf16[4672, 2048]"
# t16666 = prims.transpose(t16662, (1, 0)) # t16666: "cuda:0 bf16[4672, 2048]"
t16668 = torch.matmul(t16666, t16488) # t16668: "cuda:0 bf16[4672, 4544]"
# t16668 = ltorch.matmul(t16666, t16667) # t16668: "cuda:0 bf16[4672, 4544]"
# t16668 = prims.matmul(t16666, t16667) # t16668: "cuda:0 bf16[4672, 4544]"
del t16666, t16488
t16663 = torch.matmul(t16662, t_transformer_h_17_attn_attn_weight) # t16663: "cuda:0 bf16[2048, 4544]"
# t16663 = ltorch.matmul(t16662, t_transformer_h_17_attn_attn_weight) # t16663: "cuda:0 bf16[2048, 4544]"
# t16663 = prims.matmul(t16662, t_transformer_h_17_attn_attn_weight) # t16663: "cuda:0 bf16[2048, 4544]"
del t16662, t_transformer_h_17_attn_attn_weight
t16485 = torch.reshape(t16484, (1, 2048, 4544)) # t16485: "cuda:0 bf16[1, 2048, 4544]"
# t16485 = ltorch.reshape(t16484, (1, 2048, 4544)) # t16485: "cuda:0 bf16[1, 2048, 4544]"
# t16485 = prims.reshape(t16484, (1, 2048, 4544)) # t16485: "cuda:0 bf16[1, 2048, 4544]"
del t16484
t16664 = torch.reshape(t16663, (1, 2048, 4544)) # t16664: "cuda:0 bf16[1, 2048, 4544]"
# t16664 = ltorch.reshape(t16663, (1, 2048, 4544)) # t16664: "cuda:0 bf16[1, 2048, 4544]"
# t16664 = prims.reshape(t16663, (1, 2048, 4544)) # t16664: "cuda:0 bf16[1, 2048, 4544]"
del t16663
[t16677, t16683, t16725] = nvFusion45(i16705, t16442, t16485, t16664, t2547, t2679, t2700, t2715, t2720, t2726)
# t2706 = prims.convert_element_type(t2547, dtypes.float32) # t2706: "cuda:0 f32[1, 2048, 4544]"
# t2701 = prims.convert_element_type(t2700, dtypes.float32) # t2701: "cuda:0 f32[1, 2048, 4544]"
# t2702 = prims.convert_element_type(t2679, dtypes.float32) # t2702: "cuda:0 f32[1, 2048, 4544]"
# t2703 = prims.add(t2701, t2702) # t2703: "cuda:0 f32[1, 2048, 4544]"
# t2707 = prims.add(t2703, t2706) # t2707: "cuda:0 f32[1, 2048, 4544]"
# t2717 = prims.broadcast_in_dim(t2715, [1, 2048, 1], [0, 1]) # t2717: "cuda:0 f32[1, 2048, 1]"
# t2721 = prims.broadcast_in_dim(t2717, (1, 2048, 4544), (0, 1, 2)) # t2721: "cuda:0 f32[1, 2048, 4544]"
# t2723 = prims.sub(t2707, t2721) # t2723: "cuda:0 f32[1, 2048, 4544]"
# t2724 = prims.broadcast_in_dim(t2720, (1, 2048, 4544), (0, 1, 2)) # t2724: "cuda:0 f32[1, 2048, 4544]"
# t2725 = prims.mul(t2723, t2724) # t2725: "cuda:0 f32[1, 2048, 4544]"
# t2727 = prims.convert_element_type(t2726, dtypes.float32) # t2727: "cuda:0 f32[1, 2048, 4544]"
# t16722 = prims.convert_element_type(t16442, dtypes.float32) # t16722: "cuda:0 f32[1, 2048, 4544]"
# t16669 = prims.convert_element_type(t16485, dtypes.float32) # t16669: "cuda:0 f32[1, 2048, 4544]"
# t16670 = prims.convert_element_type(t16664, dtypes.float32) # t16670: "cuda:0 f32[1, 2048, 4544]"
# t16671 = prims.add(t16669, t16670) # t16671: "cuda:0 f32[1, 2048, 4544]"
# t16676 = prims.sum(t16671, (0, 1)) # t16676: "cuda:0 f32[4544]"
# t16677 = prims.convert_element_type(t16676, dtypes.bfloat16) # t16677: "cuda:0 bf16[4544]"
# t16678 = prims.mul(t2727, t16671) # t16678: "cuda:0 f32[1, 2048, 4544]"
# t16679 = prims.mul(t2725, t16671) # t16679: "cuda:0 f32[1, 2048, 4544]"
# t16682 = prims.sum(t16679, (0, 1)) # t16682: "cuda:0 f32[4544]"
# t16683 = prims.convert_element_type(t16682, dtypes.bfloat16) # t16683: "cuda:0 bf16[4544]"
# t16684 = prims.mul(t2724, t16678) # t16684: "cuda:0 f32[1, 2048, 4544]"
# t16685 = prims.mul(t2723, t16678) # t16685: "cuda:0 f32[1, 2048, 4544]"
# t16686 = prims.sum(t16685, (0, 2)) # t16686: "cuda:0 f32[2048]"
# t16687 = prims.broadcast_in_dim(t16686, [1, 2048, 1], [1]) # t16687: "cuda:0 f32[1, 2048, 1]"
# t16688 = prims.neg(t16684) # t16688: "cuda:0 f32[1, 2048, 4544]"
# t16690 = prims.sum(t16688, (0, 2)) # t16690: "cuda:0 f32[2048]"
# t16691 = prims.broadcast_in_dim(t16690, [1, 2048, 1], [1]) # t16691: "cuda:0 f32[1, 2048, 1]"
# t16692 = prims.mul(-0.5, t16687) # t16692: "cuda:0 f32[1, 2048, 1]"
# t16693 = prims.pow(t2720, 3.0) # t16693: "cuda:0 f32[1, 2048, 1]"
# t16694 = prims.mul(t16692, t16693) # t16694: "cuda:0 f32[1, 2048, 1]"
# t16696 = prims.sum(t16691, (0, 2)) # t16696: "cuda:0 f32[2048]"
# t16697 = prims.broadcast_in_dim(t16696, [1, 2048], [1]) # t16697: "cuda:0 f32[1, 2048]"
# t16698 = prims.sum(t16694, (0, 2)) # t16698: "cuda:0 f32[2048]"
# t16699 = prims.broadcast_in_dim(t16698, [1, 2048], [1]) # t16699: "cuda:0 f32[1, 2048]"
# t16702 = prims.broadcast_in_dim(t16697, [1, 2048, 1], [0, 1]) # t16702: "cuda:0 f32[1, 2048, 1]"
# t16703 = prims.broadcast_in_dim(t16702, (1, 2048, 4544), (0, 1, 2)) # t16703: "cuda:0 f32[1, 2048, 4544]"
# t16704 = prims.mul(0.00022007042253521127, t16703) # t16704: "cuda:0 f32[1, 2048, 4544]"
# t16706 = prims.broadcast_in_dim(t16699, [1, 2048, 1], [0, 1]) # t16706: "cuda:0 f32[1, 2048, 1]"
# t16707 = prims.broadcast_in_dim(t16706, (1, 2048, 4544), (0, 1, 2)) # t16707: "cuda:0 f32[1, 2048, 4544]"
# t16709 = prims.broadcast_in_dim(t2715, [1, 2048, 1], [0, 1]) # t16709: "cuda:0 f32[1, 2048, 1]"
# t16710 = prims.broadcast_in_dim(t16709, (1, 2048, 4544), (0, 1, 2)) # t16710: "cuda:0 f32[1, 2048, 4544]"
# t16711 = prims.mul(2.0, t16707) # t16711: "cuda:0 f32[1, 2048, 4544]"
# t16712 = prims.sub(t2707, t16710) # t16712: "cuda:0 f32[1, 2048, 4544]"
# t16713 = prims.mul(t16711, t16712) # t16713: "cuda:0 f32[1, 2048, 4544]"
# f16714 = prims.convert_element_type(i16705, float) # f16714: "float 4544.0"
# t16715 = prims.div(t16713, f16714) # t16715: "cuda:0 f32[1, 2048, 4544]"
# t16716 = prims.add(t16704, t16715) # t16716: "cuda:0 f32[1, 2048, 4544]"
# t16720 = prims.add(t16684, t16716) # t16720: "cuda:0 f32[1, 2048, 4544]"
# t16724 = prims.add(t16722, t16720) # t16724: "cuda:0 f32[1, 2048, 4544]"
# t16725 = prims.convert_element_type(t16724, dtypes.bfloat16) # t16725: "cuda:0 bf16[1, 2048, 4544]"
del i16705, t16442, t16485, t16664, t2547, t2679, t2700, t2715, t2720, t2726
t16732 = torch.reshape(t16725, (-1, 4544)) # t16732: "cuda:0 bf16[2048, 4544]"
# t16732 = ltorch.reshape(t16725, (-1, 4544)) # t16732: "cuda:0 bf16[2048, 4544]"
# t16732 = prims.reshape(t16725, (2048, 4544)) # t16732: "cuda:0 bf16[2048, 4544]"
t16736 = torch.permute(t16732, (1, 0)) # t16736: "cuda:0 bf16[4544, 2048]"
# t16736 = ltorch.permute(t16732, (1, 0)) # t16736: "cuda:0 bf16[4544, 2048]"
# t16736 = prims.transpose(t16732, (1, 0)) # t16736: "cuda:0 bf16[4544, 2048]"
t16733 = torch.matmul(t16732, t_transformer_h_16_mlp_proj_weight) # t16733: "cuda:0 bf16[2048, 18176]"
# t16733 = ltorch.matmul(t16732, t_transformer_h_16_mlp_proj_weight) # t16733: "cuda:0 bf16[2048, 18176]"
# t16733 = prims.matmul(t16732, t_transformer_h_16_mlp_proj_weight) # t16733: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_16_mlp_proj_weight
t16738 = torch.matmul(t16736, t16737) # t16738: "cuda:0 bf16[4544, 18176]"
# t16738 = ltorch.matmul(t16736, t16737) # t16738: "cuda:0 bf16[4544, 18176]"
# t16738 = prims.matmul(t16736, t16737) # t16738: "cuda:0 bf16[4544, 18176]"
del t16737
t16774 = torch.matmul(t16732, t_transformer_h_16_attn_proj_weight) # t16774: "cuda:0 bf16[2048, 4544]"
# t16774 = ltorch.matmul(t16773, t_transformer_h_16_attn_proj_weight) # t16774: "cuda:0 bf16[2048, 4544]"
# t16774 = prims.matmul(t16773, t_transformer_h_16_attn_proj_weight) # t16774: "cuda:0 bf16[2048, 4544]"
del t16732, t_transformer_h_16_attn_proj_weight
t16779 = torch.matmul(t16736, t16778) # t16779: "cuda:0 bf16[4544, 4544]"
# t16779 = ltorch.matmul(t16777, t16778) # t16779: "cuda:0 bf16[4544, 4544]"
# t16779 = prims.matmul(t16777, t16778) # t16779: "cuda:0 bf16[4544, 4544]"
del t16736, t16778
t16734 = torch.reshape(t16733, (1, 2048, 18176)) # t16734: "cuda:0 bf16[1, 2048, 18176]"
# t16734 = ltorch.reshape(t16733, (1, 2048, 18176)) # t16734: "cuda:0 bf16[1, 2048, 18176]"
# t16734 = prims.reshape(t16733, (1, 2048, 18176)) # t16734: "cuda:0 bf16[1, 2048, 18176]"
del t16733
t16775 = torch.reshape(t16774, (1, 2048, 4544)) # t16775: "cuda:0 bf16[1, 2048, 4544]"
# t16775 = ltorch.reshape(t16774, (1, 2048, 4544)) # t16775: "cuda:0 bf16[1, 2048, 4544]"
# t16775 = prims.reshape(t16774, (1, 2048, 4544)) # t16775: "cuda:0 bf16[1, 2048, 4544]"
del t16774
t16783 = torch.reshape(t16775, (1, 2048, 71, 64)) # t16783: "cuda:0 bf16[1, 2048, 71, 64]"
# t16783 = ltorch.reshape(t16775, (1, 2048, 71, 64)) # t16783: "cuda:0 bf16[1, 2048, 71, 64]"
# t16783 = prims.reshape(t16775, (1, 2048, 71, 64)) # t16783: "cuda:0 bf16[1, 2048, 71, 64]"
del t16775
t16786 = torch.permute(t16783, (0, 2, 1, 3)) # t16786: "cuda:0 bf16[1, 71, 2048, 64]"
# t16786 = ltorch.permute(t16783, (0, 2, 1, 3)) # t16786: "cuda:0 bf16[1, 71, 2048, 64]"
# t16786 = prims.transpose(t16783, (0, 2, 1, 3)) # t16786: "cuda:0 bf16[1, 71, 2048, 64]"
del t16783
[t16765] = nvFusion46(f1087, f1089, t16734, t2680)
# t2681 = prims.convert_element_type(t2680, dtypes.float32) # t2681: "cuda:0 f32[1, 2048, 18176]"
# t2683 = prims.div(t2681, 1.4142135623730951) # t2683: "cuda:0 f32[1, 2048, 18176]"
# t2686 = prims.erf(t2683) # t2686: "cuda:0 f32[1, 2048, 18176]"
# t2690 = prims.mul(0.5, t2686) # t2690: "cuda:0 f32[1, 2048, 18176]"
# t2694 = prims.add(0.5, t2690) # t2694: "cuda:0 f32[1, 2048, 18176]"
# t16739 = prims.convert_element_type(t16734, dtypes.float32) # t16739: "cuda:0 f32[1, 2048, 18176]"
# t16740 = prims.mul(t2694, t16739) # t16740: "cuda:0 f32[1, 2048, 18176]"
# t16741 = prims.mul(t2681, t16739) # t16741: "cuda:0 f32[1, 2048, 18176]"
# t16749 = prims.mul(f1089, t16741) # t16749: "cuda:0 f32[1, 2048, 18176]"
# t16752 = prims.pow(t2683, 2.0) # t16752: "cuda:0 f32[1, 2048, 18176]"
# t16753 = prims.neg(t16752) # t16753: "cuda:0 f32[1, 2048, 18176]"
# t16754 = prims.exp(t16753) # t16754: "cuda:0 f32[1, 2048, 18176]"
# t16755 = prims.mul(1.1283791670955126, t16754) # t16755: "cuda:0 f32[1, 2048, 18176]"
# t16756 = prims.mul(t16755, t16749) # t16756: "cuda:0 f32[1, 2048, 18176]"
# t16760 = prims.div(t16756, f1087) # t16760: "cuda:0 f32[1, 2048, 18176]"
# t16764 = prims.add(t16740, t16760) # t16764: "cuda:0 f32[1, 2048, 18176]"
# t16765 = prims.convert_element_type(t16764, dtypes.bfloat16) # t16765: "cuda:0 bf16[1, 2048, 18176]"
del f1087, f1089, t16734, t2680
t16766 = torch.reshape(t16765, (-1, 18176)) # t16766: "cuda:0 bf16[2048, 18176]"
# t16766 = ltorch.reshape(t16765, (-1, 18176)) # t16766: "cuda:0 bf16[2048, 18176]"
# t16766 = prims.reshape(t16765, (2048, 18176)) # t16766: "cuda:0 bf16[2048, 18176]"
del t16765
t16770 = torch.permute(t16766, (1, 0)) # t16770: "cuda:0 bf16[18176, 2048]"
# t16770 = ltorch.permute(t16766, (1, 0)) # t16770: "cuda:0 bf16[18176, 2048]"
# t16770 = prims.transpose(t16766, (1, 0)) # t16770: "cuda:0 bf16[18176, 2048]"
t16772 = torch.matmul(t16770, t16771) # t16772: "cuda:0 bf16[18176, 4544]"
# t16772 = ltorch.matmul(t16770, t16771) # t16772: "cuda:0 bf16[18176, 4544]"
# t16772 = prims.matmul(t16770, t16771) # t16772: "cuda:0 bf16[18176, 4544]"
del t16770
t16767 = torch.matmul(t16766, t_transformer_h_16_mlp_fc_weight) # t16767: "cuda:0 bf16[2048, 4544]"
# t16767 = ltorch.matmul(t16766, t_transformer_h_16_mlp_fc_weight) # t16767: "cuda:0 bf16[2048, 4544]"
# t16767 = prims.matmul(t16766, t_transformer_h_16_mlp_fc_weight) # t16767: "cuda:0 bf16[2048, 4544]"
del t16766, t_transformer_h_16_mlp_fc_weight
(t16787, t16788, t16789) = cudnn_sdpa_bwd(t16786, t2664, t2667, t2617, None, f1078, b1079, t2668, t2669, t2670, t2671, scale=f1080, cat_grad_qkv=False)
del t16786, t2664, t2667, t2617, f1078, b1079, t2668, t2669, t2670, t2671, f1080
t16791 = torch_slice_prim_impl(t16788, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16791: "cuda:0 bf16[1, 71, 2048, 64]"
del t16788
t16795 = torch_slice_prim_impl(t16787, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16795: "cuda:0 bf16[1, 71, 2048, 64]"
del t16787
t16898 = torch.reshape(t16789, (1, 1, 71, 2048, 64)) # t16898: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16898 = ltorch.reshape(t16789, (1, 1, 71, 2048, 64)) # t16898: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16898 = prims.reshape(t16789, (1, 1, 71, 2048, 64)) # t16898: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t16789
[t16932] = nvFusion47(i1051, t16791, t16795, t16898, t61, t66)
# t16792 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t16792: "cuda:0 bf16[1, 71, 2048, 0]"
# t16793 = prims.pad(t16792, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t16793: "cuda:0 bf16[1, 71, 2048, 64]"
# t16796 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t16796: "cuda:0 bf16[1, 71, 2048, 0]"
# t16797 = prims.pad(t16796, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t16797: "cuda:0 bf16[1, 71, 2048, 64]"
# t16798 = prims.convert_element_type(t16791, dtypes.float32) # t16798: "cuda:0 f32[1, 71, 2048, 64]"
# t16802 = prims.mul(t66, t16798) # t16802: "cuda:0 f32[1, 71, 2048, 64]"
# t16805 = prims.convert_element_type(t16802, dtypes.bfloat16) # t16805: "cuda:0 bf16[1, 71, 2048, 64]"
# t16814 = prims.mul(t61, t16798) # t16814: "cuda:0 f32[1, 71, 2048, 64]"
# t16826 = prims.slice_prim(t16805, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t16826: "cuda:0 bf16[1, 71, 2048, 32]"
# t16827 = prims.slice_prim(t16805, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16827: "cuda:0 bf16[1, 71, 2048, 32]"
# t16828 = prims.convert_element_type(t16826, dtypes.float32) # t16828: "cuda:0 f32[1, 71, 2048, 32]"
# t16829 = prims.neg(t16828) # t16829: "cuda:0 f32[1, 71, 2048, 32]"
# t16830 = prims.convert_element_type(t16829, dtypes.bfloat16) # t16830: "cuda:0 bf16[1, 71, 2048, 32]"
# t16831 = prims.pad(t16830, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t16831: "cuda:0 bf16[1, 71, 2048, 64]"
# t16833 = prims.convert_element_type(t16831, dtypes.float32) # t16833: "cuda:0 f32[1, 71, 2048, 64]"
# t16834 = prims.add(t16814, t16833) # t16834: "cuda:0 f32[1, 71, 2048, 64]"
# t16836 = prims.pad(t16827, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t16836: "cuda:0 bf16[1, 71, 2048, 64]"
# t16838 = prims.convert_element_type(t16836, dtypes.float32) # t16838: "cuda:0 f32[1, 71, 2048, 64]"
# t16839 = prims.add(t16834, t16838) # t16839: "cuda:0 f32[1, 71, 2048, 64]"
# t16840 = prims.convert_element_type(t16839, dtypes.bfloat16) # t16840: "cuda:0 bf16[1, 71, 2048, 64]"
# t16841 = prims.pad(t16840, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t16841: "cuda:0 bf16[1, 71, 2048, 64]"
# t16842 = prims.convert_element_type(t16793, dtypes.float32) # t16842: "cuda:0 f32[1, 71, 2048, 64]"
# t16843 = prims.convert_element_type(t16841, dtypes.float32) # t16843: "cuda:0 f32[1, 71, 2048, 64]"
# t16844 = prims.add(t16842, t16843) # t16844: "cuda:0 f32[1, 71, 2048, 64]"
# t16845 = prims.convert_element_type(t16844, dtypes.bfloat16) # t16845: "cuda:0 bf16[1, 71, 2048, 64]"
# t16846 = prims.convert_element_type(t16795, dtypes.float32) # t16846: "cuda:0 f32[1, 71, 2048, 64]"
# t16850 = prims.mul(t66, t16846) # t16850: "cuda:0 f32[1, 71, 2048, 64]"
# t16853 = prims.convert_element_type(t16850, dtypes.bfloat16) # t16853: "cuda:0 bf16[1, 71, 2048, 64]"
# t16862 = prims.mul(t61, t16846) # t16862: "cuda:0 f32[1, 71, 2048, 64]"
# t16874 = prims.slice_prim(t16853, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t16874: "cuda:0 bf16[1, 71, 2048, 32]"
# t16875 = prims.slice_prim(t16853, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t16875: "cuda:0 bf16[1, 71, 2048, 32]"
# t16876 = prims.convert_element_type(t16874, dtypes.float32) # t16876: "cuda:0 f32[1, 71, 2048, 32]"
# t16877 = prims.neg(t16876) # t16877: "cuda:0 f32[1, 71, 2048, 32]"
# t16878 = prims.convert_element_type(t16877, dtypes.bfloat16) # t16878: "cuda:0 bf16[1, 71, 2048, 32]"
# t16879 = prims.pad(t16878, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t16879: "cuda:0 bf16[1, 71, 2048, 64]"
# t16881 = prims.convert_element_type(t16879, dtypes.float32) # t16881: "cuda:0 f32[1, 71, 2048, 64]"
# t16882 = prims.add(t16862, t16881) # t16882: "cuda:0 f32[1, 71, 2048, 64]"
# t16884 = prims.pad(t16875, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t16884: "cuda:0 bf16[1, 71, 2048, 64]"
# t16886 = prims.convert_element_type(t16884, dtypes.float32) # t16886: "cuda:0 f32[1, 71, 2048, 64]"
# t16887 = prims.add(t16882, t16886) # t16887: "cuda:0 f32[1, 71, 2048, 64]"
# t16888 = prims.convert_element_type(t16887, dtypes.bfloat16) # t16888: "cuda:0 bf16[1, 71, 2048, 64]"
# t16889 = prims.pad(t16888, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t16889: "cuda:0 bf16[1, 71, 2048, 64]"
# t16890 = prims.convert_element_type(t16797, dtypes.float32) # t16890: "cuda:0 f32[1, 71, 2048, 64]"
# t16891 = prims.convert_element_type(t16889, dtypes.float32) # t16891: "cuda:0 f32[1, 71, 2048, 64]"
# t16892 = prims.add(t16890, t16891) # t16892: "cuda:0 f32[1, 71, 2048, 64]"
# t16893 = prims.convert_element_type(t16892, dtypes.bfloat16) # t16893: "cuda:0 bf16[1, 71, 2048, 64]"
# t16903 = prims.reshape(t16845, (1, 1, 71, 2048, 64)) # t16903: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16908 = prims.reshape(t16893, (1, 1, 71, 2048, 64)) # t16908: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t16914 = prims.convert_element_type(t16898, dtypes.float32) # t16914: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t16915 = prims.sum(t16914, (0, 1, 2)) # t16915: "cuda:0 f32[2048, 64]"
# t16916 = prims.convert_element_type(t16915, dtypes.bfloat16) # t16916: "cuda:0 bf16[2048, 64]"
# t16917 = prims.broadcast_in_dim(t16916, [1, 1, 1, 2048, 64], [3, 4]) # t16917: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t16923 = prims.convert_element_type(t16903, dtypes.float32) # t16923: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t16924 = prims.sum(t16923, (0, 1, 2)) # t16924: "cuda:0 f32[2048, 64]"
# t16925 = prims.convert_element_type(t16924, dtypes.bfloat16) # t16925: "cuda:0 bf16[2048, 64]"
# t16926 = prims.broadcast_in_dim(t16925, [1, 1, 1, 2048, 64], [3, 4]) # t16926: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t16932 = prims.cat((t16908, t16926, t16917), i1051) # t16932: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i1051, t16791, t16795, t16898
t16938 = torch.permute(t16932, (0, 3, 1, 2, 4)) # t16938: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t16938 = ltorch.permute(t16932, (0, 3, 1, 2, 4)) # t16938: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t16938 = prims.transpose(t16932, (0, 3, 1, 2, 4)) # t16938: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t16932
t16944 = torch.reshape(t16938, (1, 2048, 4672)) # t16944: "cuda:0 bf16[1, 2048, 4672]"
# t16944 = ltorch.reshape(t16938, (1, 2048, 4672)) # t16944: "cuda:0 bf16[1, 2048, 4672]"
# t16944 = prims.reshape(t16938, (1, 2048, 4672)) # t16944: "cuda:0 bf16[1, 2048, 4672]"
del t16938
t16945 = torch.reshape(t16944, (-1, 4672)) # t16945: "cuda:0 bf16[2048, 4672]"
# t16945 = ltorch.reshape(t16944, (-1, 4672)) # t16945: "cuda:0 bf16[2048, 4672]"
# t16945 = prims.reshape(t16944, (2048, 4672)) # t16945: "cuda:0 bf16[2048, 4672]"
del t16944
t16949 = torch.permute(t16945, (1, 0)) # t16949: "cuda:0 bf16[4672, 2048]"
# t16949 = ltorch.permute(t16945, (1, 0)) # t16949: "cuda:0 bf16[4672, 2048]"
# t16949 = prims.transpose(t16945, (1, 0)) # t16949: "cuda:0 bf16[4672, 2048]"
t16951 = torch.matmul(t16949, t16771) # t16951: "cuda:0 bf16[4672, 4544]"
# t16951 = ltorch.matmul(t16949, t16950) # t16951: "cuda:0 bf16[4672, 4544]"
# t16951 = prims.matmul(t16949, t16950) # t16951: "cuda:0 bf16[4672, 4544]"
del t16949, t16771
t16946 = torch.matmul(t16945, t_transformer_h_16_attn_attn_weight) # t16946: "cuda:0 bf16[2048, 4544]"
# t16946 = ltorch.matmul(t16945, t_transformer_h_16_attn_attn_weight) # t16946: "cuda:0 bf16[2048, 4544]"
# t16946 = prims.matmul(t16945, t_transformer_h_16_attn_attn_weight) # t16946: "cuda:0 bf16[2048, 4544]"
del t16945, t_transformer_h_16_attn_attn_weight
t16768 = torch.reshape(t16767, (1, 2048, 4544)) # t16768: "cuda:0 bf16[1, 2048, 4544]"
# t16768 = ltorch.reshape(t16767, (1, 2048, 4544)) # t16768: "cuda:0 bf16[1, 2048, 4544]"
# t16768 = prims.reshape(t16767, (1, 2048, 4544)) # t16768: "cuda:0 bf16[1, 2048, 4544]"
del t16767
t16947 = torch.reshape(t16946, (1, 2048, 4544)) # t16947: "cuda:0 bf16[1, 2048, 4544]"
# t16947 = ltorch.reshape(t16946, (1, 2048, 4544)) # t16947: "cuda:0 bf16[1, 2048, 4544]"
# t16947 = prims.reshape(t16946, (1, 2048, 4544)) # t16947: "cuda:0 bf16[1, 2048, 4544]"
del t16946
[t16960, t16966, t17008] = nvFusion48(i16988, t16725, t16768, t16947, t2386, t2518, t2539, t2554, t2559, t2565)
# t2545 = prims.convert_element_type(t2386, dtypes.float32) # t2545: "cuda:0 f32[1, 2048, 4544]"
# t2540 = prims.convert_element_type(t2539, dtypes.float32) # t2540: "cuda:0 f32[1, 2048, 4544]"
# t2541 = prims.convert_element_type(t2518, dtypes.float32) # t2541: "cuda:0 f32[1, 2048, 4544]"
# t2542 = prims.add(t2540, t2541) # t2542: "cuda:0 f32[1, 2048, 4544]"
# t2546 = prims.add(t2542, t2545) # t2546: "cuda:0 f32[1, 2048, 4544]"
# t2556 = prims.broadcast_in_dim(t2554, [1, 2048, 1], [0, 1]) # t2556: "cuda:0 f32[1, 2048, 1]"
# t2560 = prims.broadcast_in_dim(t2556, (1, 2048, 4544), (0, 1, 2)) # t2560: "cuda:0 f32[1, 2048, 4544]"
# t2562 = prims.sub(t2546, t2560) # t2562: "cuda:0 f32[1, 2048, 4544]"
# t2563 = prims.broadcast_in_dim(t2559, (1, 2048, 4544), (0, 1, 2)) # t2563: "cuda:0 f32[1, 2048, 4544]"
# t2564 = prims.mul(t2562, t2563) # t2564: "cuda:0 f32[1, 2048, 4544]"
# t2566 = prims.convert_element_type(t2565, dtypes.float32) # t2566: "cuda:0 f32[1, 2048, 4544]"
# t17005 = prims.convert_element_type(t16725, dtypes.float32) # t17005: "cuda:0 f32[1, 2048, 4544]"
# t16952 = prims.convert_element_type(t16768, dtypes.float32) # t16952: "cuda:0 f32[1, 2048, 4544]"
# t16953 = prims.convert_element_type(t16947, dtypes.float32) # t16953: "cuda:0 f32[1, 2048, 4544]"
# t16954 = prims.add(t16952, t16953) # t16954: "cuda:0 f32[1, 2048, 4544]"
# t16959 = prims.sum(t16954, (0, 1)) # t16959: "cuda:0 f32[4544]"
# t16960 = prims.convert_element_type(t16959, dtypes.bfloat16) # t16960: "cuda:0 bf16[4544]"
# t16961 = prims.mul(t2566, t16954) # t16961: "cuda:0 f32[1, 2048, 4544]"
# t16962 = prims.mul(t2564, t16954) # t16962: "cuda:0 f32[1, 2048, 4544]"
# t16965 = prims.sum(t16962, (0, 1)) # t16965: "cuda:0 f32[4544]"
# t16966 = prims.convert_element_type(t16965, dtypes.bfloat16) # t16966: "cuda:0 bf16[4544]"
# t16967 = prims.mul(t2563, t16961) # t16967: "cuda:0 f32[1, 2048, 4544]"
# t16968 = prims.mul(t2562, t16961) # t16968: "cuda:0 f32[1, 2048, 4544]"
# t16969 = prims.sum(t16968, (0, 2)) # t16969: "cuda:0 f32[2048]"
# t16970 = prims.broadcast_in_dim(t16969, [1, 2048, 1], [1]) # t16970: "cuda:0 f32[1, 2048, 1]"
# t16971 = prims.neg(t16967) # t16971: "cuda:0 f32[1, 2048, 4544]"
# t16973 = prims.sum(t16971, (0, 2)) # t16973: "cuda:0 f32[2048]"
# t16974 = prims.broadcast_in_dim(t16973, [1, 2048, 1], [1]) # t16974: "cuda:0 f32[1, 2048, 1]"
# t16975 = prims.mul(-0.5, t16970) # t16975: "cuda:0 f32[1, 2048, 1]"
# t16976 = prims.pow(t2559, 3.0) # t16976: "cuda:0 f32[1, 2048, 1]"
# t16977 = prims.mul(t16975, t16976) # t16977: "cuda:0 f32[1, 2048, 1]"
# t16979 = prims.sum(t16974, (0, 2)) # t16979: "cuda:0 f32[2048]"
# t16980 = prims.broadcast_in_dim(t16979, [1, 2048], [1]) # t16980: "cuda:0 f32[1, 2048]"
# t16981 = prims.sum(t16977, (0, 2)) # t16981: "cuda:0 f32[2048]"
# t16982 = prims.broadcast_in_dim(t16981, [1, 2048], [1]) # t16982: "cuda:0 f32[1, 2048]"
# t16985 = prims.broadcast_in_dim(t16980, [1, 2048, 1], [0, 1]) # t16985: "cuda:0 f32[1, 2048, 1]"
# t16986 = prims.broadcast_in_dim(t16985, (1, 2048, 4544), (0, 1, 2)) # t16986: "cuda:0 f32[1, 2048, 4544]"
# t16987 = prims.mul(0.00022007042253521127, t16986) # t16987: "cuda:0 f32[1, 2048, 4544]"
# t16989 = prims.broadcast_in_dim(t16982, [1, 2048, 1], [0, 1]) # t16989: "cuda:0 f32[1, 2048, 1]"
# t16990 = prims.broadcast_in_dim(t16989, (1, 2048, 4544), (0, 1, 2)) # t16990: "cuda:0 f32[1, 2048, 4544]"
# t16992 = prims.broadcast_in_dim(t2554, [1, 2048, 1], [0, 1]) # t16992: "cuda:0 f32[1, 2048, 1]"
# t16993 = prims.broadcast_in_dim(t16992, (1, 2048, 4544), (0, 1, 2)) # t16993: "cuda:0 f32[1, 2048, 4544]"
# t16994 = prims.mul(2.0, t16990) # t16994: "cuda:0 f32[1, 2048, 4544]"
# t16995 = prims.sub(t2546, t16993) # t16995: "cuda:0 f32[1, 2048, 4544]"
# t16996 = prims.mul(t16994, t16995) # t16996: "cuda:0 f32[1, 2048, 4544]"
# f16997 = prims.convert_element_type(i16988, float) # f16997: "float 4544.0"
# t16998 = prims.div(t16996, f16997) # t16998: "cuda:0 f32[1, 2048, 4544]"
# t16999 = prims.add(t16987, t16998) # t16999: "cuda:0 f32[1, 2048, 4544]"
# t17003 = prims.add(t16967, t16999) # t17003: "cuda:0 f32[1, 2048, 4544]"
# t17007 = prims.add(t17005, t17003) # t17007: "cuda:0 f32[1, 2048, 4544]"
# t17008 = prims.convert_element_type(t17007, dtypes.bfloat16) # t17008: "cuda:0 bf16[1, 2048, 4544]"
del i16988, t16725, t16768, t16947, t2386, t2518, t2539, t2554, t2559, t2565
t17015 = torch.reshape(t17008, (-1, 4544)) # t17015: "cuda:0 bf16[2048, 4544]"
# t17015 = ltorch.reshape(t17008, (-1, 4544)) # t17015: "cuda:0 bf16[2048, 4544]"
# t17015 = prims.reshape(t17008, (2048, 4544)) # t17015: "cuda:0 bf16[2048, 4544]"
t17019 = torch.permute(t17015, (1, 0)) # t17019: "cuda:0 bf16[4544, 2048]"
# t17019 = ltorch.permute(t17015, (1, 0)) # t17019: "cuda:0 bf16[4544, 2048]"
# t17019 = prims.transpose(t17015, (1, 0)) # t17019: "cuda:0 bf16[4544, 2048]"
t17021 = torch.matmul(t17019, t17020) # t17021: "cuda:0 bf16[4544, 18176]"
# t17021 = ltorch.matmul(t17019, t17020) # t17021: "cuda:0 bf16[4544, 18176]"
# t17021 = prims.matmul(t17019, t17020) # t17021: "cuda:0 bf16[4544, 18176]"
del t17020
t17057 = torch.matmul(t17015, t_transformer_h_15_attn_proj_weight) # t17057: "cuda:0 bf16[2048, 4544]"
# t17057 = ltorch.matmul(t17056, t_transformer_h_15_attn_proj_weight) # t17057: "cuda:0 bf16[2048, 4544]"
# t17057 = prims.matmul(t17056, t_transformer_h_15_attn_proj_weight) # t17057: "cuda:0 bf16[2048, 4544]"
del t_transformer_h_15_attn_proj_weight
t17062 = torch.matmul(t17019, t17061) # t17062: "cuda:0 bf16[4544, 4544]"
# t17062 = ltorch.matmul(t17060, t17061) # t17062: "cuda:0 bf16[4544, 4544]"
# t17062 = prims.matmul(t17060, t17061) # t17062: "cuda:0 bf16[4544, 4544]"
del t17019, t17061
t17016 = torch.matmul(t17015, t_transformer_h_15_mlp_proj_weight) # t17016: "cuda:0 bf16[2048, 18176]"
# t17016 = ltorch.matmul(t17015, t_transformer_h_15_mlp_proj_weight) # t17016: "cuda:0 bf16[2048, 18176]"
# t17016 = prims.matmul(t17015, t_transformer_h_15_mlp_proj_weight) # t17016: "cuda:0 bf16[2048, 18176]"
del t17015, t_transformer_h_15_mlp_proj_weight
t17058 = torch.reshape(t17057, (1, 2048, 4544)) # t17058: "cuda:0 bf16[1, 2048, 4544]"
# t17058 = ltorch.reshape(t17057, (1, 2048, 4544)) # t17058: "cuda:0 bf16[1, 2048, 4544]"
# t17058 = prims.reshape(t17057, (1, 2048, 4544)) # t17058: "cuda:0 bf16[1, 2048, 4544]"
del t17057
t17066 = torch.reshape(t17058, (1, 2048, 71, 64)) # t17066: "cuda:0 bf16[1, 2048, 71, 64]"
# t17066 = ltorch.reshape(t17058, (1, 2048, 71, 64)) # t17066: "cuda:0 bf16[1, 2048, 71, 64]"
# t17066 = prims.reshape(t17058, (1, 2048, 71, 64)) # t17066: "cuda:0 bf16[1, 2048, 71, 64]"
del t17058
t17069 = torch.permute(t17066, (0, 2, 1, 3)) # t17069: "cuda:0 bf16[1, 71, 2048, 64]"
# t17069 = ltorch.permute(t17066, (0, 2, 1, 3)) # t17069: "cuda:0 bf16[1, 71, 2048, 64]"
# t17069 = prims.transpose(t17066, (0, 2, 1, 3)) # t17069: "cuda:0 bf16[1, 71, 2048, 64]"
del t17066
t17017 = torch.reshape(t17016, (1, 2048, 18176)) # t17017: "cuda:0 bf16[1, 2048, 18176]"
# t17017 = ltorch.reshape(t17016, (1, 2048, 18176)) # t17017: "cuda:0 bf16[1, 2048, 18176]"
# t17017 = prims.reshape(t17016, (1, 2048, 18176)) # t17017: "cuda:0 bf16[1, 2048, 18176]"
del t17016
[t17048] = nvFusion49(f1023, f1025, t17017, t2519)
# t2520 = prims.convert_element_type(t2519, dtypes.float32) # t2520: "cuda:0 f32[1, 2048, 18176]"
# t2522 = prims.div(t2520, 1.4142135623730951) # t2522: "cuda:0 f32[1, 2048, 18176]"
# t2525 = prims.erf(t2522) # t2525: "cuda:0 f32[1, 2048, 18176]"
# t2529 = prims.mul(0.5, t2525) # t2529: "cuda:0 f32[1, 2048, 18176]"
# t2533 = prims.add(0.5, t2529) # t2533: "cuda:0 f32[1, 2048, 18176]"
# t17022 = prims.convert_element_type(t17017, dtypes.float32) # t17022: "cuda:0 f32[1, 2048, 18176]"
# t17023 = prims.mul(t2533, t17022) # t17023: "cuda:0 f32[1, 2048, 18176]"
# t17024 = prims.mul(t2520, t17022) # t17024: "cuda:0 f32[1, 2048, 18176]"
# t17032 = prims.mul(f1025, t17024) # t17032: "cuda:0 f32[1, 2048, 18176]"
# t17035 = prims.pow(t2522, 2.0) # t17035: "cuda:0 f32[1, 2048, 18176]"
# t17036 = prims.neg(t17035) # t17036: "cuda:0 f32[1, 2048, 18176]"
# t17037 = prims.exp(t17036) # t17037: "cuda:0 f32[1, 2048, 18176]"
# t17038 = prims.mul(1.1283791670955126, t17037) # t17038: "cuda:0 f32[1, 2048, 18176]"
# t17039 = prims.mul(t17038, t17032) # t17039: "cuda:0 f32[1, 2048, 18176]"
# t17043 = prims.div(t17039, f1023) # t17043: "cuda:0 f32[1, 2048, 18176]"
# t17047 = prims.add(t17023, t17043) # t17047: "cuda:0 f32[1, 2048, 18176]"
# t17048 = prims.convert_element_type(t17047, dtypes.bfloat16) # t17048: "cuda:0 bf16[1, 2048, 18176]"
del f1023, f1025, t17017, t2519
t17049 = torch.reshape(t17048, (-1, 18176)) # t17049: "cuda:0 bf16[2048, 18176]"
# t17049 = ltorch.reshape(t17048, (-1, 18176)) # t17049: "cuda:0 bf16[2048, 18176]"
# t17049 = prims.reshape(t17048, (2048, 18176)) # t17049: "cuda:0 bf16[2048, 18176]"
del t17048
t17053 = torch.permute(t17049, (1, 0)) # t17053: "cuda:0 bf16[18176, 2048]"
# t17053 = ltorch.permute(t17049, (1, 0)) # t17053: "cuda:0 bf16[18176, 2048]"
# t17053 = prims.transpose(t17049, (1, 0)) # t17053: "cuda:0 bf16[18176, 2048]"
(t17070, t17071, t17072) = cudnn_sdpa_bwd(t17069, t2503, t2506, t2456, None, f1014, b1015, t2507, t2508, t2509, t2510, scale=f1016, cat_grad_qkv=False)
del t17069, t2503, t2506, t2456, f1014, b1015, t2507, t2508, t2509, t2510, f1016
t17055 = torch.matmul(t17053, t17054) # t17055: "cuda:0 bf16[18176, 4544]"
# t17055 = ltorch.matmul(t17053, t17054) # t17055: "cuda:0 bf16[18176, 4544]"
# t17055 = prims.matmul(t17053, t17054) # t17055: "cuda:0 bf16[18176, 4544]"
del t17053
t17050 = torch.matmul(t17049, t_transformer_h_15_mlp_fc_weight) # t17050: "cuda:0 bf16[2048, 4544]"
# t17050 = ltorch.matmul(t17049, t_transformer_h_15_mlp_fc_weight) # t17050: "cuda:0 bf16[2048, 4544]"
# t17050 = prims.matmul(t17049, t_transformer_h_15_mlp_fc_weight) # t17050: "cuda:0 bf16[2048, 4544]"
del t17049, t_transformer_h_15_mlp_fc_weight
t17074 = torch_slice_prim_impl(t17071, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17074: "cuda:0 bf16[1, 71, 2048, 64]"
del t17071
t17078 = torch_slice_prim_impl(t17070, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17078: "cuda:0 bf16[1, 71, 2048, 64]"
del t17070
t17181 = torch.reshape(t17072, (1, 1, 71, 2048, 64)) # t17181: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17181 = ltorch.reshape(t17072, (1, 1, 71, 2048, 64)) # t17181: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17181 = prims.reshape(t17072, (1, 1, 71, 2048, 64)) # t17181: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t17072
[t17215] = nvFusion50(i987, t17074, t17078, t17181, t61, t66)
# t17075 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t17075: "cuda:0 bf16[1, 71, 2048, 0]"
# t17076 = prims.pad(t17075, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t17076: "cuda:0 bf16[1, 71, 2048, 64]"
# t17079 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t17079: "cuda:0 bf16[1, 71, 2048, 0]"
# t17080 = prims.pad(t17079, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t17080: "cuda:0 bf16[1, 71, 2048, 64]"
# t17081 = prims.convert_element_type(t17074, dtypes.float32) # t17081: "cuda:0 f32[1, 71, 2048, 64]"
# t17085 = prims.mul(t66, t17081) # t17085: "cuda:0 f32[1, 71, 2048, 64]"
# t17088 = prims.convert_element_type(t17085, dtypes.bfloat16) # t17088: "cuda:0 bf16[1, 71, 2048, 64]"
# t17097 = prims.mul(t61, t17081) # t17097: "cuda:0 f32[1, 71, 2048, 64]"
# t17109 = prims.slice_prim(t17088, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t17109: "cuda:0 bf16[1, 71, 2048, 32]"
# t17110 = prims.slice_prim(t17088, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17110: "cuda:0 bf16[1, 71, 2048, 32]"
# t17111 = prims.convert_element_type(t17109, dtypes.float32) # t17111: "cuda:0 f32[1, 71, 2048, 32]"
# t17112 = prims.neg(t17111) # t17112: "cuda:0 f32[1, 71, 2048, 32]"
# t17113 = prims.convert_element_type(t17112, dtypes.bfloat16) # t17113: "cuda:0 bf16[1, 71, 2048, 32]"
# t17114 = prims.pad(t17113, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t17114: "cuda:0 bf16[1, 71, 2048, 64]"
# t17116 = prims.convert_element_type(t17114, dtypes.float32) # t17116: "cuda:0 f32[1, 71, 2048, 64]"
# t17117 = prims.add(t17097, t17116) # t17117: "cuda:0 f32[1, 71, 2048, 64]"
# t17119 = prims.pad(t17110, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t17119: "cuda:0 bf16[1, 71, 2048, 64]"
# t17121 = prims.convert_element_type(t17119, dtypes.float32) # t17121: "cuda:0 f32[1, 71, 2048, 64]"
# t17122 = prims.add(t17117, t17121) # t17122: "cuda:0 f32[1, 71, 2048, 64]"
# t17123 = prims.convert_element_type(t17122, dtypes.bfloat16) # t17123: "cuda:0 bf16[1, 71, 2048, 64]"
# t17124 = prims.pad(t17123, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t17124: "cuda:0 bf16[1, 71, 2048, 64]"
# t17125 = prims.convert_element_type(t17076, dtypes.float32) # t17125: "cuda:0 f32[1, 71, 2048, 64]"
# t17126 = prims.convert_element_type(t17124, dtypes.float32) # t17126: "cuda:0 f32[1, 71, 2048, 64]"
# t17127 = prims.add(t17125, t17126) # t17127: "cuda:0 f32[1, 71, 2048, 64]"
# t17128 = prims.convert_element_type(t17127, dtypes.bfloat16) # t17128: "cuda:0 bf16[1, 71, 2048, 64]"
# t17129 = prims.convert_element_type(t17078, dtypes.float32) # t17129: "cuda:0 f32[1, 71, 2048, 64]"
# t17133 = prims.mul(t66, t17129) # t17133: "cuda:0 f32[1, 71, 2048, 64]"
# t17136 = prims.convert_element_type(t17133, dtypes.bfloat16) # t17136: "cuda:0 bf16[1, 71, 2048, 64]"
# t17145 = prims.mul(t61, t17129) # t17145: "cuda:0 f32[1, 71, 2048, 64]"
# t17157 = prims.slice_prim(t17136, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t17157: "cuda:0 bf16[1, 71, 2048, 32]"
# t17158 = prims.slice_prim(t17136, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17158: "cuda:0 bf16[1, 71, 2048, 32]"
# t17159 = prims.convert_element_type(t17157, dtypes.float32) # t17159: "cuda:0 f32[1, 71, 2048, 32]"
# t17160 = prims.neg(t17159) # t17160: "cuda:0 f32[1, 71, 2048, 32]"
# t17161 = prims.convert_element_type(t17160, dtypes.bfloat16) # t17161: "cuda:0 bf16[1, 71, 2048, 32]"
# t17162 = prims.pad(t17161, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t17162: "cuda:0 bf16[1, 71, 2048, 64]"
# t17164 = prims.convert_element_type(t17162, dtypes.float32) # t17164: "cuda:0 f32[1, 71, 2048, 64]"
# t17165 = prims.add(t17145, t17164) # t17165: "cuda:0 f32[1, 71, 2048, 64]"
# t17167 = prims.pad(t17158, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t17167: "cuda:0 bf16[1, 71, 2048, 64]"
# t17169 = prims.convert_element_type(t17167, dtypes.float32) # t17169: "cuda:0 f32[1, 71, 2048, 64]"
# t17170 = prims.add(t17165, t17169) # t17170: "cuda:0 f32[1, 71, 2048, 64]"
# t17171 = prims.convert_element_type(t17170, dtypes.bfloat16) # t17171: "cuda:0 bf16[1, 71, 2048, 64]"
# t17172 = prims.pad(t17171, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t17172: "cuda:0 bf16[1, 71, 2048, 64]"
# t17173 = prims.convert_element_type(t17080, dtypes.float32) # t17173: "cuda:0 f32[1, 71, 2048, 64]"
# t17174 = prims.convert_element_type(t17172, dtypes.float32) # t17174: "cuda:0 f32[1, 71, 2048, 64]"
# t17175 = prims.add(t17173, t17174) # t17175: "cuda:0 f32[1, 71, 2048, 64]"
# t17176 = prims.convert_element_type(t17175, dtypes.bfloat16) # t17176: "cuda:0 bf16[1, 71, 2048, 64]"
# t17186 = prims.reshape(t17128, (1, 1, 71, 2048, 64)) # t17186: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17191 = prims.reshape(t17176, (1, 1, 71, 2048, 64)) # t17191: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17197 = prims.convert_element_type(t17181, dtypes.float32) # t17197: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t17198 = prims.sum(t17197, (0, 1, 2)) # t17198: "cuda:0 f32[2048, 64]"
# t17199 = prims.convert_element_type(t17198, dtypes.bfloat16) # t17199: "cuda:0 bf16[2048, 64]"
# t17200 = prims.broadcast_in_dim(t17199, [1, 1, 1, 2048, 64], [3, 4]) # t17200: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t17206 = prims.convert_element_type(t17186, dtypes.float32) # t17206: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t17207 = prims.sum(t17206, (0, 1, 2)) # t17207: "cuda:0 f32[2048, 64]"
# t17208 = prims.convert_element_type(t17207, dtypes.bfloat16) # t17208: "cuda:0 bf16[2048, 64]"
# t17209 = prims.broadcast_in_dim(t17208, [1, 1, 1, 2048, 64], [3, 4]) # t17209: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t17215 = prims.cat((t17191, t17209, t17200), i987) # t17215: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i987, t17074, t17078, t17181
t17221 = torch.permute(t17215, (0, 3, 1, 2, 4)) # t17221: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t17221 = ltorch.permute(t17215, (0, 3, 1, 2, 4)) # t17221: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t17221 = prims.transpose(t17215, (0, 3, 1, 2, 4)) # t17221: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t17215
t17227 = torch.reshape(t17221, (1, 2048, 4672)) # t17227: "cuda:0 bf16[1, 2048, 4672]"
# t17227 = ltorch.reshape(t17221, (1, 2048, 4672)) # t17227: "cuda:0 bf16[1, 2048, 4672]"
# t17227 = prims.reshape(t17221, (1, 2048, 4672)) # t17227: "cuda:0 bf16[1, 2048, 4672]"
del t17221
t17228 = torch.reshape(t17227, (-1, 4672)) # t17228: "cuda:0 bf16[2048, 4672]"
# t17228 = ltorch.reshape(t17227, (-1, 4672)) # t17228: "cuda:0 bf16[2048, 4672]"
# t17228 = prims.reshape(t17227, (2048, 4672)) # t17228: "cuda:0 bf16[2048, 4672]"
del t17227
t17232 = torch.permute(t17228, (1, 0)) # t17232: "cuda:0 bf16[4672, 2048]"
# t17232 = ltorch.permute(t17228, (1, 0)) # t17232: "cuda:0 bf16[4672, 2048]"
# t17232 = prims.transpose(t17228, (1, 0)) # t17232: "cuda:0 bf16[4672, 2048]"
t17234 = torch.matmul(t17232, t17054) # t17234: "cuda:0 bf16[4672, 4544]"
# t17234 = ltorch.matmul(t17232, t17233) # t17234: "cuda:0 bf16[4672, 4544]"
# t17234 = prims.matmul(t17232, t17233) # t17234: "cuda:0 bf16[4672, 4544]"
del t17232, t17054
t17229 = torch.matmul(t17228, t_transformer_h_15_attn_attn_weight) # t17229: "cuda:0 bf16[2048, 4544]"
# t17229 = ltorch.matmul(t17228, t_transformer_h_15_attn_attn_weight) # t17229: "cuda:0 bf16[2048, 4544]"
# t17229 = prims.matmul(t17228, t_transformer_h_15_attn_attn_weight) # t17229: "cuda:0 bf16[2048, 4544]"
del t17228, t_transformer_h_15_attn_attn_weight
t17051 = torch.reshape(t17050, (1, 2048, 4544)) # t17051: "cuda:0 bf16[1, 2048, 4544]"
# t17051 = ltorch.reshape(t17050, (1, 2048, 4544)) # t17051: "cuda:0 bf16[1, 2048, 4544]"
# t17051 = prims.reshape(t17050, (1, 2048, 4544)) # t17051: "cuda:0 bf16[1, 2048, 4544]"
del t17050
t17230 = torch.reshape(t17229, (1, 2048, 4544)) # t17230: "cuda:0 bf16[1, 2048, 4544]"
# t17230 = ltorch.reshape(t17229, (1, 2048, 4544)) # t17230: "cuda:0 bf16[1, 2048, 4544]"
# t17230 = prims.reshape(t17229, (1, 2048, 4544)) # t17230: "cuda:0 bf16[1, 2048, 4544]"
del t17229
[t17243, t17249, t17291] = nvFusion51(i17271, t17008, t17051, t17230, t2225, t2357, t2378, t2393, t2398, t2404)
# t2384 = prims.convert_element_type(t2225, dtypes.float32) # t2384: "cuda:0 f32[1, 2048, 4544]"
# t2379 = prims.convert_element_type(t2378, dtypes.float32) # t2379: "cuda:0 f32[1, 2048, 4544]"
# t2380 = prims.convert_element_type(t2357, dtypes.float32) # t2380: "cuda:0 f32[1, 2048, 4544]"
# t2381 = prims.add(t2379, t2380) # t2381: "cuda:0 f32[1, 2048, 4544]"
# t2385 = prims.add(t2381, t2384) # t2385: "cuda:0 f32[1, 2048, 4544]"
# t2395 = prims.broadcast_in_dim(t2393, [1, 2048, 1], [0, 1]) # t2395: "cuda:0 f32[1, 2048, 1]"
# t2399 = prims.broadcast_in_dim(t2395, (1, 2048, 4544), (0, 1, 2)) # t2399: "cuda:0 f32[1, 2048, 4544]"
# t2401 = prims.sub(t2385, t2399) # t2401: "cuda:0 f32[1, 2048, 4544]"
# t2402 = prims.broadcast_in_dim(t2398, (1, 2048, 4544), (0, 1, 2)) # t2402: "cuda:0 f32[1, 2048, 4544]"
# t2403 = prims.mul(t2401, t2402) # t2403: "cuda:0 f32[1, 2048, 4544]"
# t2405 = prims.convert_element_type(t2404, dtypes.float32) # t2405: "cuda:0 f32[1, 2048, 4544]"
# t17288 = prims.convert_element_type(t17008, dtypes.float32) # t17288: "cuda:0 f32[1, 2048, 4544]"
# t17235 = prims.convert_element_type(t17051, dtypes.float32) # t17235: "cuda:0 f32[1, 2048, 4544]"
# t17236 = prims.convert_element_type(t17230, dtypes.float32) # t17236: "cuda:0 f32[1, 2048, 4544]"
# t17237 = prims.add(t17235, t17236) # t17237: "cuda:0 f32[1, 2048, 4544]"
# t17242 = prims.sum(t17237, (0, 1)) # t17242: "cuda:0 f32[4544]"
# t17243 = prims.convert_element_type(t17242, dtypes.bfloat16) # t17243: "cuda:0 bf16[4544]"
# t17244 = prims.mul(t2405, t17237) # t17244: "cuda:0 f32[1, 2048, 4544]"
# t17245 = prims.mul(t2403, t17237) # t17245: "cuda:0 f32[1, 2048, 4544]"
# t17248 = prims.sum(t17245, (0, 1)) # t17248: "cuda:0 f32[4544]"
# t17249 = prims.convert_element_type(t17248, dtypes.bfloat16) # t17249: "cuda:0 bf16[4544]"
# t17250 = prims.mul(t2402, t17244) # t17250: "cuda:0 f32[1, 2048, 4544]"
# t17251 = prims.mul(t2401, t17244) # t17251: "cuda:0 f32[1, 2048, 4544]"
# t17252 = prims.sum(t17251, (0, 2)) # t17252: "cuda:0 f32[2048]"
# t17253 = prims.broadcast_in_dim(t17252, [1, 2048, 1], [1]) # t17253: "cuda:0 f32[1, 2048, 1]"
# t17254 = prims.neg(t17250) # t17254: "cuda:0 f32[1, 2048, 4544]"
# t17256 = prims.sum(t17254, (0, 2)) # t17256: "cuda:0 f32[2048]"
# t17257 = prims.broadcast_in_dim(t17256, [1, 2048, 1], [1]) # t17257: "cuda:0 f32[1, 2048, 1]"
# t17258 = prims.mul(-0.5, t17253) # t17258: "cuda:0 f32[1, 2048, 1]"
# t17259 = prims.pow(t2398, 3.0) # t17259: "cuda:0 f32[1, 2048, 1]"
# t17260 = prims.mul(t17258, t17259) # t17260: "cuda:0 f32[1, 2048, 1]"
# t17262 = prims.sum(t17257, (0, 2)) # t17262: "cuda:0 f32[2048]"
# t17263 = prims.broadcast_in_dim(t17262, [1, 2048], [1]) # t17263: "cuda:0 f32[1, 2048]"
# t17264 = prims.sum(t17260, (0, 2)) # t17264: "cuda:0 f32[2048]"
# t17265 = prims.broadcast_in_dim(t17264, [1, 2048], [1]) # t17265: "cuda:0 f32[1, 2048]"
# t17268 = prims.broadcast_in_dim(t17263, [1, 2048, 1], [0, 1]) # t17268: "cuda:0 f32[1, 2048, 1]"
# t17269 = prims.broadcast_in_dim(t17268, (1, 2048, 4544), (0, 1, 2)) # t17269: "cuda:0 f32[1, 2048, 4544]"
# t17270 = prims.mul(0.00022007042253521127, t17269) # t17270: "cuda:0 f32[1, 2048, 4544]"
# t17272 = prims.broadcast_in_dim(t17265, [1, 2048, 1], [0, 1]) # t17272: "cuda:0 f32[1, 2048, 1]"
# t17273 = prims.broadcast_in_dim(t17272, (1, 2048, 4544), (0, 1, 2)) # t17273: "cuda:0 f32[1, 2048, 4544]"
# t17275 = prims.broadcast_in_dim(t2393, [1, 2048, 1], [0, 1]) # t17275: "cuda:0 f32[1, 2048, 1]"
# t17276 = prims.broadcast_in_dim(t17275, (1, 2048, 4544), (0, 1, 2)) # t17276: "cuda:0 f32[1, 2048, 4544]"
# t17277 = prims.mul(2.0, t17273) # t17277: "cuda:0 f32[1, 2048, 4544]"
# t17278 = prims.sub(t2385, t17276) # t17278: "cuda:0 f32[1, 2048, 4544]"
# t17279 = prims.mul(t17277, t17278) # t17279: "cuda:0 f32[1, 2048, 4544]"
# f17280 = prims.convert_element_type(i17271, float) # f17280: "float 4544.0"
# t17281 = prims.div(t17279, f17280) # t17281: "cuda:0 f32[1, 2048, 4544]"
# t17282 = prims.add(t17270, t17281) # t17282: "cuda:0 f32[1, 2048, 4544]"
# t17286 = prims.add(t17250, t17282) # t17286: "cuda:0 f32[1, 2048, 4544]"
# t17290 = prims.add(t17288, t17286) # t17290: "cuda:0 f32[1, 2048, 4544]"
# t17291 = prims.convert_element_type(t17290, dtypes.bfloat16) # t17291: "cuda:0 bf16[1, 2048, 4544]"
del i17271, t17008, t17051, t17230, t2225, t2357, t2378, t2393, t2398, t2404
t17298 = torch.reshape(t17291, (-1, 4544)) # t17298: "cuda:0 bf16[2048, 4544]"
# t17298 = ltorch.reshape(t17291, (-1, 4544)) # t17298: "cuda:0 bf16[2048, 4544]"
# t17298 = prims.reshape(t17291, (2048, 4544)) # t17298: "cuda:0 bf16[2048, 4544]"
t17302 = torch.permute(t17298, (1, 0)) # t17302: "cuda:0 bf16[4544, 2048]"
# t17302 = ltorch.permute(t17298, (1, 0)) # t17302: "cuda:0 bf16[4544, 2048]"
# t17302 = prims.transpose(t17298, (1, 0)) # t17302: "cuda:0 bf16[4544, 2048]"
t17345 = torch.matmul(t17302, t17344) # t17345: "cuda:0 bf16[4544, 4544]"
# t17345 = ltorch.matmul(t17343, t17344) # t17345: "cuda:0 bf16[4544, 4544]"
# t17345 = prims.matmul(t17343, t17344) # t17345: "cuda:0 bf16[4544, 4544]"
del t17344
t17299 = torch.matmul(t17298, t_transformer_h_14_mlp_proj_weight) # t17299: "cuda:0 bf16[2048, 18176]"
# t17299 = ltorch.matmul(t17298, t_transformer_h_14_mlp_proj_weight) # t17299: "cuda:0 bf16[2048, 18176]"
# t17299 = prims.matmul(t17298, t_transformer_h_14_mlp_proj_weight) # t17299: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_14_mlp_proj_weight
t17304 = torch.matmul(t17302, t17303) # t17304: "cuda:0 bf16[4544, 18176]"
# t17304 = ltorch.matmul(t17302, t17303) # t17304: "cuda:0 bf16[4544, 18176]"
# t17304 = prims.matmul(t17302, t17303) # t17304: "cuda:0 bf16[4544, 18176]"
del t17302, t17303
t17340 = torch.matmul(t17298, t_transformer_h_14_attn_proj_weight) # t17340: "cuda:0 bf16[2048, 4544]"
# t17340 = ltorch.matmul(t17339, t_transformer_h_14_attn_proj_weight) # t17340: "cuda:0 bf16[2048, 4544]"
# t17340 = prims.matmul(t17339, t_transformer_h_14_attn_proj_weight) # t17340: "cuda:0 bf16[2048, 4544]"
del t17298, t_transformer_h_14_attn_proj_weight
t17300 = torch.reshape(t17299, (1, 2048, 18176)) # t17300: "cuda:0 bf16[1, 2048, 18176]"
# t17300 = ltorch.reshape(t17299, (1, 2048, 18176)) # t17300: "cuda:0 bf16[1, 2048, 18176]"
# t17300 = prims.reshape(t17299, (1, 2048, 18176)) # t17300: "cuda:0 bf16[1, 2048, 18176]"
del t17299
t17341 = torch.reshape(t17340, (1, 2048, 4544)) # t17341: "cuda:0 bf16[1, 2048, 4544]"
# t17341 = ltorch.reshape(t17340, (1, 2048, 4544)) # t17341: "cuda:0 bf16[1, 2048, 4544]"
# t17341 = prims.reshape(t17340, (1, 2048, 4544)) # t17341: "cuda:0 bf16[1, 2048, 4544]"
del t17340
t17349 = torch.reshape(t17341, (1, 2048, 71, 64)) # t17349: "cuda:0 bf16[1, 2048, 71, 64]"
# t17349 = ltorch.reshape(t17341, (1, 2048, 71, 64)) # t17349: "cuda:0 bf16[1, 2048, 71, 64]"
# t17349 = prims.reshape(t17341, (1, 2048, 71, 64)) # t17349: "cuda:0 bf16[1, 2048, 71, 64]"
del t17341
t17352 = torch.permute(t17349, (0, 2, 1, 3)) # t17352: "cuda:0 bf16[1, 71, 2048, 64]"
# t17352 = ltorch.permute(t17349, (0, 2, 1, 3)) # t17352: "cuda:0 bf16[1, 71, 2048, 64]"
# t17352 = prims.transpose(t17349, (0, 2, 1, 3)) # t17352: "cuda:0 bf16[1, 71, 2048, 64]"
del t17349
[t17331] = nvFusion52(f959, f961, t17300, t2358)
# t2359 = prims.convert_element_type(t2358, dtypes.float32) # t2359: "cuda:0 f32[1, 2048, 18176]"
# t2361 = prims.div(t2359, 1.4142135623730951) # t2361: "cuda:0 f32[1, 2048, 18176]"
# t2364 = prims.erf(t2361) # t2364: "cuda:0 f32[1, 2048, 18176]"
# t2368 = prims.mul(0.5, t2364) # t2368: "cuda:0 f32[1, 2048, 18176]"
# t2372 = prims.add(0.5, t2368) # t2372: "cuda:0 f32[1, 2048, 18176]"
# t17305 = prims.convert_element_type(t17300, dtypes.float32) # t17305: "cuda:0 f32[1, 2048, 18176]"
# t17306 = prims.mul(t2372, t17305) # t17306: "cuda:0 f32[1, 2048, 18176]"
# t17307 = prims.mul(t2359, t17305) # t17307: "cuda:0 f32[1, 2048, 18176]"
# t17315 = prims.mul(f961, t17307) # t17315: "cuda:0 f32[1, 2048, 18176]"
# t17318 = prims.pow(t2361, 2.0) # t17318: "cuda:0 f32[1, 2048, 18176]"
# t17319 = prims.neg(t17318) # t17319: "cuda:0 f32[1, 2048, 18176]"
# t17320 = prims.exp(t17319) # t17320: "cuda:0 f32[1, 2048, 18176]"
# t17321 = prims.mul(1.1283791670955126, t17320) # t17321: "cuda:0 f32[1, 2048, 18176]"
# t17322 = prims.mul(t17321, t17315) # t17322: "cuda:0 f32[1, 2048, 18176]"
# t17326 = prims.div(t17322, f959) # t17326: "cuda:0 f32[1, 2048, 18176]"
# t17330 = prims.add(t17306, t17326) # t17330: "cuda:0 f32[1, 2048, 18176]"
# t17331 = prims.convert_element_type(t17330, dtypes.bfloat16) # t17331: "cuda:0 bf16[1, 2048, 18176]"
del f959, f961, t17300, t2358
t17332 = torch.reshape(t17331, (-1, 18176)) # t17332: "cuda:0 bf16[2048, 18176]"
# t17332 = ltorch.reshape(t17331, (-1, 18176)) # t17332: "cuda:0 bf16[2048, 18176]"
# t17332 = prims.reshape(t17331, (2048, 18176)) # t17332: "cuda:0 bf16[2048, 18176]"
del t17331
t17336 = torch.permute(t17332, (1, 0)) # t17336: "cuda:0 bf16[18176, 2048]"
# t17336 = ltorch.permute(t17332, (1, 0)) # t17336: "cuda:0 bf16[18176, 2048]"
# t17336 = prims.transpose(t17332, (1, 0)) # t17336: "cuda:0 bf16[18176, 2048]"
t17338 = torch.matmul(t17336, t17337) # t17338: "cuda:0 bf16[18176, 4544]"
# t17338 = ltorch.matmul(t17336, t17337) # t17338: "cuda:0 bf16[18176, 4544]"
# t17338 = prims.matmul(t17336, t17337) # t17338: "cuda:0 bf16[18176, 4544]"
del t17336
t17333 = torch.matmul(t17332, t_transformer_h_14_mlp_fc_weight) # t17333: "cuda:0 bf16[2048, 4544]"
# t17333 = ltorch.matmul(t17332, t_transformer_h_14_mlp_fc_weight) # t17333: "cuda:0 bf16[2048, 4544]"
# t17333 = prims.matmul(t17332, t_transformer_h_14_mlp_fc_weight) # t17333: "cuda:0 bf16[2048, 4544]"
del t17332, t_transformer_h_14_mlp_fc_weight
(t17353, t17354, t17355) = cudnn_sdpa_bwd(t17352, t2342, t2345, t2295, None, f950, b951, t2346, t2347, t2348, t2349, scale=f952, cat_grad_qkv=False)
del t17352, t2342, t2345, t2295, f950, b951, t2346, t2347, t2348, t2349, f952
t17357 = torch_slice_prim_impl(t17354, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17357: "cuda:0 bf16[1, 71, 2048, 64]"
del t17354
t17361 = torch_slice_prim_impl(t17353, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17361: "cuda:0 bf16[1, 71, 2048, 64]"
del t17353
t17464 = torch.reshape(t17355, (1, 1, 71, 2048, 64)) # t17464: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17464 = ltorch.reshape(t17355, (1, 1, 71, 2048, 64)) # t17464: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17464 = prims.reshape(t17355, (1, 1, 71, 2048, 64)) # t17464: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t17355
[t17498] = nvFusion53(i923, t17357, t17361, t17464, t61, t66)
# t17358 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t17358: "cuda:0 bf16[1, 71, 2048, 0]"
# t17359 = prims.pad(t17358, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t17359: "cuda:0 bf16[1, 71, 2048, 64]"
# t17362 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t17362: "cuda:0 bf16[1, 71, 2048, 0]"
# t17363 = prims.pad(t17362, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t17363: "cuda:0 bf16[1, 71, 2048, 64]"
# t17364 = prims.convert_element_type(t17357, dtypes.float32) # t17364: "cuda:0 f32[1, 71, 2048, 64]"
# t17368 = prims.mul(t66, t17364) # t17368: "cuda:0 f32[1, 71, 2048, 64]"
# t17371 = prims.convert_element_type(t17368, dtypes.bfloat16) # t17371: "cuda:0 bf16[1, 71, 2048, 64]"
# t17380 = prims.mul(t61, t17364) # t17380: "cuda:0 f32[1, 71, 2048, 64]"
# t17392 = prims.slice_prim(t17371, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t17392: "cuda:0 bf16[1, 71, 2048, 32]"
# t17393 = prims.slice_prim(t17371, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17393: "cuda:0 bf16[1, 71, 2048, 32]"
# t17394 = prims.convert_element_type(t17392, dtypes.float32) # t17394: "cuda:0 f32[1, 71, 2048, 32]"
# t17395 = prims.neg(t17394) # t17395: "cuda:0 f32[1, 71, 2048, 32]"
# t17396 = prims.convert_element_type(t17395, dtypes.bfloat16) # t17396: "cuda:0 bf16[1, 71, 2048, 32]"
# t17397 = prims.pad(t17396, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t17397: "cuda:0 bf16[1, 71, 2048, 64]"
# t17399 = prims.convert_element_type(t17397, dtypes.float32) # t17399: "cuda:0 f32[1, 71, 2048, 64]"
# t17400 = prims.add(t17380, t17399) # t17400: "cuda:0 f32[1, 71, 2048, 64]"
# t17402 = prims.pad(t17393, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t17402: "cuda:0 bf16[1, 71, 2048, 64]"
# t17404 = prims.convert_element_type(t17402, dtypes.float32) # t17404: "cuda:0 f32[1, 71, 2048, 64]"
# t17405 = prims.add(t17400, t17404) # t17405: "cuda:0 f32[1, 71, 2048, 64]"
# t17406 = prims.convert_element_type(t17405, dtypes.bfloat16) # t17406: "cuda:0 bf16[1, 71, 2048, 64]"
# t17407 = prims.pad(t17406, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t17407: "cuda:0 bf16[1, 71, 2048, 64]"
# t17408 = prims.convert_element_type(t17359, dtypes.float32) # t17408: "cuda:0 f32[1, 71, 2048, 64]"
# t17409 = prims.convert_element_type(t17407, dtypes.float32) # t17409: "cuda:0 f32[1, 71, 2048, 64]"
# t17410 = prims.add(t17408, t17409) # t17410: "cuda:0 f32[1, 71, 2048, 64]"
# t17411 = prims.convert_element_type(t17410, dtypes.bfloat16) # t17411: "cuda:0 bf16[1, 71, 2048, 64]"
# t17412 = prims.convert_element_type(t17361, dtypes.float32) # t17412: "cuda:0 f32[1, 71, 2048, 64]"
# t17416 = prims.mul(t66, t17412) # t17416: "cuda:0 f32[1, 71, 2048, 64]"
# t17419 = prims.convert_element_type(t17416, dtypes.bfloat16) # t17419: "cuda:0 bf16[1, 71, 2048, 64]"
# t17428 = prims.mul(t61, t17412) # t17428: "cuda:0 f32[1, 71, 2048, 64]"
# t17440 = prims.slice_prim(t17419, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t17440: "cuda:0 bf16[1, 71, 2048, 32]"
# t17441 = prims.slice_prim(t17419, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17441: "cuda:0 bf16[1, 71, 2048, 32]"
# t17442 = prims.convert_element_type(t17440, dtypes.float32) # t17442: "cuda:0 f32[1, 71, 2048, 32]"
# t17443 = prims.neg(t17442) # t17443: "cuda:0 f32[1, 71, 2048, 32]"
# t17444 = prims.convert_element_type(t17443, dtypes.bfloat16) # t17444: "cuda:0 bf16[1, 71, 2048, 32]"
# t17445 = prims.pad(t17444, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t17445: "cuda:0 bf16[1, 71, 2048, 64]"
# t17447 = prims.convert_element_type(t17445, dtypes.float32) # t17447: "cuda:0 f32[1, 71, 2048, 64]"
# t17448 = prims.add(t17428, t17447) # t17448: "cuda:0 f32[1, 71, 2048, 64]"
# t17450 = prims.pad(t17441, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t17450: "cuda:0 bf16[1, 71, 2048, 64]"
# t17452 = prims.convert_element_type(t17450, dtypes.float32) # t17452: "cuda:0 f32[1, 71, 2048, 64]"
# t17453 = prims.add(t17448, t17452) # t17453: "cuda:0 f32[1, 71, 2048, 64]"
# t17454 = prims.convert_element_type(t17453, dtypes.bfloat16) # t17454: "cuda:0 bf16[1, 71, 2048, 64]"
# t17455 = prims.pad(t17454, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t17455: "cuda:0 bf16[1, 71, 2048, 64]"
# t17456 = prims.convert_element_type(t17363, dtypes.float32) # t17456: "cuda:0 f32[1, 71, 2048, 64]"
# t17457 = prims.convert_element_type(t17455, dtypes.float32) # t17457: "cuda:0 f32[1, 71, 2048, 64]"
# t17458 = prims.add(t17456, t17457) # t17458: "cuda:0 f32[1, 71, 2048, 64]"
# t17459 = prims.convert_element_type(t17458, dtypes.bfloat16) # t17459: "cuda:0 bf16[1, 71, 2048, 64]"
# t17469 = prims.reshape(t17411, (1, 1, 71, 2048, 64)) # t17469: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17474 = prims.reshape(t17459, (1, 1, 71, 2048, 64)) # t17474: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17480 = prims.convert_element_type(t17464, dtypes.float32) # t17480: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t17481 = prims.sum(t17480, (0, 1, 2)) # t17481: "cuda:0 f32[2048, 64]"
# t17482 = prims.convert_element_type(t17481, dtypes.bfloat16) # t17482: "cuda:0 bf16[2048, 64]"
# t17483 = prims.broadcast_in_dim(t17482, [1, 1, 1, 2048, 64], [3, 4]) # t17483: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t17489 = prims.convert_element_type(t17469, dtypes.float32) # t17489: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t17490 = prims.sum(t17489, (0, 1, 2)) # t17490: "cuda:0 f32[2048, 64]"
# t17491 = prims.convert_element_type(t17490, dtypes.bfloat16) # t17491: "cuda:0 bf16[2048, 64]"
# t17492 = prims.broadcast_in_dim(t17491, [1, 1, 1, 2048, 64], [3, 4]) # t17492: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t17498 = prims.cat((t17474, t17492, t17483), i923) # t17498: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i923, t17357, t17361, t17464
t17504 = torch.permute(t17498, (0, 3, 1, 2, 4)) # t17504: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t17504 = ltorch.permute(t17498, (0, 3, 1, 2, 4)) # t17504: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t17504 = prims.transpose(t17498, (0, 3, 1, 2, 4)) # t17504: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t17498
t17510 = torch.reshape(t17504, (1, 2048, 4672)) # t17510: "cuda:0 bf16[1, 2048, 4672]"
# t17510 = ltorch.reshape(t17504, (1, 2048, 4672)) # t17510: "cuda:0 bf16[1, 2048, 4672]"
# t17510 = prims.reshape(t17504, (1, 2048, 4672)) # t17510: "cuda:0 bf16[1, 2048, 4672]"
del t17504
t17511 = torch.reshape(t17510, (-1, 4672)) # t17511: "cuda:0 bf16[2048, 4672]"
# t17511 = ltorch.reshape(t17510, (-1, 4672)) # t17511: "cuda:0 bf16[2048, 4672]"
# t17511 = prims.reshape(t17510, (2048, 4672)) # t17511: "cuda:0 bf16[2048, 4672]"
del t17510
t17515 = torch.permute(t17511, (1, 0)) # t17515: "cuda:0 bf16[4672, 2048]"
# t17515 = ltorch.permute(t17511, (1, 0)) # t17515: "cuda:0 bf16[4672, 2048]"
# t17515 = prims.transpose(t17511, (1, 0)) # t17515: "cuda:0 bf16[4672, 2048]"
t17517 = torch.matmul(t17515, t17337) # t17517: "cuda:0 bf16[4672, 4544]"
# t17517 = ltorch.matmul(t17515, t17516) # t17517: "cuda:0 bf16[4672, 4544]"
# t17517 = prims.matmul(t17515, t17516) # t17517: "cuda:0 bf16[4672, 4544]"
del t17515, t17337
t17512 = torch.matmul(t17511, t_transformer_h_14_attn_attn_weight) # t17512: "cuda:0 bf16[2048, 4544]"
# t17512 = ltorch.matmul(t17511, t_transformer_h_14_attn_attn_weight) # t17512: "cuda:0 bf16[2048, 4544]"
# t17512 = prims.matmul(t17511, t_transformer_h_14_attn_attn_weight) # t17512: "cuda:0 bf16[2048, 4544]"
del t17511, t_transformer_h_14_attn_attn_weight
t17334 = torch.reshape(t17333, (1, 2048, 4544)) # t17334: "cuda:0 bf16[1, 2048, 4544]"
# t17334 = ltorch.reshape(t17333, (1, 2048, 4544)) # t17334: "cuda:0 bf16[1, 2048, 4544]"
# t17334 = prims.reshape(t17333, (1, 2048, 4544)) # t17334: "cuda:0 bf16[1, 2048, 4544]"
del t17333
t17513 = torch.reshape(t17512, (1, 2048, 4544)) # t17513: "cuda:0 bf16[1, 2048, 4544]"
# t17513 = ltorch.reshape(t17512, (1, 2048, 4544)) # t17513: "cuda:0 bf16[1, 2048, 4544]"
# t17513 = prims.reshape(t17512, (1, 2048, 4544)) # t17513: "cuda:0 bf16[1, 2048, 4544]"
del t17512
[t17526, t17532, t17574] = nvFusion54(i17554, t17291, t17334, t17513, t2064, t2196, t2217, t2232, t2237, t2243)
# t2223 = prims.convert_element_type(t2064, dtypes.float32) # t2223: "cuda:0 f32[1, 2048, 4544]"
# t2218 = prims.convert_element_type(t2217, dtypes.float32) # t2218: "cuda:0 f32[1, 2048, 4544]"
# t2219 = prims.convert_element_type(t2196, dtypes.float32) # t2219: "cuda:0 f32[1, 2048, 4544]"
# t2220 = prims.add(t2218, t2219) # t2220: "cuda:0 f32[1, 2048, 4544]"
# t2224 = prims.add(t2220, t2223) # t2224: "cuda:0 f32[1, 2048, 4544]"
# t2234 = prims.broadcast_in_dim(t2232, [1, 2048, 1], [0, 1]) # t2234: "cuda:0 f32[1, 2048, 1]"
# t2238 = prims.broadcast_in_dim(t2234, (1, 2048, 4544), (0, 1, 2)) # t2238: "cuda:0 f32[1, 2048, 4544]"
# t2240 = prims.sub(t2224, t2238) # t2240: "cuda:0 f32[1, 2048, 4544]"
# t2241 = prims.broadcast_in_dim(t2237, (1, 2048, 4544), (0, 1, 2)) # t2241: "cuda:0 f32[1, 2048, 4544]"
# t2242 = prims.mul(t2240, t2241) # t2242: "cuda:0 f32[1, 2048, 4544]"
# t2244 = prims.convert_element_type(t2243, dtypes.float32) # t2244: "cuda:0 f32[1, 2048, 4544]"
# t17571 = prims.convert_element_type(t17291, dtypes.float32) # t17571: "cuda:0 f32[1, 2048, 4544]"
# t17518 = prims.convert_element_type(t17334, dtypes.float32) # t17518: "cuda:0 f32[1, 2048, 4544]"
# t17519 = prims.convert_element_type(t17513, dtypes.float32) # t17519: "cuda:0 f32[1, 2048, 4544]"
# t17520 = prims.add(t17518, t17519) # t17520: "cuda:0 f32[1, 2048, 4544]"
# t17525 = prims.sum(t17520, (0, 1)) # t17525: "cuda:0 f32[4544]"
# t17526 = prims.convert_element_type(t17525, dtypes.bfloat16) # t17526: "cuda:0 bf16[4544]"
# t17527 = prims.mul(t2244, t17520) # t17527: "cuda:0 f32[1, 2048, 4544]"
# t17528 = prims.mul(t2242, t17520) # t17528: "cuda:0 f32[1, 2048, 4544]"
# t17531 = prims.sum(t17528, (0, 1)) # t17531: "cuda:0 f32[4544]"
# t17532 = prims.convert_element_type(t17531, dtypes.bfloat16) # t17532: "cuda:0 bf16[4544]"
# t17533 = prims.mul(t2241, t17527) # t17533: "cuda:0 f32[1, 2048, 4544]"
# t17534 = prims.mul(t2240, t17527) # t17534: "cuda:0 f32[1, 2048, 4544]"
# t17535 = prims.sum(t17534, (0, 2)) # t17535: "cuda:0 f32[2048]"
# t17536 = prims.broadcast_in_dim(t17535, [1, 2048, 1], [1]) # t17536: "cuda:0 f32[1, 2048, 1]"
# t17537 = prims.neg(t17533) # t17537: "cuda:0 f32[1, 2048, 4544]"
# t17539 = prims.sum(t17537, (0, 2)) # t17539: "cuda:0 f32[2048]"
# t17540 = prims.broadcast_in_dim(t17539, [1, 2048, 1], [1]) # t17540: "cuda:0 f32[1, 2048, 1]"
# t17541 = prims.mul(-0.5, t17536) # t17541: "cuda:0 f32[1, 2048, 1]"
# t17542 = prims.pow(t2237, 3.0) # t17542: "cuda:0 f32[1, 2048, 1]"
# t17543 = prims.mul(t17541, t17542) # t17543: "cuda:0 f32[1, 2048, 1]"
# t17545 = prims.sum(t17540, (0, 2)) # t17545: "cuda:0 f32[2048]"
# t17546 = prims.broadcast_in_dim(t17545, [1, 2048], [1]) # t17546: "cuda:0 f32[1, 2048]"
# t17547 = prims.sum(t17543, (0, 2)) # t17547: "cuda:0 f32[2048]"
# t17548 = prims.broadcast_in_dim(t17547, [1, 2048], [1]) # t17548: "cuda:0 f32[1, 2048]"
# t17551 = prims.broadcast_in_dim(t17546, [1, 2048, 1], [0, 1]) # t17551: "cuda:0 f32[1, 2048, 1]"
# t17552 = prims.broadcast_in_dim(t17551, (1, 2048, 4544), (0, 1, 2)) # t17552: "cuda:0 f32[1, 2048, 4544]"
# t17553 = prims.mul(0.00022007042253521127, t17552) # t17553: "cuda:0 f32[1, 2048, 4544]"
# t17555 = prims.broadcast_in_dim(t17548, [1, 2048, 1], [0, 1]) # t17555: "cuda:0 f32[1, 2048, 1]"
# t17556 = prims.broadcast_in_dim(t17555, (1, 2048, 4544), (0, 1, 2)) # t17556: "cuda:0 f32[1, 2048, 4544]"
# t17558 = prims.broadcast_in_dim(t2232, [1, 2048, 1], [0, 1]) # t17558: "cuda:0 f32[1, 2048, 1]"
# t17559 = prims.broadcast_in_dim(t17558, (1, 2048, 4544), (0, 1, 2)) # t17559: "cuda:0 f32[1, 2048, 4544]"
# t17560 = prims.mul(2.0, t17556) # t17560: "cuda:0 f32[1, 2048, 4544]"
# t17561 = prims.sub(t2224, t17559) # t17561: "cuda:0 f32[1, 2048, 4544]"
# t17562 = prims.mul(t17560, t17561) # t17562: "cuda:0 f32[1, 2048, 4544]"
# f17563 = prims.convert_element_type(i17554, float) # f17563: "float 4544.0"
# t17564 = prims.div(t17562, f17563) # t17564: "cuda:0 f32[1, 2048, 4544]"
# t17565 = prims.add(t17553, t17564) # t17565: "cuda:0 f32[1, 2048, 4544]"
# t17569 = prims.add(t17533, t17565) # t17569: "cuda:0 f32[1, 2048, 4544]"
# t17573 = prims.add(t17571, t17569) # t17573: "cuda:0 f32[1, 2048, 4544]"
# t17574 = prims.convert_element_type(t17573, dtypes.bfloat16) # t17574: "cuda:0 bf16[1, 2048, 4544]"
del i17554, t17291, t17334, t17513, t2064, t2196, t2217, t2232, t2237, t2243
t17581 = torch.reshape(t17574, (-1, 4544)) # t17581: "cuda:0 bf16[2048, 4544]"
# t17581 = ltorch.reshape(t17574, (-1, 4544)) # t17581: "cuda:0 bf16[2048, 4544]"
# t17581 = prims.reshape(t17574, (2048, 4544)) # t17581: "cuda:0 bf16[2048, 4544]"
t17585 = torch.permute(t17581, (1, 0)) # t17585: "cuda:0 bf16[4544, 2048]"
# t17585 = ltorch.permute(t17581, (1, 0)) # t17585: "cuda:0 bf16[4544, 2048]"
# t17585 = prims.transpose(t17581, (1, 0)) # t17585: "cuda:0 bf16[4544, 2048]"
t17582 = torch.matmul(t17581, t_transformer_h_13_mlp_proj_weight) # t17582: "cuda:0 bf16[2048, 18176]"
# t17582 = ltorch.matmul(t17581, t_transformer_h_13_mlp_proj_weight) # t17582: "cuda:0 bf16[2048, 18176]"
# t17582 = prims.matmul(t17581, t_transformer_h_13_mlp_proj_weight) # t17582: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_13_mlp_proj_weight
t17587 = torch.matmul(t17585, t17586) # t17587: "cuda:0 bf16[4544, 18176]"
# t17587 = ltorch.matmul(t17585, t17586) # t17587: "cuda:0 bf16[4544, 18176]"
# t17587 = prims.matmul(t17585, t17586) # t17587: "cuda:0 bf16[4544, 18176]"
del t17586
t17623 = torch.matmul(t17581, t_transformer_h_13_attn_proj_weight) # t17623: "cuda:0 bf16[2048, 4544]"
# t17623 = ltorch.matmul(t17622, t_transformer_h_13_attn_proj_weight) # t17623: "cuda:0 bf16[2048, 4544]"
# t17623 = prims.matmul(t17622, t_transformer_h_13_attn_proj_weight) # t17623: "cuda:0 bf16[2048, 4544]"
del t17581, t_transformer_h_13_attn_proj_weight
t17628 = torch.matmul(t17585, t17627) # t17628: "cuda:0 bf16[4544, 4544]"
# t17628 = ltorch.matmul(t17626, t17627) # t17628: "cuda:0 bf16[4544, 4544]"
# t17628 = prims.matmul(t17626, t17627) # t17628: "cuda:0 bf16[4544, 4544]"
del t17585, t17627
t17583 = torch.reshape(t17582, (1, 2048, 18176)) # t17583: "cuda:0 bf16[1, 2048, 18176]"
# t17583 = ltorch.reshape(t17582, (1, 2048, 18176)) # t17583: "cuda:0 bf16[1, 2048, 18176]"
# t17583 = prims.reshape(t17582, (1, 2048, 18176)) # t17583: "cuda:0 bf16[1, 2048, 18176]"
del t17582
t17624 = torch.reshape(t17623, (1, 2048, 4544)) # t17624: "cuda:0 bf16[1, 2048, 4544]"
# t17624 = ltorch.reshape(t17623, (1, 2048, 4544)) # t17624: "cuda:0 bf16[1, 2048, 4544]"
# t17624 = prims.reshape(t17623, (1, 2048, 4544)) # t17624: "cuda:0 bf16[1, 2048, 4544]"
del t17623
t17632 = torch.reshape(t17624, (1, 2048, 71, 64)) # t17632: "cuda:0 bf16[1, 2048, 71, 64]"
# t17632 = ltorch.reshape(t17624, (1, 2048, 71, 64)) # t17632: "cuda:0 bf16[1, 2048, 71, 64]"
# t17632 = prims.reshape(t17624, (1, 2048, 71, 64)) # t17632: "cuda:0 bf16[1, 2048, 71, 64]"
del t17624
t17635 = torch.permute(t17632, (0, 2, 1, 3)) # t17635: "cuda:0 bf16[1, 71, 2048, 64]"
# t17635 = ltorch.permute(t17632, (0, 2, 1, 3)) # t17635: "cuda:0 bf16[1, 71, 2048, 64]"
# t17635 = prims.transpose(t17632, (0, 2, 1, 3)) # t17635: "cuda:0 bf16[1, 71, 2048, 64]"
del t17632
[t17614] = nvFusion55(f895, f897, t17583, t2197)
# t2198 = prims.convert_element_type(t2197, dtypes.float32) # t2198: "cuda:0 f32[1, 2048, 18176]"
# t2200 = prims.div(t2198, 1.4142135623730951) # t2200: "cuda:0 f32[1, 2048, 18176]"
# t2203 = prims.erf(t2200) # t2203: "cuda:0 f32[1, 2048, 18176]"
# t2207 = prims.mul(0.5, t2203) # t2207: "cuda:0 f32[1, 2048, 18176]"
# t2211 = prims.add(0.5, t2207) # t2211: "cuda:0 f32[1, 2048, 18176]"
# t17588 = prims.convert_element_type(t17583, dtypes.float32) # t17588: "cuda:0 f32[1, 2048, 18176]"
# t17589 = prims.mul(t2211, t17588) # t17589: "cuda:0 f32[1, 2048, 18176]"
# t17590 = prims.mul(t2198, t17588) # t17590: "cuda:0 f32[1, 2048, 18176]"
# t17598 = prims.mul(f897, t17590) # t17598: "cuda:0 f32[1, 2048, 18176]"
# t17601 = prims.pow(t2200, 2.0) # t17601: "cuda:0 f32[1, 2048, 18176]"
# t17602 = prims.neg(t17601) # t17602: "cuda:0 f32[1, 2048, 18176]"
# t17603 = prims.exp(t17602) # t17603: "cuda:0 f32[1, 2048, 18176]"
# t17604 = prims.mul(1.1283791670955126, t17603) # t17604: "cuda:0 f32[1, 2048, 18176]"
# t17605 = prims.mul(t17604, t17598) # t17605: "cuda:0 f32[1, 2048, 18176]"
# t17609 = prims.div(t17605, f895) # t17609: "cuda:0 f32[1, 2048, 18176]"
# t17613 = prims.add(t17589, t17609) # t17613: "cuda:0 f32[1, 2048, 18176]"
# t17614 = prims.convert_element_type(t17613, dtypes.bfloat16) # t17614: "cuda:0 bf16[1, 2048, 18176]"
del f895, f897, t17583, t2197
t17615 = torch.reshape(t17614, (-1, 18176)) # t17615: "cuda:0 bf16[2048, 18176]"
# t17615 = ltorch.reshape(t17614, (-1, 18176)) # t17615: "cuda:0 bf16[2048, 18176]"
# t17615 = prims.reshape(t17614, (2048, 18176)) # t17615: "cuda:0 bf16[2048, 18176]"
del t17614
t17619 = torch.permute(t17615, (1, 0)) # t17619: "cuda:0 bf16[18176, 2048]"
# t17619 = ltorch.permute(t17615, (1, 0)) # t17619: "cuda:0 bf16[18176, 2048]"
# t17619 = prims.transpose(t17615, (1, 0)) # t17619: "cuda:0 bf16[18176, 2048]"
t17621 = torch.matmul(t17619, t17620) # t17621: "cuda:0 bf16[18176, 4544]"
# t17621 = ltorch.matmul(t17619, t17620) # t17621: "cuda:0 bf16[18176, 4544]"
# t17621 = prims.matmul(t17619, t17620) # t17621: "cuda:0 bf16[18176, 4544]"
del t17619
t17616 = torch.matmul(t17615, t_transformer_h_13_mlp_fc_weight) # t17616: "cuda:0 bf16[2048, 4544]"
# t17616 = ltorch.matmul(t17615, t_transformer_h_13_mlp_fc_weight) # t17616: "cuda:0 bf16[2048, 4544]"
# t17616 = prims.matmul(t17615, t_transformer_h_13_mlp_fc_weight) # t17616: "cuda:0 bf16[2048, 4544]"
del t17615, t_transformer_h_13_mlp_fc_weight
(t17636, t17637, t17638) = cudnn_sdpa_bwd(t17635, t2181, t2184, t2134, None, f886, b887, t2185, t2186, t2187, t2188, scale=f888, cat_grad_qkv=False)
del t17635, t2181, t2184, t2134, f886, b887, t2185, t2186, t2187, t2188, f888
t17640 = torch_slice_prim_impl(t17637, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17640: "cuda:0 bf16[1, 71, 2048, 64]"
del t17637
t17644 = torch_slice_prim_impl(t17636, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17644: "cuda:0 bf16[1, 71, 2048, 64]"
del t17636
t17747 = torch.reshape(t17638, (1, 1, 71, 2048, 64)) # t17747: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17747 = ltorch.reshape(t17638, (1, 1, 71, 2048, 64)) # t17747: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17747 = prims.reshape(t17638, (1, 1, 71, 2048, 64)) # t17747: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t17638
[t17781] = nvFusion56(i859, t17640, t17644, t17747, t61, t66)
# t17641 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t17641: "cuda:0 bf16[1, 71, 2048, 0]"
# t17642 = prims.pad(t17641, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t17642: "cuda:0 bf16[1, 71, 2048, 64]"
# t17645 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t17645: "cuda:0 bf16[1, 71, 2048, 0]"
# t17646 = prims.pad(t17645, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t17646: "cuda:0 bf16[1, 71, 2048, 64]"
# t17647 = prims.convert_element_type(t17640, dtypes.float32) # t17647: "cuda:0 f32[1, 71, 2048, 64]"
# t17651 = prims.mul(t66, t17647) # t17651: "cuda:0 f32[1, 71, 2048, 64]"
# t17654 = prims.convert_element_type(t17651, dtypes.bfloat16) # t17654: "cuda:0 bf16[1, 71, 2048, 64]"
# t17663 = prims.mul(t61, t17647) # t17663: "cuda:0 f32[1, 71, 2048, 64]"
# t17675 = prims.slice_prim(t17654, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t17675: "cuda:0 bf16[1, 71, 2048, 32]"
# t17676 = prims.slice_prim(t17654, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17676: "cuda:0 bf16[1, 71, 2048, 32]"
# t17677 = prims.convert_element_type(t17675, dtypes.float32) # t17677: "cuda:0 f32[1, 71, 2048, 32]"
# t17678 = prims.neg(t17677) # t17678: "cuda:0 f32[1, 71, 2048, 32]"
# t17679 = prims.convert_element_type(t17678, dtypes.bfloat16) # t17679: "cuda:0 bf16[1, 71, 2048, 32]"
# t17680 = prims.pad(t17679, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t17680: "cuda:0 bf16[1, 71, 2048, 64]"
# t17682 = prims.convert_element_type(t17680, dtypes.float32) # t17682: "cuda:0 f32[1, 71, 2048, 64]"
# t17683 = prims.add(t17663, t17682) # t17683: "cuda:0 f32[1, 71, 2048, 64]"
# t17685 = prims.pad(t17676, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t17685: "cuda:0 bf16[1, 71, 2048, 64]"
# t17687 = prims.convert_element_type(t17685, dtypes.float32) # t17687: "cuda:0 f32[1, 71, 2048, 64]"
# t17688 = prims.add(t17683, t17687) # t17688: "cuda:0 f32[1, 71, 2048, 64]"
# t17689 = prims.convert_element_type(t17688, dtypes.bfloat16) # t17689: "cuda:0 bf16[1, 71, 2048, 64]"
# t17690 = prims.pad(t17689, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t17690: "cuda:0 bf16[1, 71, 2048, 64]"
# t17691 = prims.convert_element_type(t17642, dtypes.float32) # t17691: "cuda:0 f32[1, 71, 2048, 64]"
# t17692 = prims.convert_element_type(t17690, dtypes.float32) # t17692: "cuda:0 f32[1, 71, 2048, 64]"
# t17693 = prims.add(t17691, t17692) # t17693: "cuda:0 f32[1, 71, 2048, 64]"
# t17694 = prims.convert_element_type(t17693, dtypes.bfloat16) # t17694: "cuda:0 bf16[1, 71, 2048, 64]"
# t17695 = prims.convert_element_type(t17644, dtypes.float32) # t17695: "cuda:0 f32[1, 71, 2048, 64]"
# t17699 = prims.mul(t66, t17695) # t17699: "cuda:0 f32[1, 71, 2048, 64]"
# t17702 = prims.convert_element_type(t17699, dtypes.bfloat16) # t17702: "cuda:0 bf16[1, 71, 2048, 64]"
# t17711 = prims.mul(t61, t17695) # t17711: "cuda:0 f32[1, 71, 2048, 64]"
# t17723 = prims.slice_prim(t17702, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t17723: "cuda:0 bf16[1, 71, 2048, 32]"
# t17724 = prims.slice_prim(t17702, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17724: "cuda:0 bf16[1, 71, 2048, 32]"
# t17725 = prims.convert_element_type(t17723, dtypes.float32) # t17725: "cuda:0 f32[1, 71, 2048, 32]"
# t17726 = prims.neg(t17725) # t17726: "cuda:0 f32[1, 71, 2048, 32]"
# t17727 = prims.convert_element_type(t17726, dtypes.bfloat16) # t17727: "cuda:0 bf16[1, 71, 2048, 32]"
# t17728 = prims.pad(t17727, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t17728: "cuda:0 bf16[1, 71, 2048, 64]"
# t17730 = prims.convert_element_type(t17728, dtypes.float32) # t17730: "cuda:0 f32[1, 71, 2048, 64]"
# t17731 = prims.add(t17711, t17730) # t17731: "cuda:0 f32[1, 71, 2048, 64]"
# t17733 = prims.pad(t17724, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t17733: "cuda:0 bf16[1, 71, 2048, 64]"
# t17735 = prims.convert_element_type(t17733, dtypes.float32) # t17735: "cuda:0 f32[1, 71, 2048, 64]"
# t17736 = prims.add(t17731, t17735) # t17736: "cuda:0 f32[1, 71, 2048, 64]"
# t17737 = prims.convert_element_type(t17736, dtypes.bfloat16) # t17737: "cuda:0 bf16[1, 71, 2048, 64]"
# t17738 = prims.pad(t17737, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t17738: "cuda:0 bf16[1, 71, 2048, 64]"
# t17739 = prims.convert_element_type(t17646, dtypes.float32) # t17739: "cuda:0 f32[1, 71, 2048, 64]"
# t17740 = prims.convert_element_type(t17738, dtypes.float32) # t17740: "cuda:0 f32[1, 71, 2048, 64]"
# t17741 = prims.add(t17739, t17740) # t17741: "cuda:0 f32[1, 71, 2048, 64]"
# t17742 = prims.convert_element_type(t17741, dtypes.bfloat16) # t17742: "cuda:0 bf16[1, 71, 2048, 64]"
# t17752 = prims.reshape(t17694, (1, 1, 71, 2048, 64)) # t17752: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17757 = prims.reshape(t17742, (1, 1, 71, 2048, 64)) # t17757: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t17763 = prims.convert_element_type(t17747, dtypes.float32) # t17763: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t17764 = prims.sum(t17763, (0, 1, 2)) # t17764: "cuda:0 f32[2048, 64]"
# t17765 = prims.convert_element_type(t17764, dtypes.bfloat16) # t17765: "cuda:0 bf16[2048, 64]"
# t17766 = prims.broadcast_in_dim(t17765, [1, 1, 1, 2048, 64], [3, 4]) # t17766: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t17772 = prims.convert_element_type(t17752, dtypes.float32) # t17772: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t17773 = prims.sum(t17772, (0, 1, 2)) # t17773: "cuda:0 f32[2048, 64]"
# t17774 = prims.convert_element_type(t17773, dtypes.bfloat16) # t17774: "cuda:0 bf16[2048, 64]"
# t17775 = prims.broadcast_in_dim(t17774, [1, 1, 1, 2048, 64], [3, 4]) # t17775: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t17781 = prims.cat((t17757, t17775, t17766), i859) # t17781: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i859, t17640, t17644, t17747
t17787 = torch.permute(t17781, (0, 3, 1, 2, 4)) # t17787: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t17787 = ltorch.permute(t17781, (0, 3, 1, 2, 4)) # t17787: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t17787 = prims.transpose(t17781, (0, 3, 1, 2, 4)) # t17787: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t17781
t17793 = torch.reshape(t17787, (1, 2048, 4672)) # t17793: "cuda:0 bf16[1, 2048, 4672]"
# t17793 = ltorch.reshape(t17787, (1, 2048, 4672)) # t17793: "cuda:0 bf16[1, 2048, 4672]"
# t17793 = prims.reshape(t17787, (1, 2048, 4672)) # t17793: "cuda:0 bf16[1, 2048, 4672]"
del t17787
t17794 = torch.reshape(t17793, (-1, 4672)) # t17794: "cuda:0 bf16[2048, 4672]"
# t17794 = ltorch.reshape(t17793, (-1, 4672)) # t17794: "cuda:0 bf16[2048, 4672]"
# t17794 = prims.reshape(t17793, (2048, 4672)) # t17794: "cuda:0 bf16[2048, 4672]"
del t17793
t17798 = torch.permute(t17794, (1, 0)) # t17798: "cuda:0 bf16[4672, 2048]"
# t17798 = ltorch.permute(t17794, (1, 0)) # t17798: "cuda:0 bf16[4672, 2048]"
# t17798 = prims.transpose(t17794, (1, 0)) # t17798: "cuda:0 bf16[4672, 2048]"
t17800 = torch.matmul(t17798, t17620) # t17800: "cuda:0 bf16[4672, 4544]"
# t17800 = ltorch.matmul(t17798, t17799) # t17800: "cuda:0 bf16[4672, 4544]"
# t17800 = prims.matmul(t17798, t17799) # t17800: "cuda:0 bf16[4672, 4544]"
del t17798, t17620
t17795 = torch.matmul(t17794, t_transformer_h_13_attn_attn_weight) # t17795: "cuda:0 bf16[2048, 4544]"
# t17795 = ltorch.matmul(t17794, t_transformer_h_13_attn_attn_weight) # t17795: "cuda:0 bf16[2048, 4544]"
# t17795 = prims.matmul(t17794, t_transformer_h_13_attn_attn_weight) # t17795: "cuda:0 bf16[2048, 4544]"
del t17794, t_transformer_h_13_attn_attn_weight
t17617 = torch.reshape(t17616, (1, 2048, 4544)) # t17617: "cuda:0 bf16[1, 2048, 4544]"
# t17617 = ltorch.reshape(t17616, (1, 2048, 4544)) # t17617: "cuda:0 bf16[1, 2048, 4544]"
# t17617 = prims.reshape(t17616, (1, 2048, 4544)) # t17617: "cuda:0 bf16[1, 2048, 4544]"
del t17616
t17796 = torch.reshape(t17795, (1, 2048, 4544)) # t17796: "cuda:0 bf16[1, 2048, 4544]"
# t17796 = ltorch.reshape(t17795, (1, 2048, 4544)) # t17796: "cuda:0 bf16[1, 2048, 4544]"
# t17796 = prims.reshape(t17795, (1, 2048, 4544)) # t17796: "cuda:0 bf16[1, 2048, 4544]"
del t17795
[t17809, t17815, t17857] = nvFusion57(i17837, t17574, t17617, t17796, t1903, t2035, t2056, t2071, t2076, t2082)
# t2062 = prims.convert_element_type(t1903, dtypes.float32) # t2062: "cuda:0 f32[1, 2048, 4544]"
# t2057 = prims.convert_element_type(t2056, dtypes.float32) # t2057: "cuda:0 f32[1, 2048, 4544]"
# t2058 = prims.convert_element_type(t2035, dtypes.float32) # t2058: "cuda:0 f32[1, 2048, 4544]"
# t2059 = prims.add(t2057, t2058) # t2059: "cuda:0 f32[1, 2048, 4544]"
# t2063 = prims.add(t2059, t2062) # t2063: "cuda:0 f32[1, 2048, 4544]"
# t2073 = prims.broadcast_in_dim(t2071, [1, 2048, 1], [0, 1]) # t2073: "cuda:0 f32[1, 2048, 1]"
# t2077 = prims.broadcast_in_dim(t2073, (1, 2048, 4544), (0, 1, 2)) # t2077: "cuda:0 f32[1, 2048, 4544]"
# t2079 = prims.sub(t2063, t2077) # t2079: "cuda:0 f32[1, 2048, 4544]"
# t2080 = prims.broadcast_in_dim(t2076, (1, 2048, 4544), (0, 1, 2)) # t2080: "cuda:0 f32[1, 2048, 4544]"
# t2081 = prims.mul(t2079, t2080) # t2081: "cuda:0 f32[1, 2048, 4544]"
# t2083 = prims.convert_element_type(t2082, dtypes.float32) # t2083: "cuda:0 f32[1, 2048, 4544]"
# t17854 = prims.convert_element_type(t17574, dtypes.float32) # t17854: "cuda:0 f32[1, 2048, 4544]"
# t17801 = prims.convert_element_type(t17617, dtypes.float32) # t17801: "cuda:0 f32[1, 2048, 4544]"
# t17802 = prims.convert_element_type(t17796, dtypes.float32) # t17802: "cuda:0 f32[1, 2048, 4544]"
# t17803 = prims.add(t17801, t17802) # t17803: "cuda:0 f32[1, 2048, 4544]"
# t17808 = prims.sum(t17803, (0, 1)) # t17808: "cuda:0 f32[4544]"
# t17809 = prims.convert_element_type(t17808, dtypes.bfloat16) # t17809: "cuda:0 bf16[4544]"
# t17810 = prims.mul(t2083, t17803) # t17810: "cuda:0 f32[1, 2048, 4544]"
# t17811 = prims.mul(t2081, t17803) # t17811: "cuda:0 f32[1, 2048, 4544]"
# t17814 = prims.sum(t17811, (0, 1)) # t17814: "cuda:0 f32[4544]"
# t17815 = prims.convert_element_type(t17814, dtypes.bfloat16) # t17815: "cuda:0 bf16[4544]"
# t17816 = prims.mul(t2080, t17810) # t17816: "cuda:0 f32[1, 2048, 4544]"
# t17817 = prims.mul(t2079, t17810) # t17817: "cuda:0 f32[1, 2048, 4544]"
# t17818 = prims.sum(t17817, (0, 2)) # t17818: "cuda:0 f32[2048]"
# t17819 = prims.broadcast_in_dim(t17818, [1, 2048, 1], [1]) # t17819: "cuda:0 f32[1, 2048, 1]"
# t17820 = prims.neg(t17816) # t17820: "cuda:0 f32[1, 2048, 4544]"
# t17822 = prims.sum(t17820, (0, 2)) # t17822: "cuda:0 f32[2048]"
# t17823 = prims.broadcast_in_dim(t17822, [1, 2048, 1], [1]) # t17823: "cuda:0 f32[1, 2048, 1]"
# t17824 = prims.mul(-0.5, t17819) # t17824: "cuda:0 f32[1, 2048, 1]"
# t17825 = prims.pow(t2076, 3.0) # t17825: "cuda:0 f32[1, 2048, 1]"
# t17826 = prims.mul(t17824, t17825) # t17826: "cuda:0 f32[1, 2048, 1]"
# t17828 = prims.sum(t17823, (0, 2)) # t17828: "cuda:0 f32[2048]"
# t17829 = prims.broadcast_in_dim(t17828, [1, 2048], [1]) # t17829: "cuda:0 f32[1, 2048]"
# t17830 = prims.sum(t17826, (0, 2)) # t17830: "cuda:0 f32[2048]"
# t17831 = prims.broadcast_in_dim(t17830, [1, 2048], [1]) # t17831: "cuda:0 f32[1, 2048]"
# t17834 = prims.broadcast_in_dim(t17829, [1, 2048, 1], [0, 1]) # t17834: "cuda:0 f32[1, 2048, 1]"
# t17835 = prims.broadcast_in_dim(t17834, (1, 2048, 4544), (0, 1, 2)) # t17835: "cuda:0 f32[1, 2048, 4544]"
# t17836 = prims.mul(0.00022007042253521127, t17835) # t17836: "cuda:0 f32[1, 2048, 4544]"
# t17838 = prims.broadcast_in_dim(t17831, [1, 2048, 1], [0, 1]) # t17838: "cuda:0 f32[1, 2048, 1]"
# t17839 = prims.broadcast_in_dim(t17838, (1, 2048, 4544), (0, 1, 2)) # t17839: "cuda:0 f32[1, 2048, 4544]"
# t17841 = prims.broadcast_in_dim(t2071, [1, 2048, 1], [0, 1]) # t17841: "cuda:0 f32[1, 2048, 1]"
# t17842 = prims.broadcast_in_dim(t17841, (1, 2048, 4544), (0, 1, 2)) # t17842: "cuda:0 f32[1, 2048, 4544]"
# t17843 = prims.mul(2.0, t17839) # t17843: "cuda:0 f32[1, 2048, 4544]"
# t17844 = prims.sub(t2063, t17842) # t17844: "cuda:0 f32[1, 2048, 4544]"
# t17845 = prims.mul(t17843, t17844) # t17845: "cuda:0 f32[1, 2048, 4544]"
# f17846 = prims.convert_element_type(i17837, float) # f17846: "float 4544.0"
# t17847 = prims.div(t17845, f17846) # t17847: "cuda:0 f32[1, 2048, 4544]"
# t17848 = prims.add(t17836, t17847) # t17848: "cuda:0 f32[1, 2048, 4544]"
# t17852 = prims.add(t17816, t17848) # t17852: "cuda:0 f32[1, 2048, 4544]"
# t17856 = prims.add(t17854, t17852) # t17856: "cuda:0 f32[1, 2048, 4544]"
# t17857 = prims.convert_element_type(t17856, dtypes.bfloat16) # t17857: "cuda:0 bf16[1, 2048, 4544]"
del i17837, t17574, t17617, t17796, t1903, t2035, t2056, t2071, t2076, t2082
t17864 = torch.reshape(t17857, (-1, 4544)) # t17864: "cuda:0 bf16[2048, 4544]"
# t17864 = ltorch.reshape(t17857, (-1, 4544)) # t17864: "cuda:0 bf16[2048, 4544]"
# t17864 = prims.reshape(t17857, (2048, 4544)) # t17864: "cuda:0 bf16[2048, 4544]"
t17868 = torch.permute(t17864, (1, 0)) # t17868: "cuda:0 bf16[4544, 2048]"
# t17868 = ltorch.permute(t17864, (1, 0)) # t17868: "cuda:0 bf16[4544, 2048]"
# t17868 = prims.transpose(t17864, (1, 0)) # t17868: "cuda:0 bf16[4544, 2048]"
t17865 = torch.matmul(t17864, t_transformer_h_12_mlp_proj_weight) # t17865: "cuda:0 bf16[2048, 18176]"
# t17865 = ltorch.matmul(t17864, t_transformer_h_12_mlp_proj_weight) # t17865: "cuda:0 bf16[2048, 18176]"
# t17865 = prims.matmul(t17864, t_transformer_h_12_mlp_proj_weight) # t17865: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_12_mlp_proj_weight
t17870 = torch.matmul(t17868, t17869) # t17870: "cuda:0 bf16[4544, 18176]"
# t17870 = ltorch.matmul(t17868, t17869) # t17870: "cuda:0 bf16[4544, 18176]"
# t17870 = prims.matmul(t17868, t17869) # t17870: "cuda:0 bf16[4544, 18176]"
del t17869
t17906 = torch.matmul(t17864, t_transformer_h_12_attn_proj_weight) # t17906: "cuda:0 bf16[2048, 4544]"
# t17906 = ltorch.matmul(t17905, t_transformer_h_12_attn_proj_weight) # t17906: "cuda:0 bf16[2048, 4544]"
# t17906 = prims.matmul(t17905, t_transformer_h_12_attn_proj_weight) # t17906: "cuda:0 bf16[2048, 4544]"
del t17864, t_transformer_h_12_attn_proj_weight
t17911 = torch.matmul(t17868, t17910) # t17911: "cuda:0 bf16[4544, 4544]"
# t17911 = ltorch.matmul(t17909, t17910) # t17911: "cuda:0 bf16[4544, 4544]"
# t17911 = prims.matmul(t17909, t17910) # t17911: "cuda:0 bf16[4544, 4544]"
del t17868, t17910
t17866 = torch.reshape(t17865, (1, 2048, 18176)) # t17866: "cuda:0 bf16[1, 2048, 18176]"
# t17866 = ltorch.reshape(t17865, (1, 2048, 18176)) # t17866: "cuda:0 bf16[1, 2048, 18176]"
# t17866 = prims.reshape(t17865, (1, 2048, 18176)) # t17866: "cuda:0 bf16[1, 2048, 18176]"
del t17865
t17907 = torch.reshape(t17906, (1, 2048, 4544)) # t17907: "cuda:0 bf16[1, 2048, 4544]"
# t17907 = ltorch.reshape(t17906, (1, 2048, 4544)) # t17907: "cuda:0 bf16[1, 2048, 4544]"
# t17907 = prims.reshape(t17906, (1, 2048, 4544)) # t17907: "cuda:0 bf16[1, 2048, 4544]"
del t17906
t17915 = torch.reshape(t17907, (1, 2048, 71, 64)) # t17915: "cuda:0 bf16[1, 2048, 71, 64]"
# t17915 = ltorch.reshape(t17907, (1, 2048, 71, 64)) # t17915: "cuda:0 bf16[1, 2048, 71, 64]"
# t17915 = prims.reshape(t17907, (1, 2048, 71, 64)) # t17915: "cuda:0 bf16[1, 2048, 71, 64]"
del t17907
t17918 = torch.permute(t17915, (0, 2, 1, 3)) # t17918: "cuda:0 bf16[1, 71, 2048, 64]"
# t17918 = ltorch.permute(t17915, (0, 2, 1, 3)) # t17918: "cuda:0 bf16[1, 71, 2048, 64]"
# t17918 = prims.transpose(t17915, (0, 2, 1, 3)) # t17918: "cuda:0 bf16[1, 71, 2048, 64]"
del t17915
[t17897] = nvFusion58(f831, f833, t17866, t2036)
# t2037 = prims.convert_element_type(t2036, dtypes.float32) # t2037: "cuda:0 f32[1, 2048, 18176]"
# t2039 = prims.div(t2037, 1.4142135623730951) # t2039: "cuda:0 f32[1, 2048, 18176]"
# t2042 = prims.erf(t2039) # t2042: "cuda:0 f32[1, 2048, 18176]"
# t2046 = prims.mul(0.5, t2042) # t2046: "cuda:0 f32[1, 2048, 18176]"
# t2050 = prims.add(0.5, t2046) # t2050: "cuda:0 f32[1, 2048, 18176]"
# t17871 = prims.convert_element_type(t17866, dtypes.float32) # t17871: "cuda:0 f32[1, 2048, 18176]"
# t17872 = prims.mul(t2050, t17871) # t17872: "cuda:0 f32[1, 2048, 18176]"
# t17873 = prims.mul(t2037, t17871) # t17873: "cuda:0 f32[1, 2048, 18176]"
# t17881 = prims.mul(f833, t17873) # t17881: "cuda:0 f32[1, 2048, 18176]"
# t17884 = prims.pow(t2039, 2.0) # t17884: "cuda:0 f32[1, 2048, 18176]"
# t17885 = prims.neg(t17884) # t17885: "cuda:0 f32[1, 2048, 18176]"
# t17886 = prims.exp(t17885) # t17886: "cuda:0 f32[1, 2048, 18176]"
# t17887 = prims.mul(1.1283791670955126, t17886) # t17887: "cuda:0 f32[1, 2048, 18176]"
# t17888 = prims.mul(t17887, t17881) # t17888: "cuda:0 f32[1, 2048, 18176]"
# t17892 = prims.div(t17888, f831) # t17892: "cuda:0 f32[1, 2048, 18176]"
# t17896 = prims.add(t17872, t17892) # t17896: "cuda:0 f32[1, 2048, 18176]"
# t17897 = prims.convert_element_type(t17896, dtypes.bfloat16) # t17897: "cuda:0 bf16[1, 2048, 18176]"
del f831, f833, t17866, t2036
t17898 = torch.reshape(t17897, (-1, 18176)) # t17898: "cuda:0 bf16[2048, 18176]"
# t17898 = ltorch.reshape(t17897, (-1, 18176)) # t17898: "cuda:0 bf16[2048, 18176]"
# t17898 = prims.reshape(t17897, (2048, 18176)) # t17898: "cuda:0 bf16[2048, 18176]"
del t17897
t17902 = torch.permute(t17898, (1, 0)) # t17902: "cuda:0 bf16[18176, 2048]"
# t17902 = ltorch.permute(t17898, (1, 0)) # t17902: "cuda:0 bf16[18176, 2048]"
# t17902 = prims.transpose(t17898, (1, 0)) # t17902: "cuda:0 bf16[18176, 2048]"
t17904 = torch.matmul(t17902, t17903) # t17904: "cuda:0 bf16[18176, 4544]"
# t17904 = ltorch.matmul(t17902, t17903) # t17904: "cuda:0 bf16[18176, 4544]"
# t17904 = prims.matmul(t17902, t17903) # t17904: "cuda:0 bf16[18176, 4544]"
del t17902
t17899 = torch.matmul(t17898, t_transformer_h_12_mlp_fc_weight) # t17899: "cuda:0 bf16[2048, 4544]"
# t17899 = ltorch.matmul(t17898, t_transformer_h_12_mlp_fc_weight) # t17899: "cuda:0 bf16[2048, 4544]"
# t17899 = prims.matmul(t17898, t_transformer_h_12_mlp_fc_weight) # t17899: "cuda:0 bf16[2048, 4544]"
del t17898, t_transformer_h_12_mlp_fc_weight
(t17919, t17920, t17921) = cudnn_sdpa_bwd(t17918, t2020, t2023, t1973, None, f822, b823, t2024, t2025, t2026, t2027, scale=f824, cat_grad_qkv=False)
del t17918, t2020, t2023, t1973, f822, b823, t2024, t2025, t2026, t2027, f824
t17923 = torch_slice_prim_impl(t17920, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17923: "cuda:0 bf16[1, 71, 2048, 64]"
del t17920
t17927 = torch_slice_prim_impl(t17919, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17927: "cuda:0 bf16[1, 71, 2048, 64]"
del t17919
t18030 = torch.reshape(t17921, (1, 1, 71, 2048, 64)) # t18030: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18030 = ltorch.reshape(t17921, (1, 1, 71, 2048, 64)) # t18030: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18030 = prims.reshape(t17921, (1, 1, 71, 2048, 64)) # t18030: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t17921
[t18064] = nvFusion59(i795, t17923, t17927, t18030, t61, t66)
# t17924 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t17924: "cuda:0 bf16[1, 71, 2048, 0]"
# t17925 = prims.pad(t17924, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t17925: "cuda:0 bf16[1, 71, 2048, 64]"
# t17928 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t17928: "cuda:0 bf16[1, 71, 2048, 0]"
# t17929 = prims.pad(t17928, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t17929: "cuda:0 bf16[1, 71, 2048, 64]"
# t17930 = prims.convert_element_type(t17923, dtypes.float32) # t17930: "cuda:0 f32[1, 71, 2048, 64]"
# t17934 = prims.mul(t66, t17930) # t17934: "cuda:0 f32[1, 71, 2048, 64]"
# t17937 = prims.convert_element_type(t17934, dtypes.bfloat16) # t17937: "cuda:0 bf16[1, 71, 2048, 64]"
# t17946 = prims.mul(t61, t17930) # t17946: "cuda:0 f32[1, 71, 2048, 64]"
# t17958 = prims.slice_prim(t17937, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t17958: "cuda:0 bf16[1, 71, 2048, 32]"
# t17959 = prims.slice_prim(t17937, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t17959: "cuda:0 bf16[1, 71, 2048, 32]"
# t17960 = prims.convert_element_type(t17958, dtypes.float32) # t17960: "cuda:0 f32[1, 71, 2048, 32]"
# t17961 = prims.neg(t17960) # t17961: "cuda:0 f32[1, 71, 2048, 32]"
# t17962 = prims.convert_element_type(t17961, dtypes.bfloat16) # t17962: "cuda:0 bf16[1, 71, 2048, 32]"
# t17963 = prims.pad(t17962, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t17963: "cuda:0 bf16[1, 71, 2048, 64]"
# t17965 = prims.convert_element_type(t17963, dtypes.float32) # t17965: "cuda:0 f32[1, 71, 2048, 64]"
# t17966 = prims.add(t17946, t17965) # t17966: "cuda:0 f32[1, 71, 2048, 64]"
# t17968 = prims.pad(t17959, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t17968: "cuda:0 bf16[1, 71, 2048, 64]"
# t17970 = prims.convert_element_type(t17968, dtypes.float32) # t17970: "cuda:0 f32[1, 71, 2048, 64]"
# t17971 = prims.add(t17966, t17970) # t17971: "cuda:0 f32[1, 71, 2048, 64]"
# t17972 = prims.convert_element_type(t17971, dtypes.bfloat16) # t17972: "cuda:0 bf16[1, 71, 2048, 64]"
# t17973 = prims.pad(t17972, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t17973: "cuda:0 bf16[1, 71, 2048, 64]"
# t17974 = prims.convert_element_type(t17925, dtypes.float32) # t17974: "cuda:0 f32[1, 71, 2048, 64]"
# t17975 = prims.convert_element_type(t17973, dtypes.float32) # t17975: "cuda:0 f32[1, 71, 2048, 64]"
# t17976 = prims.add(t17974, t17975) # t17976: "cuda:0 f32[1, 71, 2048, 64]"
# t17977 = prims.convert_element_type(t17976, dtypes.bfloat16) # t17977: "cuda:0 bf16[1, 71, 2048, 64]"
# t17978 = prims.convert_element_type(t17927, dtypes.float32) # t17978: "cuda:0 f32[1, 71, 2048, 64]"
# t17982 = prims.mul(t66, t17978) # t17982: "cuda:0 f32[1, 71, 2048, 64]"
# t17985 = prims.convert_element_type(t17982, dtypes.bfloat16) # t17985: "cuda:0 bf16[1, 71, 2048, 64]"
# t17994 = prims.mul(t61, t17978) # t17994: "cuda:0 f32[1, 71, 2048, 64]"
# t18006 = prims.slice_prim(t17985, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t18006: "cuda:0 bf16[1, 71, 2048, 32]"
# t18007 = prims.slice_prim(t17985, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18007: "cuda:0 bf16[1, 71, 2048, 32]"
# t18008 = prims.convert_element_type(t18006, dtypes.float32) # t18008: "cuda:0 f32[1, 71, 2048, 32]"
# t18009 = prims.neg(t18008) # t18009: "cuda:0 f32[1, 71, 2048, 32]"
# t18010 = prims.convert_element_type(t18009, dtypes.bfloat16) # t18010: "cuda:0 bf16[1, 71, 2048, 32]"
# t18011 = prims.pad(t18010, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t18011: "cuda:0 bf16[1, 71, 2048, 64]"
# t18013 = prims.convert_element_type(t18011, dtypes.float32) # t18013: "cuda:0 f32[1, 71, 2048, 64]"
# t18014 = prims.add(t17994, t18013) # t18014: "cuda:0 f32[1, 71, 2048, 64]"
# t18016 = prims.pad(t18007, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t18016: "cuda:0 bf16[1, 71, 2048, 64]"
# t18018 = prims.convert_element_type(t18016, dtypes.float32) # t18018: "cuda:0 f32[1, 71, 2048, 64]"
# t18019 = prims.add(t18014, t18018) # t18019: "cuda:0 f32[1, 71, 2048, 64]"
# t18020 = prims.convert_element_type(t18019, dtypes.bfloat16) # t18020: "cuda:0 bf16[1, 71, 2048, 64]"
# t18021 = prims.pad(t18020, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t18021: "cuda:0 bf16[1, 71, 2048, 64]"
# t18022 = prims.convert_element_type(t17929, dtypes.float32) # t18022: "cuda:0 f32[1, 71, 2048, 64]"
# t18023 = prims.convert_element_type(t18021, dtypes.float32) # t18023: "cuda:0 f32[1, 71, 2048, 64]"
# t18024 = prims.add(t18022, t18023) # t18024: "cuda:0 f32[1, 71, 2048, 64]"
# t18025 = prims.convert_element_type(t18024, dtypes.bfloat16) # t18025: "cuda:0 bf16[1, 71, 2048, 64]"
# t18035 = prims.reshape(t17977, (1, 1, 71, 2048, 64)) # t18035: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18040 = prims.reshape(t18025, (1, 1, 71, 2048, 64)) # t18040: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18046 = prims.convert_element_type(t18030, dtypes.float32) # t18046: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t18047 = prims.sum(t18046, (0, 1, 2)) # t18047: "cuda:0 f32[2048, 64]"
# t18048 = prims.convert_element_type(t18047, dtypes.bfloat16) # t18048: "cuda:0 bf16[2048, 64]"
# t18049 = prims.broadcast_in_dim(t18048, [1, 1, 1, 2048, 64], [3, 4]) # t18049: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t18055 = prims.convert_element_type(t18035, dtypes.float32) # t18055: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t18056 = prims.sum(t18055, (0, 1, 2)) # t18056: "cuda:0 f32[2048, 64]"
# t18057 = prims.convert_element_type(t18056, dtypes.bfloat16) # t18057: "cuda:0 bf16[2048, 64]"
# t18058 = prims.broadcast_in_dim(t18057, [1, 1, 1, 2048, 64], [3, 4]) # t18058: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t18064 = prims.cat((t18040, t18058, t18049), i795) # t18064: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i795, t17923, t17927, t18030
t18070 = torch.permute(t18064, (0, 3, 1, 2, 4)) # t18070: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t18070 = ltorch.permute(t18064, (0, 3, 1, 2, 4)) # t18070: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t18070 = prims.transpose(t18064, (0, 3, 1, 2, 4)) # t18070: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t18064
t18076 = torch.reshape(t18070, (1, 2048, 4672)) # t18076: "cuda:0 bf16[1, 2048, 4672]"
# t18076 = ltorch.reshape(t18070, (1, 2048, 4672)) # t18076: "cuda:0 bf16[1, 2048, 4672]"
# t18076 = prims.reshape(t18070, (1, 2048, 4672)) # t18076: "cuda:0 bf16[1, 2048, 4672]"
del t18070
t18077 = torch.reshape(t18076, (-1, 4672)) # t18077: "cuda:0 bf16[2048, 4672]"
# t18077 = ltorch.reshape(t18076, (-1, 4672)) # t18077: "cuda:0 bf16[2048, 4672]"
# t18077 = prims.reshape(t18076, (2048, 4672)) # t18077: "cuda:0 bf16[2048, 4672]"
del t18076
t18081 = torch.permute(t18077, (1, 0)) # t18081: "cuda:0 bf16[4672, 2048]"
# t18081 = ltorch.permute(t18077, (1, 0)) # t18081: "cuda:0 bf16[4672, 2048]"
# t18081 = prims.transpose(t18077, (1, 0)) # t18081: "cuda:0 bf16[4672, 2048]"
t18083 = torch.matmul(t18081, t17903) # t18083: "cuda:0 bf16[4672, 4544]"
# t18083 = ltorch.matmul(t18081, t18082) # t18083: "cuda:0 bf16[4672, 4544]"
# t18083 = prims.matmul(t18081, t18082) # t18083: "cuda:0 bf16[4672, 4544]"
del t18081, t17903
t18078 = torch.matmul(t18077, t_transformer_h_12_attn_attn_weight) # t18078: "cuda:0 bf16[2048, 4544]"
# t18078 = ltorch.matmul(t18077, t_transformer_h_12_attn_attn_weight) # t18078: "cuda:0 bf16[2048, 4544]"
# t18078 = prims.matmul(t18077, t_transformer_h_12_attn_attn_weight) # t18078: "cuda:0 bf16[2048, 4544]"
del t18077, t_transformer_h_12_attn_attn_weight
t17900 = torch.reshape(t17899, (1, 2048, 4544)) # t17900: "cuda:0 bf16[1, 2048, 4544]"
# t17900 = ltorch.reshape(t17899, (1, 2048, 4544)) # t17900: "cuda:0 bf16[1, 2048, 4544]"
# t17900 = prims.reshape(t17899, (1, 2048, 4544)) # t17900: "cuda:0 bf16[1, 2048, 4544]"
del t17899
t18079 = torch.reshape(t18078, (1, 2048, 4544)) # t18079: "cuda:0 bf16[1, 2048, 4544]"
# t18079 = ltorch.reshape(t18078, (1, 2048, 4544)) # t18079: "cuda:0 bf16[1, 2048, 4544]"
# t18079 = prims.reshape(t18078, (1, 2048, 4544)) # t18079: "cuda:0 bf16[1, 2048, 4544]"
del t18078
[t18092, t18098, t18140] = nvFusion60(i18120, t1742, t17857, t17900, t18079, t1874, t1895, t1910, t1915, t1921)
# t1901 = prims.convert_element_type(t1742, dtypes.float32) # t1901: "cuda:0 f32[1, 2048, 4544]"
# t1896 = prims.convert_element_type(t1895, dtypes.float32) # t1896: "cuda:0 f32[1, 2048, 4544]"
# t1897 = prims.convert_element_type(t1874, dtypes.float32) # t1897: "cuda:0 f32[1, 2048, 4544]"
# t1898 = prims.add(t1896, t1897) # t1898: "cuda:0 f32[1, 2048, 4544]"
# t1902 = prims.add(t1898, t1901) # t1902: "cuda:0 f32[1, 2048, 4544]"
# t1912 = prims.broadcast_in_dim(t1910, [1, 2048, 1], [0, 1]) # t1912: "cuda:0 f32[1, 2048, 1]"
# t1916 = prims.broadcast_in_dim(t1912, (1, 2048, 4544), (0, 1, 2)) # t1916: "cuda:0 f32[1, 2048, 4544]"
# t1918 = prims.sub(t1902, t1916) # t1918: "cuda:0 f32[1, 2048, 4544]"
# t1919 = prims.broadcast_in_dim(t1915, (1, 2048, 4544), (0, 1, 2)) # t1919: "cuda:0 f32[1, 2048, 4544]"
# t1920 = prims.mul(t1918, t1919) # t1920: "cuda:0 f32[1, 2048, 4544]"
# t1922 = prims.convert_element_type(t1921, dtypes.float32) # t1922: "cuda:0 f32[1, 2048, 4544]"
# t18137 = prims.convert_element_type(t17857, dtypes.float32) # t18137: "cuda:0 f32[1, 2048, 4544]"
# t18084 = prims.convert_element_type(t17900, dtypes.float32) # t18084: "cuda:0 f32[1, 2048, 4544]"
# t18085 = prims.convert_element_type(t18079, dtypes.float32) # t18085: "cuda:0 f32[1, 2048, 4544]"
# t18086 = prims.add(t18084, t18085) # t18086: "cuda:0 f32[1, 2048, 4544]"
# t18091 = prims.sum(t18086, (0, 1)) # t18091: "cuda:0 f32[4544]"
# t18092 = prims.convert_element_type(t18091, dtypes.bfloat16) # t18092: "cuda:0 bf16[4544]"
# t18093 = prims.mul(t1922, t18086) # t18093: "cuda:0 f32[1, 2048, 4544]"
# t18094 = prims.mul(t1920, t18086) # t18094: "cuda:0 f32[1, 2048, 4544]"
# t18097 = prims.sum(t18094, (0, 1)) # t18097: "cuda:0 f32[4544]"
# t18098 = prims.convert_element_type(t18097, dtypes.bfloat16) # t18098: "cuda:0 bf16[4544]"
# t18099 = prims.mul(t1919, t18093) # t18099: "cuda:0 f32[1, 2048, 4544]"
# t18100 = prims.mul(t1918, t18093) # t18100: "cuda:0 f32[1, 2048, 4544]"
# t18101 = prims.sum(t18100, (0, 2)) # t18101: "cuda:0 f32[2048]"
# t18102 = prims.broadcast_in_dim(t18101, [1, 2048, 1], [1]) # t18102: "cuda:0 f32[1, 2048, 1]"
# t18103 = prims.neg(t18099) # t18103: "cuda:0 f32[1, 2048, 4544]"
# t18105 = prims.sum(t18103, (0, 2)) # t18105: "cuda:0 f32[2048]"
# t18106 = prims.broadcast_in_dim(t18105, [1, 2048, 1], [1]) # t18106: "cuda:0 f32[1, 2048, 1]"
# t18107 = prims.mul(-0.5, t18102) # t18107: "cuda:0 f32[1, 2048, 1]"
# t18108 = prims.pow(t1915, 3.0) # t18108: "cuda:0 f32[1, 2048, 1]"
# t18109 = prims.mul(t18107, t18108) # t18109: "cuda:0 f32[1, 2048, 1]"
# t18111 = prims.sum(t18106, (0, 2)) # t18111: "cuda:0 f32[2048]"
# t18112 = prims.broadcast_in_dim(t18111, [1, 2048], [1]) # t18112: "cuda:0 f32[1, 2048]"
# t18113 = prims.sum(t18109, (0, 2)) # t18113: "cuda:0 f32[2048]"
# t18114 = prims.broadcast_in_dim(t18113, [1, 2048], [1]) # t18114: "cuda:0 f32[1, 2048]"
# t18117 = prims.broadcast_in_dim(t18112, [1, 2048, 1], [0, 1]) # t18117: "cuda:0 f32[1, 2048, 1]"
# t18118 = prims.broadcast_in_dim(t18117, (1, 2048, 4544), (0, 1, 2)) # t18118: "cuda:0 f32[1, 2048, 4544]"
# t18119 = prims.mul(0.00022007042253521127, t18118) # t18119: "cuda:0 f32[1, 2048, 4544]"
# t18121 = prims.broadcast_in_dim(t18114, [1, 2048, 1], [0, 1]) # t18121: "cuda:0 f32[1, 2048, 1]"
# t18122 = prims.broadcast_in_dim(t18121, (1, 2048, 4544), (0, 1, 2)) # t18122: "cuda:0 f32[1, 2048, 4544]"
# t18124 = prims.broadcast_in_dim(t1910, [1, 2048, 1], [0, 1]) # t18124: "cuda:0 f32[1, 2048, 1]"
# t18125 = prims.broadcast_in_dim(t18124, (1, 2048, 4544), (0, 1, 2)) # t18125: "cuda:0 f32[1, 2048, 4544]"
# t18126 = prims.mul(2.0, t18122) # t18126: "cuda:0 f32[1, 2048, 4544]"
# t18127 = prims.sub(t1902, t18125) # t18127: "cuda:0 f32[1, 2048, 4544]"
# t18128 = prims.mul(t18126, t18127) # t18128: "cuda:0 f32[1, 2048, 4544]"
# f18129 = prims.convert_element_type(i18120, float) # f18129: "float 4544.0"
# t18130 = prims.div(t18128, f18129) # t18130: "cuda:0 f32[1, 2048, 4544]"
# t18131 = prims.add(t18119, t18130) # t18131: "cuda:0 f32[1, 2048, 4544]"
# t18135 = prims.add(t18099, t18131) # t18135: "cuda:0 f32[1, 2048, 4544]"
# t18139 = prims.add(t18137, t18135) # t18139: "cuda:0 f32[1, 2048, 4544]"
# t18140 = prims.convert_element_type(t18139, dtypes.bfloat16) # t18140: "cuda:0 bf16[1, 2048, 4544]"
del i18120, t1742, t17857, t17900, t18079, t1874, t1895, t1910, t1915, t1921
t18147 = torch.reshape(t18140, (-1, 4544)) # t18147: "cuda:0 bf16[2048, 4544]"
# t18147 = ltorch.reshape(t18140, (-1, 4544)) # t18147: "cuda:0 bf16[2048, 4544]"
# t18147 = prims.reshape(t18140, (2048, 4544)) # t18147: "cuda:0 bf16[2048, 4544]"
t18151 = torch.permute(t18147, (1, 0)) # t18151: "cuda:0 bf16[4544, 2048]"
# t18151 = ltorch.permute(t18147, (1, 0)) # t18151: "cuda:0 bf16[4544, 2048]"
# t18151 = prims.transpose(t18147, (1, 0)) # t18151: "cuda:0 bf16[4544, 2048]"
t18153 = torch.matmul(t18151, t18152) # t18153: "cuda:0 bf16[4544, 18176]"
# t18153 = ltorch.matmul(t18151, t18152) # t18153: "cuda:0 bf16[4544, 18176]"
# t18153 = prims.matmul(t18151, t18152) # t18153: "cuda:0 bf16[4544, 18176]"
del t18152
t18189 = torch.matmul(t18147, t_transformer_h_11_attn_proj_weight) # t18189: "cuda:0 bf16[2048, 4544]"
# t18189 = ltorch.matmul(t18188, t_transformer_h_11_attn_proj_weight) # t18189: "cuda:0 bf16[2048, 4544]"
# t18189 = prims.matmul(t18188, t_transformer_h_11_attn_proj_weight) # t18189: "cuda:0 bf16[2048, 4544]"
del t_transformer_h_11_attn_proj_weight
t18194 = torch.matmul(t18151, t18193) # t18194: "cuda:0 bf16[4544, 4544]"
# t18194 = ltorch.matmul(t18192, t18193) # t18194: "cuda:0 bf16[4544, 4544]"
# t18194 = prims.matmul(t18192, t18193) # t18194: "cuda:0 bf16[4544, 4544]"
del t18151, t18193
t18148 = torch.matmul(t18147, t_transformer_h_11_mlp_proj_weight) # t18148: "cuda:0 bf16[2048, 18176]"
# t18148 = ltorch.matmul(t18147, t_transformer_h_11_mlp_proj_weight) # t18148: "cuda:0 bf16[2048, 18176]"
# t18148 = prims.matmul(t18147, t_transformer_h_11_mlp_proj_weight) # t18148: "cuda:0 bf16[2048, 18176]"
del t18147, t_transformer_h_11_mlp_proj_weight
t18190 = torch.reshape(t18189, (1, 2048, 4544)) # t18190: "cuda:0 bf16[1, 2048, 4544]"
# t18190 = ltorch.reshape(t18189, (1, 2048, 4544)) # t18190: "cuda:0 bf16[1, 2048, 4544]"
# t18190 = prims.reshape(t18189, (1, 2048, 4544)) # t18190: "cuda:0 bf16[1, 2048, 4544]"
del t18189
t18198 = torch.reshape(t18190, (1, 2048, 71, 64)) # t18198: "cuda:0 bf16[1, 2048, 71, 64]"
# t18198 = ltorch.reshape(t18190, (1, 2048, 71, 64)) # t18198: "cuda:0 bf16[1, 2048, 71, 64]"
# t18198 = prims.reshape(t18190, (1, 2048, 71, 64)) # t18198: "cuda:0 bf16[1, 2048, 71, 64]"
del t18190
t18201 = torch.permute(t18198, (0, 2, 1, 3)) # t18201: "cuda:0 bf16[1, 71, 2048, 64]"
# t18201 = ltorch.permute(t18198, (0, 2, 1, 3)) # t18201: "cuda:0 bf16[1, 71, 2048, 64]"
# t18201 = prims.transpose(t18198, (0, 2, 1, 3)) # t18201: "cuda:0 bf16[1, 71, 2048, 64]"
del t18198
t18149 = torch.reshape(t18148, (1, 2048, 18176)) # t18149: "cuda:0 bf16[1, 2048, 18176]"
# t18149 = ltorch.reshape(t18148, (1, 2048, 18176)) # t18149: "cuda:0 bf16[1, 2048, 18176]"
# t18149 = prims.reshape(t18148, (1, 2048, 18176)) # t18149: "cuda:0 bf16[1, 2048, 18176]"
del t18148
[t18180] = nvFusion61(f767, f769, t18149, t1875)
# t1876 = prims.convert_element_type(t1875, dtypes.float32) # t1876: "cuda:0 f32[1, 2048, 18176]"
# t1878 = prims.div(t1876, 1.4142135623730951) # t1878: "cuda:0 f32[1, 2048, 18176]"
# t1881 = prims.erf(t1878) # t1881: "cuda:0 f32[1, 2048, 18176]"
# t1885 = prims.mul(0.5, t1881) # t1885: "cuda:0 f32[1, 2048, 18176]"
# t1889 = prims.add(0.5, t1885) # t1889: "cuda:0 f32[1, 2048, 18176]"
# t18154 = prims.convert_element_type(t18149, dtypes.float32) # t18154: "cuda:0 f32[1, 2048, 18176]"
# t18155 = prims.mul(t1889, t18154) # t18155: "cuda:0 f32[1, 2048, 18176]"
# t18156 = prims.mul(t1876, t18154) # t18156: "cuda:0 f32[1, 2048, 18176]"
# t18164 = prims.mul(f769, t18156) # t18164: "cuda:0 f32[1, 2048, 18176]"
# t18167 = prims.pow(t1878, 2.0) # t18167: "cuda:0 f32[1, 2048, 18176]"
# t18168 = prims.neg(t18167) # t18168: "cuda:0 f32[1, 2048, 18176]"
# t18169 = prims.exp(t18168) # t18169: "cuda:0 f32[1, 2048, 18176]"
# t18170 = prims.mul(1.1283791670955126, t18169) # t18170: "cuda:0 f32[1, 2048, 18176]"
# t18171 = prims.mul(t18170, t18164) # t18171: "cuda:0 f32[1, 2048, 18176]"
# t18175 = prims.div(t18171, f767) # t18175: "cuda:0 f32[1, 2048, 18176]"
# t18179 = prims.add(t18155, t18175) # t18179: "cuda:0 f32[1, 2048, 18176]"
# t18180 = prims.convert_element_type(t18179, dtypes.bfloat16) # t18180: "cuda:0 bf16[1, 2048, 18176]"
del f767, f769, t18149, t1875
t18181 = torch.reshape(t18180, (-1, 18176)) # t18181: "cuda:0 bf16[2048, 18176]"
# t18181 = ltorch.reshape(t18180, (-1, 18176)) # t18181: "cuda:0 bf16[2048, 18176]"
# t18181 = prims.reshape(t18180, (2048, 18176)) # t18181: "cuda:0 bf16[2048, 18176]"
del t18180
t18185 = torch.permute(t18181, (1, 0)) # t18185: "cuda:0 bf16[18176, 2048]"
# t18185 = ltorch.permute(t18181, (1, 0)) # t18185: "cuda:0 bf16[18176, 2048]"
# t18185 = prims.transpose(t18181, (1, 0)) # t18185: "cuda:0 bf16[18176, 2048]"
(t18202, t18203, t18204) = cudnn_sdpa_bwd(t18201, t1859, t1862, t1812, None, f758, b759, t1863, t1864, t1865, t1866, scale=f760, cat_grad_qkv=False)
del t18201, t1859, t1862, t1812, f758, b759, t1863, t1864, t1865, t1866, f760
t18187 = torch.matmul(t18185, t18186) # t18187: "cuda:0 bf16[18176, 4544]"
# t18187 = ltorch.matmul(t18185, t18186) # t18187: "cuda:0 bf16[18176, 4544]"
# t18187 = prims.matmul(t18185, t18186) # t18187: "cuda:0 bf16[18176, 4544]"
del t18185
t18182 = torch.matmul(t18181, t_transformer_h_11_mlp_fc_weight) # t18182: "cuda:0 bf16[2048, 4544]"
# t18182 = ltorch.matmul(t18181, t_transformer_h_11_mlp_fc_weight) # t18182: "cuda:0 bf16[2048, 4544]"
# t18182 = prims.matmul(t18181, t_transformer_h_11_mlp_fc_weight) # t18182: "cuda:0 bf16[2048, 4544]"
del t18181, t_transformer_h_11_mlp_fc_weight
t18206 = torch_slice_prim_impl(t18203, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18206: "cuda:0 bf16[1, 71, 2048, 64]"
del t18203
t18210 = torch_slice_prim_impl(t18202, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18210: "cuda:0 bf16[1, 71, 2048, 64]"
del t18202
t18313 = torch.reshape(t18204, (1, 1, 71, 2048, 64)) # t18313: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18313 = ltorch.reshape(t18204, (1, 1, 71, 2048, 64)) # t18313: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18313 = prims.reshape(t18204, (1, 1, 71, 2048, 64)) # t18313: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t18204
[t18347] = nvFusion62(i731, t18206, t18210, t18313, t61, t66)
# t18207 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t18207: "cuda:0 bf16[1, 71, 2048, 0]"
# t18208 = prims.pad(t18207, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t18208: "cuda:0 bf16[1, 71, 2048, 64]"
# t18211 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t18211: "cuda:0 bf16[1, 71, 2048, 0]"
# t18212 = prims.pad(t18211, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t18212: "cuda:0 bf16[1, 71, 2048, 64]"
# t18213 = prims.convert_element_type(t18206, dtypes.float32) # t18213: "cuda:0 f32[1, 71, 2048, 64]"
# t18217 = prims.mul(t66, t18213) # t18217: "cuda:0 f32[1, 71, 2048, 64]"
# t18220 = prims.convert_element_type(t18217, dtypes.bfloat16) # t18220: "cuda:0 bf16[1, 71, 2048, 64]"
# t18229 = prims.mul(t61, t18213) # t18229: "cuda:0 f32[1, 71, 2048, 64]"
# t18241 = prims.slice_prim(t18220, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t18241: "cuda:0 bf16[1, 71, 2048, 32]"
# t18242 = prims.slice_prim(t18220, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18242: "cuda:0 bf16[1, 71, 2048, 32]"
# t18243 = prims.convert_element_type(t18241, dtypes.float32) # t18243: "cuda:0 f32[1, 71, 2048, 32]"
# t18244 = prims.neg(t18243) # t18244: "cuda:0 f32[1, 71, 2048, 32]"
# t18245 = prims.convert_element_type(t18244, dtypes.bfloat16) # t18245: "cuda:0 bf16[1, 71, 2048, 32]"
# t18246 = prims.pad(t18245, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t18246: "cuda:0 bf16[1, 71, 2048, 64]"
# t18248 = prims.convert_element_type(t18246, dtypes.float32) # t18248: "cuda:0 f32[1, 71, 2048, 64]"
# t18249 = prims.add(t18229, t18248) # t18249: "cuda:0 f32[1, 71, 2048, 64]"
# t18251 = prims.pad(t18242, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t18251: "cuda:0 bf16[1, 71, 2048, 64]"
# t18253 = prims.convert_element_type(t18251, dtypes.float32) # t18253: "cuda:0 f32[1, 71, 2048, 64]"
# t18254 = prims.add(t18249, t18253) # t18254: "cuda:0 f32[1, 71, 2048, 64]"
# t18255 = prims.convert_element_type(t18254, dtypes.bfloat16) # t18255: "cuda:0 bf16[1, 71, 2048, 64]"
# t18256 = prims.pad(t18255, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t18256: "cuda:0 bf16[1, 71, 2048, 64]"
# t18257 = prims.convert_element_type(t18208, dtypes.float32) # t18257: "cuda:0 f32[1, 71, 2048, 64]"
# t18258 = prims.convert_element_type(t18256, dtypes.float32) # t18258: "cuda:0 f32[1, 71, 2048, 64]"
# t18259 = prims.add(t18257, t18258) # t18259: "cuda:0 f32[1, 71, 2048, 64]"
# t18260 = prims.convert_element_type(t18259, dtypes.bfloat16) # t18260: "cuda:0 bf16[1, 71, 2048, 64]"
# t18261 = prims.convert_element_type(t18210, dtypes.float32) # t18261: "cuda:0 f32[1, 71, 2048, 64]"
# t18265 = prims.mul(t66, t18261) # t18265: "cuda:0 f32[1, 71, 2048, 64]"
# t18268 = prims.convert_element_type(t18265, dtypes.bfloat16) # t18268: "cuda:0 bf16[1, 71, 2048, 64]"
# t18277 = prims.mul(t61, t18261) # t18277: "cuda:0 f32[1, 71, 2048, 64]"
# t18289 = prims.slice_prim(t18268, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t18289: "cuda:0 bf16[1, 71, 2048, 32]"
# t18290 = prims.slice_prim(t18268, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18290: "cuda:0 bf16[1, 71, 2048, 32]"
# t18291 = prims.convert_element_type(t18289, dtypes.float32) # t18291: "cuda:0 f32[1, 71, 2048, 32]"
# t18292 = prims.neg(t18291) # t18292: "cuda:0 f32[1, 71, 2048, 32]"
# t18293 = prims.convert_element_type(t18292, dtypes.bfloat16) # t18293: "cuda:0 bf16[1, 71, 2048, 32]"
# t18294 = prims.pad(t18293, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t18294: "cuda:0 bf16[1, 71, 2048, 64]"
# t18296 = prims.convert_element_type(t18294, dtypes.float32) # t18296: "cuda:0 f32[1, 71, 2048, 64]"
# t18297 = prims.add(t18277, t18296) # t18297: "cuda:0 f32[1, 71, 2048, 64]"
# t18299 = prims.pad(t18290, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t18299: "cuda:0 bf16[1, 71, 2048, 64]"
# t18301 = prims.convert_element_type(t18299, dtypes.float32) # t18301: "cuda:0 f32[1, 71, 2048, 64]"
# t18302 = prims.add(t18297, t18301) # t18302: "cuda:0 f32[1, 71, 2048, 64]"
# t18303 = prims.convert_element_type(t18302, dtypes.bfloat16) # t18303: "cuda:0 bf16[1, 71, 2048, 64]"
# t18304 = prims.pad(t18303, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t18304: "cuda:0 bf16[1, 71, 2048, 64]"
# t18305 = prims.convert_element_type(t18212, dtypes.float32) # t18305: "cuda:0 f32[1, 71, 2048, 64]"
# t18306 = prims.convert_element_type(t18304, dtypes.float32) # t18306: "cuda:0 f32[1, 71, 2048, 64]"
# t18307 = prims.add(t18305, t18306) # t18307: "cuda:0 f32[1, 71, 2048, 64]"
# t18308 = prims.convert_element_type(t18307, dtypes.bfloat16) # t18308: "cuda:0 bf16[1, 71, 2048, 64]"
# t18318 = prims.reshape(t18260, (1, 1, 71, 2048, 64)) # t18318: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18323 = prims.reshape(t18308, (1, 1, 71, 2048, 64)) # t18323: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18329 = prims.convert_element_type(t18313, dtypes.float32) # t18329: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t18330 = prims.sum(t18329, (0, 1, 2)) # t18330: "cuda:0 f32[2048, 64]"
# t18331 = prims.convert_element_type(t18330, dtypes.bfloat16) # t18331: "cuda:0 bf16[2048, 64]"
# t18332 = prims.broadcast_in_dim(t18331, [1, 1, 1, 2048, 64], [3, 4]) # t18332: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t18338 = prims.convert_element_type(t18318, dtypes.float32) # t18338: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t18339 = prims.sum(t18338, (0, 1, 2)) # t18339: "cuda:0 f32[2048, 64]"
# t18340 = prims.convert_element_type(t18339, dtypes.bfloat16) # t18340: "cuda:0 bf16[2048, 64]"
# t18341 = prims.broadcast_in_dim(t18340, [1, 1, 1, 2048, 64], [3, 4]) # t18341: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t18347 = prims.cat((t18323, t18341, t18332), i731) # t18347: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i731, t18206, t18210, t18313
t18353 = torch.permute(t18347, (0, 3, 1, 2, 4)) # t18353: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t18353 = ltorch.permute(t18347, (0, 3, 1, 2, 4)) # t18353: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t18353 = prims.transpose(t18347, (0, 3, 1, 2, 4)) # t18353: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t18347
t18359 = torch.reshape(t18353, (1, 2048, 4672)) # t18359: "cuda:0 bf16[1, 2048, 4672]"
# t18359 = ltorch.reshape(t18353, (1, 2048, 4672)) # t18359: "cuda:0 bf16[1, 2048, 4672]"
# t18359 = prims.reshape(t18353, (1, 2048, 4672)) # t18359: "cuda:0 bf16[1, 2048, 4672]"
del t18353
t18360 = torch.reshape(t18359, (-1, 4672)) # t18360: "cuda:0 bf16[2048, 4672]"
# t18360 = ltorch.reshape(t18359, (-1, 4672)) # t18360: "cuda:0 bf16[2048, 4672]"
# t18360 = prims.reshape(t18359, (2048, 4672)) # t18360: "cuda:0 bf16[2048, 4672]"
del t18359
t18364 = torch.permute(t18360, (1, 0)) # t18364: "cuda:0 bf16[4672, 2048]"
# t18364 = ltorch.permute(t18360, (1, 0)) # t18364: "cuda:0 bf16[4672, 2048]"
# t18364 = prims.transpose(t18360, (1, 0)) # t18364: "cuda:0 bf16[4672, 2048]"
t18366 = torch.matmul(t18364, t18186) # t18366: "cuda:0 bf16[4672, 4544]"
# t18366 = ltorch.matmul(t18364, t18365) # t18366: "cuda:0 bf16[4672, 4544]"
# t18366 = prims.matmul(t18364, t18365) # t18366: "cuda:0 bf16[4672, 4544]"
del t18364, t18186
t18361 = torch.matmul(t18360, t_transformer_h_11_attn_attn_weight) # t18361: "cuda:0 bf16[2048, 4544]"
# t18361 = ltorch.matmul(t18360, t_transformer_h_11_attn_attn_weight) # t18361: "cuda:0 bf16[2048, 4544]"
# t18361 = prims.matmul(t18360, t_transformer_h_11_attn_attn_weight) # t18361: "cuda:0 bf16[2048, 4544]"
del t18360, t_transformer_h_11_attn_attn_weight
t18183 = torch.reshape(t18182, (1, 2048, 4544)) # t18183: "cuda:0 bf16[1, 2048, 4544]"
# t18183 = ltorch.reshape(t18182, (1, 2048, 4544)) # t18183: "cuda:0 bf16[1, 2048, 4544]"
# t18183 = prims.reshape(t18182, (1, 2048, 4544)) # t18183: "cuda:0 bf16[1, 2048, 4544]"
del t18182
t18362 = torch.reshape(t18361, (1, 2048, 4544)) # t18362: "cuda:0 bf16[1, 2048, 4544]"
# t18362 = ltorch.reshape(t18361, (1, 2048, 4544)) # t18362: "cuda:0 bf16[1, 2048, 4544]"
# t18362 = prims.reshape(t18361, (1, 2048, 4544)) # t18362: "cuda:0 bf16[1, 2048, 4544]"
del t18361
[t18375, t18381, t18423] = nvFusion63(i18403, t1581, t1713, t1734, t1749, t1754, t1760, t18140, t18183, t18362)
# t1740 = prims.convert_element_type(t1581, dtypes.float32) # t1740: "cuda:0 f32[1, 2048, 4544]"
# t1735 = prims.convert_element_type(t1734, dtypes.float32) # t1735: "cuda:0 f32[1, 2048, 4544]"
# t1736 = prims.convert_element_type(t1713, dtypes.float32) # t1736: "cuda:0 f32[1, 2048, 4544]"
# t1737 = prims.add(t1735, t1736) # t1737: "cuda:0 f32[1, 2048, 4544]"
# t1741 = prims.add(t1737, t1740) # t1741: "cuda:0 f32[1, 2048, 4544]"
# t1751 = prims.broadcast_in_dim(t1749, [1, 2048, 1], [0, 1]) # t1751: "cuda:0 f32[1, 2048, 1]"
# t1755 = prims.broadcast_in_dim(t1751, (1, 2048, 4544), (0, 1, 2)) # t1755: "cuda:0 f32[1, 2048, 4544]"
# t1757 = prims.sub(t1741, t1755) # t1757: "cuda:0 f32[1, 2048, 4544]"
# t1758 = prims.broadcast_in_dim(t1754, (1, 2048, 4544), (0, 1, 2)) # t1758: "cuda:0 f32[1, 2048, 4544]"
# t1759 = prims.mul(t1757, t1758) # t1759: "cuda:0 f32[1, 2048, 4544]"
# t1761 = prims.convert_element_type(t1760, dtypes.float32) # t1761: "cuda:0 f32[1, 2048, 4544]"
# t18420 = prims.convert_element_type(t18140, dtypes.float32) # t18420: "cuda:0 f32[1, 2048, 4544]"
# t18367 = prims.convert_element_type(t18183, dtypes.float32) # t18367: "cuda:0 f32[1, 2048, 4544]"
# t18368 = prims.convert_element_type(t18362, dtypes.float32) # t18368: "cuda:0 f32[1, 2048, 4544]"
# t18369 = prims.add(t18367, t18368) # t18369: "cuda:0 f32[1, 2048, 4544]"
# t18374 = prims.sum(t18369, (0, 1)) # t18374: "cuda:0 f32[4544]"
# t18375 = prims.convert_element_type(t18374, dtypes.bfloat16) # t18375: "cuda:0 bf16[4544]"
# t18376 = prims.mul(t1761, t18369) # t18376: "cuda:0 f32[1, 2048, 4544]"
# t18377 = prims.mul(t1759, t18369) # t18377: "cuda:0 f32[1, 2048, 4544]"
# t18380 = prims.sum(t18377, (0, 1)) # t18380: "cuda:0 f32[4544]"
# t18381 = prims.convert_element_type(t18380, dtypes.bfloat16) # t18381: "cuda:0 bf16[4544]"
# t18382 = prims.mul(t1758, t18376) # t18382: "cuda:0 f32[1, 2048, 4544]"
# t18383 = prims.mul(t1757, t18376) # t18383: "cuda:0 f32[1, 2048, 4544]"
# t18384 = prims.sum(t18383, (0, 2)) # t18384: "cuda:0 f32[2048]"
# t18385 = prims.broadcast_in_dim(t18384, [1, 2048, 1], [1]) # t18385: "cuda:0 f32[1, 2048, 1]"
# t18386 = prims.neg(t18382) # t18386: "cuda:0 f32[1, 2048, 4544]"
# t18388 = prims.sum(t18386, (0, 2)) # t18388: "cuda:0 f32[2048]"
# t18389 = prims.broadcast_in_dim(t18388, [1, 2048, 1], [1]) # t18389: "cuda:0 f32[1, 2048, 1]"
# t18390 = prims.mul(-0.5, t18385) # t18390: "cuda:0 f32[1, 2048, 1]"
# t18391 = prims.pow(t1754, 3.0) # t18391: "cuda:0 f32[1, 2048, 1]"
# t18392 = prims.mul(t18390, t18391) # t18392: "cuda:0 f32[1, 2048, 1]"
# t18394 = prims.sum(t18389, (0, 2)) # t18394: "cuda:0 f32[2048]"
# t18395 = prims.broadcast_in_dim(t18394, [1, 2048], [1]) # t18395: "cuda:0 f32[1, 2048]"
# t18396 = prims.sum(t18392, (0, 2)) # t18396: "cuda:0 f32[2048]"
# t18397 = prims.broadcast_in_dim(t18396, [1, 2048], [1]) # t18397: "cuda:0 f32[1, 2048]"
# t18400 = prims.broadcast_in_dim(t18395, [1, 2048, 1], [0, 1]) # t18400: "cuda:0 f32[1, 2048, 1]"
# t18401 = prims.broadcast_in_dim(t18400, (1, 2048, 4544), (0, 1, 2)) # t18401: "cuda:0 f32[1, 2048, 4544]"
# t18402 = prims.mul(0.00022007042253521127, t18401) # t18402: "cuda:0 f32[1, 2048, 4544]"
# t18404 = prims.broadcast_in_dim(t18397, [1, 2048, 1], [0, 1]) # t18404: "cuda:0 f32[1, 2048, 1]"
# t18405 = prims.broadcast_in_dim(t18404, (1, 2048, 4544), (0, 1, 2)) # t18405: "cuda:0 f32[1, 2048, 4544]"
# t18407 = prims.broadcast_in_dim(t1749, [1, 2048, 1], [0, 1]) # t18407: "cuda:0 f32[1, 2048, 1]"
# t18408 = prims.broadcast_in_dim(t18407, (1, 2048, 4544), (0, 1, 2)) # t18408: "cuda:0 f32[1, 2048, 4544]"
# t18409 = prims.mul(2.0, t18405) # t18409: "cuda:0 f32[1, 2048, 4544]"
# t18410 = prims.sub(t1741, t18408) # t18410: "cuda:0 f32[1, 2048, 4544]"
# t18411 = prims.mul(t18409, t18410) # t18411: "cuda:0 f32[1, 2048, 4544]"
# f18412 = prims.convert_element_type(i18403, float) # f18412: "float 4544.0"
# t18413 = prims.div(t18411, f18412) # t18413: "cuda:0 f32[1, 2048, 4544]"
# t18414 = prims.add(t18402, t18413) # t18414: "cuda:0 f32[1, 2048, 4544]"
# t18418 = prims.add(t18382, t18414) # t18418: "cuda:0 f32[1, 2048, 4544]"
# t18422 = prims.add(t18420, t18418) # t18422: "cuda:0 f32[1, 2048, 4544]"
# t18423 = prims.convert_element_type(t18422, dtypes.bfloat16) # t18423: "cuda:0 bf16[1, 2048, 4544]"
del i18403, t1581, t1713, t1734, t1749, t1754, t1760, t18140, t18183, t18362
t18430 = torch.reshape(t18423, (-1, 4544)) # t18430: "cuda:0 bf16[2048, 4544]"
# t18430 = ltorch.reshape(t18423, (-1, 4544)) # t18430: "cuda:0 bf16[2048, 4544]"
# t18430 = prims.reshape(t18423, (2048, 4544)) # t18430: "cuda:0 bf16[2048, 4544]"
t18434 = torch.permute(t18430, (1, 0)) # t18434: "cuda:0 bf16[4544, 2048]"
# t18434 = ltorch.permute(t18430, (1, 0)) # t18434: "cuda:0 bf16[4544, 2048]"
# t18434 = prims.transpose(t18430, (1, 0)) # t18434: "cuda:0 bf16[4544, 2048]"
t18477 = torch.matmul(t18434, t18476) # t18477: "cuda:0 bf16[4544, 4544]"
# t18477 = ltorch.matmul(t18475, t18476) # t18477: "cuda:0 bf16[4544, 4544]"
# t18477 = prims.matmul(t18475, t18476) # t18477: "cuda:0 bf16[4544, 4544]"
del t18476
t18431 = torch.matmul(t18430, t_transformer_h_10_mlp_proj_weight) # t18431: "cuda:0 bf16[2048, 18176]"
# t18431 = ltorch.matmul(t18430, t_transformer_h_10_mlp_proj_weight) # t18431: "cuda:0 bf16[2048, 18176]"
# t18431 = prims.matmul(t18430, t_transformer_h_10_mlp_proj_weight) # t18431: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_10_mlp_proj_weight
t18436 = torch.matmul(t18434, t18435) # t18436: "cuda:0 bf16[4544, 18176]"
# t18436 = ltorch.matmul(t18434, t18435) # t18436: "cuda:0 bf16[4544, 18176]"
# t18436 = prims.matmul(t18434, t18435) # t18436: "cuda:0 bf16[4544, 18176]"
del t18434, t18435
t18472 = torch.matmul(t18430, t_transformer_h_10_attn_proj_weight) # t18472: "cuda:0 bf16[2048, 4544]"
# t18472 = ltorch.matmul(t18471, t_transformer_h_10_attn_proj_weight) # t18472: "cuda:0 bf16[2048, 4544]"
# t18472 = prims.matmul(t18471, t_transformer_h_10_attn_proj_weight) # t18472: "cuda:0 bf16[2048, 4544]"
del t18430, t_transformer_h_10_attn_proj_weight
t18432 = torch.reshape(t18431, (1, 2048, 18176)) # t18432: "cuda:0 bf16[1, 2048, 18176]"
# t18432 = ltorch.reshape(t18431, (1, 2048, 18176)) # t18432: "cuda:0 bf16[1, 2048, 18176]"
# t18432 = prims.reshape(t18431, (1, 2048, 18176)) # t18432: "cuda:0 bf16[1, 2048, 18176]"
del t18431
t18473 = torch.reshape(t18472, (1, 2048, 4544)) # t18473: "cuda:0 bf16[1, 2048, 4544]"
# t18473 = ltorch.reshape(t18472, (1, 2048, 4544)) # t18473: "cuda:0 bf16[1, 2048, 4544]"
# t18473 = prims.reshape(t18472, (1, 2048, 4544)) # t18473: "cuda:0 bf16[1, 2048, 4544]"
del t18472
t18481 = torch.reshape(t18473, (1, 2048, 71, 64)) # t18481: "cuda:0 bf16[1, 2048, 71, 64]"
# t18481 = ltorch.reshape(t18473, (1, 2048, 71, 64)) # t18481: "cuda:0 bf16[1, 2048, 71, 64]"
# t18481 = prims.reshape(t18473, (1, 2048, 71, 64)) # t18481: "cuda:0 bf16[1, 2048, 71, 64]"
del t18473
t18484 = torch.permute(t18481, (0, 2, 1, 3)) # t18484: "cuda:0 bf16[1, 71, 2048, 64]"
# t18484 = ltorch.permute(t18481, (0, 2, 1, 3)) # t18484: "cuda:0 bf16[1, 71, 2048, 64]"
# t18484 = prims.transpose(t18481, (0, 2, 1, 3)) # t18484: "cuda:0 bf16[1, 71, 2048, 64]"
del t18481
[t18463] = nvFusion64(f703, f705, t1714, t18432)
# t1715 = prims.convert_element_type(t1714, dtypes.float32) # t1715: "cuda:0 f32[1, 2048, 18176]"
# t1717 = prims.div(t1715, 1.4142135623730951) # t1717: "cuda:0 f32[1, 2048, 18176]"
# t1720 = prims.erf(t1717) # t1720: "cuda:0 f32[1, 2048, 18176]"
# t1724 = prims.mul(0.5, t1720) # t1724: "cuda:0 f32[1, 2048, 18176]"
# t1728 = prims.add(0.5, t1724) # t1728: "cuda:0 f32[1, 2048, 18176]"
# t18437 = prims.convert_element_type(t18432, dtypes.float32) # t18437: "cuda:0 f32[1, 2048, 18176]"
# t18438 = prims.mul(t1728, t18437) # t18438: "cuda:0 f32[1, 2048, 18176]"
# t18439 = prims.mul(t1715, t18437) # t18439: "cuda:0 f32[1, 2048, 18176]"
# t18447 = prims.mul(f705, t18439) # t18447: "cuda:0 f32[1, 2048, 18176]"
# t18450 = prims.pow(t1717, 2.0) # t18450: "cuda:0 f32[1, 2048, 18176]"
# t18451 = prims.neg(t18450) # t18451: "cuda:0 f32[1, 2048, 18176]"
# t18452 = prims.exp(t18451) # t18452: "cuda:0 f32[1, 2048, 18176]"
# t18453 = prims.mul(1.1283791670955126, t18452) # t18453: "cuda:0 f32[1, 2048, 18176]"
# t18454 = prims.mul(t18453, t18447) # t18454: "cuda:0 f32[1, 2048, 18176]"
# t18458 = prims.div(t18454, f703) # t18458: "cuda:0 f32[1, 2048, 18176]"
# t18462 = prims.add(t18438, t18458) # t18462: "cuda:0 f32[1, 2048, 18176]"
# t18463 = prims.convert_element_type(t18462, dtypes.bfloat16) # t18463: "cuda:0 bf16[1, 2048, 18176]"
del f703, f705, t1714, t18432
t18464 = torch.reshape(t18463, (-1, 18176)) # t18464: "cuda:0 bf16[2048, 18176]"
# t18464 = ltorch.reshape(t18463, (-1, 18176)) # t18464: "cuda:0 bf16[2048, 18176]"
# t18464 = prims.reshape(t18463, (2048, 18176)) # t18464: "cuda:0 bf16[2048, 18176]"
del t18463
t18468 = torch.permute(t18464, (1, 0)) # t18468: "cuda:0 bf16[18176, 2048]"
# t18468 = ltorch.permute(t18464, (1, 0)) # t18468: "cuda:0 bf16[18176, 2048]"
# t18468 = prims.transpose(t18464, (1, 0)) # t18468: "cuda:0 bf16[18176, 2048]"
t18470 = torch.matmul(t18468, t18469) # t18470: "cuda:0 bf16[18176, 4544]"
# t18470 = ltorch.matmul(t18468, t18469) # t18470: "cuda:0 bf16[18176, 4544]"
# t18470 = prims.matmul(t18468, t18469) # t18470: "cuda:0 bf16[18176, 4544]"
del t18468
t18465 = torch.matmul(t18464, t_transformer_h_10_mlp_fc_weight) # t18465: "cuda:0 bf16[2048, 4544]"
# t18465 = ltorch.matmul(t18464, t_transformer_h_10_mlp_fc_weight) # t18465: "cuda:0 bf16[2048, 4544]"
# t18465 = prims.matmul(t18464, t_transformer_h_10_mlp_fc_weight) # t18465: "cuda:0 bf16[2048, 4544]"
del t18464, t_transformer_h_10_mlp_fc_weight
(t18485, t18486, t18487) = cudnn_sdpa_bwd(t18484, t1698, t1701, t1651, None, f694, b695, t1702, t1703, t1704, t1705, scale=f696, cat_grad_qkv=False)
del t18484, t1698, t1701, t1651, f694, b695, t1702, t1703, t1704, t1705, f696
t18489 = torch_slice_prim_impl(t18486, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18489: "cuda:0 bf16[1, 71, 2048, 64]"
del t18486
t18493 = torch_slice_prim_impl(t18485, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18493: "cuda:0 bf16[1, 71, 2048, 64]"
del t18485
t18596 = torch.reshape(t18487, (1, 1, 71, 2048, 64)) # t18596: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18596 = ltorch.reshape(t18487, (1, 1, 71, 2048, 64)) # t18596: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18596 = prims.reshape(t18487, (1, 1, 71, 2048, 64)) # t18596: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t18487
[t18630] = nvFusion65(i667, t18489, t18493, t18596, t61, t66)
# t18490 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t18490: "cuda:0 bf16[1, 71, 2048, 0]"
# t18491 = prims.pad(t18490, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t18491: "cuda:0 bf16[1, 71, 2048, 64]"
# t18494 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t18494: "cuda:0 bf16[1, 71, 2048, 0]"
# t18495 = prims.pad(t18494, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t18495: "cuda:0 bf16[1, 71, 2048, 64]"
# t18496 = prims.convert_element_type(t18489, dtypes.float32) # t18496: "cuda:0 f32[1, 71, 2048, 64]"
# t18500 = prims.mul(t66, t18496) # t18500: "cuda:0 f32[1, 71, 2048, 64]"
# t18503 = prims.convert_element_type(t18500, dtypes.bfloat16) # t18503: "cuda:0 bf16[1, 71, 2048, 64]"
# t18512 = prims.mul(t61, t18496) # t18512: "cuda:0 f32[1, 71, 2048, 64]"
# t18524 = prims.slice_prim(t18503, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t18524: "cuda:0 bf16[1, 71, 2048, 32]"
# t18525 = prims.slice_prim(t18503, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18525: "cuda:0 bf16[1, 71, 2048, 32]"
# t18526 = prims.convert_element_type(t18524, dtypes.float32) # t18526: "cuda:0 f32[1, 71, 2048, 32]"
# t18527 = prims.neg(t18526) # t18527: "cuda:0 f32[1, 71, 2048, 32]"
# t18528 = prims.convert_element_type(t18527, dtypes.bfloat16) # t18528: "cuda:0 bf16[1, 71, 2048, 32]"
# t18529 = prims.pad(t18528, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t18529: "cuda:0 bf16[1, 71, 2048, 64]"
# t18531 = prims.convert_element_type(t18529, dtypes.float32) # t18531: "cuda:0 f32[1, 71, 2048, 64]"
# t18532 = prims.add(t18512, t18531) # t18532: "cuda:0 f32[1, 71, 2048, 64]"
# t18534 = prims.pad(t18525, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t18534: "cuda:0 bf16[1, 71, 2048, 64]"
# t18536 = prims.convert_element_type(t18534, dtypes.float32) # t18536: "cuda:0 f32[1, 71, 2048, 64]"
# t18537 = prims.add(t18532, t18536) # t18537: "cuda:0 f32[1, 71, 2048, 64]"
# t18538 = prims.convert_element_type(t18537, dtypes.bfloat16) # t18538: "cuda:0 bf16[1, 71, 2048, 64]"
# t18539 = prims.pad(t18538, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t18539: "cuda:0 bf16[1, 71, 2048, 64]"
# t18540 = prims.convert_element_type(t18491, dtypes.float32) # t18540: "cuda:0 f32[1, 71, 2048, 64]"
# t18541 = prims.convert_element_type(t18539, dtypes.float32) # t18541: "cuda:0 f32[1, 71, 2048, 64]"
# t18542 = prims.add(t18540, t18541) # t18542: "cuda:0 f32[1, 71, 2048, 64]"
# t18543 = prims.convert_element_type(t18542, dtypes.bfloat16) # t18543: "cuda:0 bf16[1, 71, 2048, 64]"
# t18544 = prims.convert_element_type(t18493, dtypes.float32) # t18544: "cuda:0 f32[1, 71, 2048, 64]"
# t18548 = prims.mul(t66, t18544) # t18548: "cuda:0 f32[1, 71, 2048, 64]"
# t18551 = prims.convert_element_type(t18548, dtypes.bfloat16) # t18551: "cuda:0 bf16[1, 71, 2048, 64]"
# t18560 = prims.mul(t61, t18544) # t18560: "cuda:0 f32[1, 71, 2048, 64]"
# t18572 = prims.slice_prim(t18551, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t18572: "cuda:0 bf16[1, 71, 2048, 32]"
# t18573 = prims.slice_prim(t18551, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18573: "cuda:0 bf16[1, 71, 2048, 32]"
# t18574 = prims.convert_element_type(t18572, dtypes.float32) # t18574: "cuda:0 f32[1, 71, 2048, 32]"
# t18575 = prims.neg(t18574) # t18575: "cuda:0 f32[1, 71, 2048, 32]"
# t18576 = prims.convert_element_type(t18575, dtypes.bfloat16) # t18576: "cuda:0 bf16[1, 71, 2048, 32]"
# t18577 = prims.pad(t18576, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t18577: "cuda:0 bf16[1, 71, 2048, 64]"
# t18579 = prims.convert_element_type(t18577, dtypes.float32) # t18579: "cuda:0 f32[1, 71, 2048, 64]"
# t18580 = prims.add(t18560, t18579) # t18580: "cuda:0 f32[1, 71, 2048, 64]"
# t18582 = prims.pad(t18573, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t18582: "cuda:0 bf16[1, 71, 2048, 64]"
# t18584 = prims.convert_element_type(t18582, dtypes.float32) # t18584: "cuda:0 f32[1, 71, 2048, 64]"
# t18585 = prims.add(t18580, t18584) # t18585: "cuda:0 f32[1, 71, 2048, 64]"
# t18586 = prims.convert_element_type(t18585, dtypes.bfloat16) # t18586: "cuda:0 bf16[1, 71, 2048, 64]"
# t18587 = prims.pad(t18586, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t18587: "cuda:0 bf16[1, 71, 2048, 64]"
# t18588 = prims.convert_element_type(t18495, dtypes.float32) # t18588: "cuda:0 f32[1, 71, 2048, 64]"
# t18589 = prims.convert_element_type(t18587, dtypes.float32) # t18589: "cuda:0 f32[1, 71, 2048, 64]"
# t18590 = prims.add(t18588, t18589) # t18590: "cuda:0 f32[1, 71, 2048, 64]"
# t18591 = prims.convert_element_type(t18590, dtypes.bfloat16) # t18591: "cuda:0 bf16[1, 71, 2048, 64]"
# t18601 = prims.reshape(t18543, (1, 1, 71, 2048, 64)) # t18601: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18606 = prims.reshape(t18591, (1, 1, 71, 2048, 64)) # t18606: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18612 = prims.convert_element_type(t18596, dtypes.float32) # t18612: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t18613 = prims.sum(t18612, (0, 1, 2)) # t18613: "cuda:0 f32[2048, 64]"
# t18614 = prims.convert_element_type(t18613, dtypes.bfloat16) # t18614: "cuda:0 bf16[2048, 64]"
# t18615 = prims.broadcast_in_dim(t18614, [1, 1, 1, 2048, 64], [3, 4]) # t18615: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t18621 = prims.convert_element_type(t18601, dtypes.float32) # t18621: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t18622 = prims.sum(t18621, (0, 1, 2)) # t18622: "cuda:0 f32[2048, 64]"
# t18623 = prims.convert_element_type(t18622, dtypes.bfloat16) # t18623: "cuda:0 bf16[2048, 64]"
# t18624 = prims.broadcast_in_dim(t18623, [1, 1, 1, 2048, 64], [3, 4]) # t18624: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t18630 = prims.cat((t18606, t18624, t18615), i667) # t18630: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i667, t18489, t18493, t18596
t18636 = torch.permute(t18630, (0, 3, 1, 2, 4)) # t18636: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t18636 = ltorch.permute(t18630, (0, 3, 1, 2, 4)) # t18636: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t18636 = prims.transpose(t18630, (0, 3, 1, 2, 4)) # t18636: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t18630
t18642 = torch.reshape(t18636, (1, 2048, 4672)) # t18642: "cuda:0 bf16[1, 2048, 4672]"
# t18642 = ltorch.reshape(t18636, (1, 2048, 4672)) # t18642: "cuda:0 bf16[1, 2048, 4672]"
# t18642 = prims.reshape(t18636, (1, 2048, 4672)) # t18642: "cuda:0 bf16[1, 2048, 4672]"
del t18636
t18643 = torch.reshape(t18642, (-1, 4672)) # t18643: "cuda:0 bf16[2048, 4672]"
# t18643 = ltorch.reshape(t18642, (-1, 4672)) # t18643: "cuda:0 bf16[2048, 4672]"
# t18643 = prims.reshape(t18642, (2048, 4672)) # t18643: "cuda:0 bf16[2048, 4672]"
del t18642
t18647 = torch.permute(t18643, (1, 0)) # t18647: "cuda:0 bf16[4672, 2048]"
# t18647 = ltorch.permute(t18643, (1, 0)) # t18647: "cuda:0 bf16[4672, 2048]"
# t18647 = prims.transpose(t18643, (1, 0)) # t18647: "cuda:0 bf16[4672, 2048]"
t18649 = torch.matmul(t18647, t18469) # t18649: "cuda:0 bf16[4672, 4544]"
# t18649 = ltorch.matmul(t18647, t18648) # t18649: "cuda:0 bf16[4672, 4544]"
# t18649 = prims.matmul(t18647, t18648) # t18649: "cuda:0 bf16[4672, 4544]"
del t18647, t18469
t18644 = torch.matmul(t18643, t_transformer_h_10_attn_attn_weight) # t18644: "cuda:0 bf16[2048, 4544]"
# t18644 = ltorch.matmul(t18643, t_transformer_h_10_attn_attn_weight) # t18644: "cuda:0 bf16[2048, 4544]"
# t18644 = prims.matmul(t18643, t_transformer_h_10_attn_attn_weight) # t18644: "cuda:0 bf16[2048, 4544]"
del t18643, t_transformer_h_10_attn_attn_weight
t18466 = torch.reshape(t18465, (1, 2048, 4544)) # t18466: "cuda:0 bf16[1, 2048, 4544]"
# t18466 = ltorch.reshape(t18465, (1, 2048, 4544)) # t18466: "cuda:0 bf16[1, 2048, 4544]"
# t18466 = prims.reshape(t18465, (1, 2048, 4544)) # t18466: "cuda:0 bf16[1, 2048, 4544]"
del t18465
t18645 = torch.reshape(t18644, (1, 2048, 4544)) # t18645: "cuda:0 bf16[1, 2048, 4544]"
# t18645 = ltorch.reshape(t18644, (1, 2048, 4544)) # t18645: "cuda:0 bf16[1, 2048, 4544]"
# t18645 = prims.reshape(t18644, (1, 2048, 4544)) # t18645: "cuda:0 bf16[1, 2048, 4544]"
del t18644
[t18658, t18664, t18706] = nvFusion66(i18686, t1420, t1552, t1573, t1588, t1593, t1599, t18423, t18466, t18645)
# t1579 = prims.convert_element_type(t1420, dtypes.float32) # t1579: "cuda:0 f32[1, 2048, 4544]"
# t1574 = prims.convert_element_type(t1573, dtypes.float32) # t1574: "cuda:0 f32[1, 2048, 4544]"
# t1575 = prims.convert_element_type(t1552, dtypes.float32) # t1575: "cuda:0 f32[1, 2048, 4544]"
# t1576 = prims.add(t1574, t1575) # t1576: "cuda:0 f32[1, 2048, 4544]"
# t1580 = prims.add(t1576, t1579) # t1580: "cuda:0 f32[1, 2048, 4544]"
# t1590 = prims.broadcast_in_dim(t1588, [1, 2048, 1], [0, 1]) # t1590: "cuda:0 f32[1, 2048, 1]"
# t1594 = prims.broadcast_in_dim(t1590, (1, 2048, 4544), (0, 1, 2)) # t1594: "cuda:0 f32[1, 2048, 4544]"
# t1596 = prims.sub(t1580, t1594) # t1596: "cuda:0 f32[1, 2048, 4544]"
# t1597 = prims.broadcast_in_dim(t1593, (1, 2048, 4544), (0, 1, 2)) # t1597: "cuda:0 f32[1, 2048, 4544]"
# t1598 = prims.mul(t1596, t1597) # t1598: "cuda:0 f32[1, 2048, 4544]"
# t1600 = prims.convert_element_type(t1599, dtypes.float32) # t1600: "cuda:0 f32[1, 2048, 4544]"
# t18703 = prims.convert_element_type(t18423, dtypes.float32) # t18703: "cuda:0 f32[1, 2048, 4544]"
# t18650 = prims.convert_element_type(t18466, dtypes.float32) # t18650: "cuda:0 f32[1, 2048, 4544]"
# t18651 = prims.convert_element_type(t18645, dtypes.float32) # t18651: "cuda:0 f32[1, 2048, 4544]"
# t18652 = prims.add(t18650, t18651) # t18652: "cuda:0 f32[1, 2048, 4544]"
# t18657 = prims.sum(t18652, (0, 1)) # t18657: "cuda:0 f32[4544]"
# t18658 = prims.convert_element_type(t18657, dtypes.bfloat16) # t18658: "cuda:0 bf16[4544]"
# t18659 = prims.mul(t1600, t18652) # t18659: "cuda:0 f32[1, 2048, 4544]"
# t18660 = prims.mul(t1598, t18652) # t18660: "cuda:0 f32[1, 2048, 4544]"
# t18663 = prims.sum(t18660, (0, 1)) # t18663: "cuda:0 f32[4544]"
# t18664 = prims.convert_element_type(t18663, dtypes.bfloat16) # t18664: "cuda:0 bf16[4544]"
# t18665 = prims.mul(t1597, t18659) # t18665: "cuda:0 f32[1, 2048, 4544]"
# t18666 = prims.mul(t1596, t18659) # t18666: "cuda:0 f32[1, 2048, 4544]"
# t18667 = prims.sum(t18666, (0, 2)) # t18667: "cuda:0 f32[2048]"
# t18668 = prims.broadcast_in_dim(t18667, [1, 2048, 1], [1]) # t18668: "cuda:0 f32[1, 2048, 1]"
# t18669 = prims.neg(t18665) # t18669: "cuda:0 f32[1, 2048, 4544]"
# t18671 = prims.sum(t18669, (0, 2)) # t18671: "cuda:0 f32[2048]"
# t18672 = prims.broadcast_in_dim(t18671, [1, 2048, 1], [1]) # t18672: "cuda:0 f32[1, 2048, 1]"
# t18673 = prims.mul(-0.5, t18668) # t18673: "cuda:0 f32[1, 2048, 1]"
# t18674 = prims.pow(t1593, 3.0) # t18674: "cuda:0 f32[1, 2048, 1]"
# t18675 = prims.mul(t18673, t18674) # t18675: "cuda:0 f32[1, 2048, 1]"
# t18677 = prims.sum(t18672, (0, 2)) # t18677: "cuda:0 f32[2048]"
# t18678 = prims.broadcast_in_dim(t18677, [1, 2048], [1]) # t18678: "cuda:0 f32[1, 2048]"
# t18679 = prims.sum(t18675, (0, 2)) # t18679: "cuda:0 f32[2048]"
# t18680 = prims.broadcast_in_dim(t18679, [1, 2048], [1]) # t18680: "cuda:0 f32[1, 2048]"
# t18683 = prims.broadcast_in_dim(t18678, [1, 2048, 1], [0, 1]) # t18683: "cuda:0 f32[1, 2048, 1]"
# t18684 = prims.broadcast_in_dim(t18683, (1, 2048, 4544), (0, 1, 2)) # t18684: "cuda:0 f32[1, 2048, 4544]"
# t18685 = prims.mul(0.00022007042253521127, t18684) # t18685: "cuda:0 f32[1, 2048, 4544]"
# t18687 = prims.broadcast_in_dim(t18680, [1, 2048, 1], [0, 1]) # t18687: "cuda:0 f32[1, 2048, 1]"
# t18688 = prims.broadcast_in_dim(t18687, (1, 2048, 4544), (0, 1, 2)) # t18688: "cuda:0 f32[1, 2048, 4544]"
# t18690 = prims.broadcast_in_dim(t1588, [1, 2048, 1], [0, 1]) # t18690: "cuda:0 f32[1, 2048, 1]"
# t18691 = prims.broadcast_in_dim(t18690, (1, 2048, 4544), (0, 1, 2)) # t18691: "cuda:0 f32[1, 2048, 4544]"
# t18692 = prims.mul(2.0, t18688) # t18692: "cuda:0 f32[1, 2048, 4544]"
# t18693 = prims.sub(t1580, t18691) # t18693: "cuda:0 f32[1, 2048, 4544]"
# t18694 = prims.mul(t18692, t18693) # t18694: "cuda:0 f32[1, 2048, 4544]"
# f18695 = prims.convert_element_type(i18686, float) # f18695: "float 4544.0"
# t18696 = prims.div(t18694, f18695) # t18696: "cuda:0 f32[1, 2048, 4544]"
# t18697 = prims.add(t18685, t18696) # t18697: "cuda:0 f32[1, 2048, 4544]"
# t18701 = prims.add(t18665, t18697) # t18701: "cuda:0 f32[1, 2048, 4544]"
# t18705 = prims.add(t18703, t18701) # t18705: "cuda:0 f32[1, 2048, 4544]"
# t18706 = prims.convert_element_type(t18705, dtypes.bfloat16) # t18706: "cuda:0 bf16[1, 2048, 4544]"
del i18686, t1420, t1552, t1573, t1588, t1593, t1599, t18423, t18466, t18645
t18713 = torch.reshape(t18706, (-1, 4544)) # t18713: "cuda:0 bf16[2048, 4544]"
# t18713 = ltorch.reshape(t18706, (-1, 4544)) # t18713: "cuda:0 bf16[2048, 4544]"
# t18713 = prims.reshape(t18706, (2048, 4544)) # t18713: "cuda:0 bf16[2048, 4544]"
t18717 = torch.permute(t18713, (1, 0)) # t18717: "cuda:0 bf16[4544, 2048]"
# t18717 = ltorch.permute(t18713, (1, 0)) # t18717: "cuda:0 bf16[4544, 2048]"
# t18717 = prims.transpose(t18713, (1, 0)) # t18717: "cuda:0 bf16[4544, 2048]"
t18714 = torch.matmul(t18713, t_transformer_h_9_mlp_proj_weight) # t18714: "cuda:0 bf16[2048, 18176]"
# t18714 = ltorch.matmul(t18713, t_transformer_h_9_mlp_proj_weight) # t18714: "cuda:0 bf16[2048, 18176]"
# t18714 = prims.matmul(t18713, t_transformer_h_9_mlp_proj_weight) # t18714: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_9_mlp_proj_weight
t18719 = torch.matmul(t18717, t18718) # t18719: "cuda:0 bf16[4544, 18176]"
# t18719 = ltorch.matmul(t18717, t18718) # t18719: "cuda:0 bf16[4544, 18176]"
# t18719 = prims.matmul(t18717, t18718) # t18719: "cuda:0 bf16[4544, 18176]"
del t18718
t18755 = torch.matmul(t18713, t_transformer_h_9_attn_proj_weight) # t18755: "cuda:0 bf16[2048, 4544]"
# t18755 = ltorch.matmul(t18754, t_transformer_h_9_attn_proj_weight) # t18755: "cuda:0 bf16[2048, 4544]"
# t18755 = prims.matmul(t18754, t_transformer_h_9_attn_proj_weight) # t18755: "cuda:0 bf16[2048, 4544]"
del t18713, t_transformer_h_9_attn_proj_weight
t18760 = torch.matmul(t18717, t18759) # t18760: "cuda:0 bf16[4544, 4544]"
# t18760 = ltorch.matmul(t18758, t18759) # t18760: "cuda:0 bf16[4544, 4544]"
# t18760 = prims.matmul(t18758, t18759) # t18760: "cuda:0 bf16[4544, 4544]"
del t18717, t18759
t18715 = torch.reshape(t18714, (1, 2048, 18176)) # t18715: "cuda:0 bf16[1, 2048, 18176]"
# t18715 = ltorch.reshape(t18714, (1, 2048, 18176)) # t18715: "cuda:0 bf16[1, 2048, 18176]"
# t18715 = prims.reshape(t18714, (1, 2048, 18176)) # t18715: "cuda:0 bf16[1, 2048, 18176]"
del t18714
t18756 = torch.reshape(t18755, (1, 2048, 4544)) # t18756: "cuda:0 bf16[1, 2048, 4544]"
# t18756 = ltorch.reshape(t18755, (1, 2048, 4544)) # t18756: "cuda:0 bf16[1, 2048, 4544]"
# t18756 = prims.reshape(t18755, (1, 2048, 4544)) # t18756: "cuda:0 bf16[1, 2048, 4544]"
del t18755
t18764 = torch.reshape(t18756, (1, 2048, 71, 64)) # t18764: "cuda:0 bf16[1, 2048, 71, 64]"
# t18764 = ltorch.reshape(t18756, (1, 2048, 71, 64)) # t18764: "cuda:0 bf16[1, 2048, 71, 64]"
# t18764 = prims.reshape(t18756, (1, 2048, 71, 64)) # t18764: "cuda:0 bf16[1, 2048, 71, 64]"
del t18756
t18767 = torch.permute(t18764, (0, 2, 1, 3)) # t18767: "cuda:0 bf16[1, 71, 2048, 64]"
# t18767 = ltorch.permute(t18764, (0, 2, 1, 3)) # t18767: "cuda:0 bf16[1, 71, 2048, 64]"
# t18767 = prims.transpose(t18764, (0, 2, 1, 3)) # t18767: "cuda:0 bf16[1, 71, 2048, 64]"
del t18764
[t18746] = nvFusion67(f639, f641, t1553, t18715)
# t1554 = prims.convert_element_type(t1553, dtypes.float32) # t1554: "cuda:0 f32[1, 2048, 18176]"
# t1556 = prims.div(t1554, 1.4142135623730951) # t1556: "cuda:0 f32[1, 2048, 18176]"
# t1559 = prims.erf(t1556) # t1559: "cuda:0 f32[1, 2048, 18176]"
# t1563 = prims.mul(0.5, t1559) # t1563: "cuda:0 f32[1, 2048, 18176]"
# t1567 = prims.add(0.5, t1563) # t1567: "cuda:0 f32[1, 2048, 18176]"
# t18720 = prims.convert_element_type(t18715, dtypes.float32) # t18720: "cuda:0 f32[1, 2048, 18176]"
# t18721 = prims.mul(t1567, t18720) # t18721: "cuda:0 f32[1, 2048, 18176]"
# t18722 = prims.mul(t1554, t18720) # t18722: "cuda:0 f32[1, 2048, 18176]"
# t18730 = prims.mul(f641, t18722) # t18730: "cuda:0 f32[1, 2048, 18176]"
# t18733 = prims.pow(t1556, 2.0) # t18733: "cuda:0 f32[1, 2048, 18176]"
# t18734 = prims.neg(t18733) # t18734: "cuda:0 f32[1, 2048, 18176]"
# t18735 = prims.exp(t18734) # t18735: "cuda:0 f32[1, 2048, 18176]"
# t18736 = prims.mul(1.1283791670955126, t18735) # t18736: "cuda:0 f32[1, 2048, 18176]"
# t18737 = prims.mul(t18736, t18730) # t18737: "cuda:0 f32[1, 2048, 18176]"
# t18741 = prims.div(t18737, f639) # t18741: "cuda:0 f32[1, 2048, 18176]"
# t18745 = prims.add(t18721, t18741) # t18745: "cuda:0 f32[1, 2048, 18176]"
# t18746 = prims.convert_element_type(t18745, dtypes.bfloat16) # t18746: "cuda:0 bf16[1, 2048, 18176]"
del f639, f641, t1553, t18715
t18747 = torch.reshape(t18746, (-1, 18176)) # t18747: "cuda:0 bf16[2048, 18176]"
# t18747 = ltorch.reshape(t18746, (-1, 18176)) # t18747: "cuda:0 bf16[2048, 18176]"
# t18747 = prims.reshape(t18746, (2048, 18176)) # t18747: "cuda:0 bf16[2048, 18176]"
del t18746
t18751 = torch.permute(t18747, (1, 0)) # t18751: "cuda:0 bf16[18176, 2048]"
# t18751 = ltorch.permute(t18747, (1, 0)) # t18751: "cuda:0 bf16[18176, 2048]"
# t18751 = prims.transpose(t18747, (1, 0)) # t18751: "cuda:0 bf16[18176, 2048]"
t18753 = torch.matmul(t18751, t18752) # t18753: "cuda:0 bf16[18176, 4544]"
# t18753 = ltorch.matmul(t18751, t18752) # t18753: "cuda:0 bf16[18176, 4544]"
# t18753 = prims.matmul(t18751, t18752) # t18753: "cuda:0 bf16[18176, 4544]"
del t18751
t18748 = torch.matmul(t18747, t_transformer_h_9_mlp_fc_weight) # t18748: "cuda:0 bf16[2048, 4544]"
# t18748 = ltorch.matmul(t18747, t_transformer_h_9_mlp_fc_weight) # t18748: "cuda:0 bf16[2048, 4544]"
# t18748 = prims.matmul(t18747, t_transformer_h_9_mlp_fc_weight) # t18748: "cuda:0 bf16[2048, 4544]"
del t18747, t_transformer_h_9_mlp_fc_weight
(t18768, t18769, t18770) = cudnn_sdpa_bwd(t18767, t1537, t1540, t1490, None, f630, b631, t1541, t1542, t1543, t1544, scale=f632, cat_grad_qkv=False)
del t18767, t1537, t1540, t1490, f630, b631, t1541, t1542, t1543, t1544, f632
t18772 = torch_slice_prim_impl(t18769, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18772: "cuda:0 bf16[1, 71, 2048, 64]"
del t18769
t18776 = torch_slice_prim_impl(t18768, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18776: "cuda:0 bf16[1, 71, 2048, 64]"
del t18768
t18879 = torch.reshape(t18770, (1, 1, 71, 2048, 64)) # t18879: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18879 = ltorch.reshape(t18770, (1, 1, 71, 2048, 64)) # t18879: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18879 = prims.reshape(t18770, (1, 1, 71, 2048, 64)) # t18879: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t18770
[t18913] = nvFusion68(i603, t18772, t18776, t18879, t61, t66)
# t18773 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t18773: "cuda:0 bf16[1, 71, 2048, 0]"
# t18774 = prims.pad(t18773, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t18774: "cuda:0 bf16[1, 71, 2048, 64]"
# t18777 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t18777: "cuda:0 bf16[1, 71, 2048, 0]"
# t18778 = prims.pad(t18777, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t18778: "cuda:0 bf16[1, 71, 2048, 64]"
# t18779 = prims.convert_element_type(t18772, dtypes.float32) # t18779: "cuda:0 f32[1, 71, 2048, 64]"
# t18783 = prims.mul(t66, t18779) # t18783: "cuda:0 f32[1, 71, 2048, 64]"
# t18786 = prims.convert_element_type(t18783, dtypes.bfloat16) # t18786: "cuda:0 bf16[1, 71, 2048, 64]"
# t18795 = prims.mul(t61, t18779) # t18795: "cuda:0 f32[1, 71, 2048, 64]"
# t18807 = prims.slice_prim(t18786, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t18807: "cuda:0 bf16[1, 71, 2048, 32]"
# t18808 = prims.slice_prim(t18786, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18808: "cuda:0 bf16[1, 71, 2048, 32]"
# t18809 = prims.convert_element_type(t18807, dtypes.float32) # t18809: "cuda:0 f32[1, 71, 2048, 32]"
# t18810 = prims.neg(t18809) # t18810: "cuda:0 f32[1, 71, 2048, 32]"
# t18811 = prims.convert_element_type(t18810, dtypes.bfloat16) # t18811: "cuda:0 bf16[1, 71, 2048, 32]"
# t18812 = prims.pad(t18811, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t18812: "cuda:0 bf16[1, 71, 2048, 64]"
# t18814 = prims.convert_element_type(t18812, dtypes.float32) # t18814: "cuda:0 f32[1, 71, 2048, 64]"
# t18815 = prims.add(t18795, t18814) # t18815: "cuda:0 f32[1, 71, 2048, 64]"
# t18817 = prims.pad(t18808, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t18817: "cuda:0 bf16[1, 71, 2048, 64]"
# t18819 = prims.convert_element_type(t18817, dtypes.float32) # t18819: "cuda:0 f32[1, 71, 2048, 64]"
# t18820 = prims.add(t18815, t18819) # t18820: "cuda:0 f32[1, 71, 2048, 64]"
# t18821 = prims.convert_element_type(t18820, dtypes.bfloat16) # t18821: "cuda:0 bf16[1, 71, 2048, 64]"
# t18822 = prims.pad(t18821, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t18822: "cuda:0 bf16[1, 71, 2048, 64]"
# t18823 = prims.convert_element_type(t18774, dtypes.float32) # t18823: "cuda:0 f32[1, 71, 2048, 64]"
# t18824 = prims.convert_element_type(t18822, dtypes.float32) # t18824: "cuda:0 f32[1, 71, 2048, 64]"
# t18825 = prims.add(t18823, t18824) # t18825: "cuda:0 f32[1, 71, 2048, 64]"
# t18826 = prims.convert_element_type(t18825, dtypes.bfloat16) # t18826: "cuda:0 bf16[1, 71, 2048, 64]"
# t18827 = prims.convert_element_type(t18776, dtypes.float32) # t18827: "cuda:0 f32[1, 71, 2048, 64]"
# t18831 = prims.mul(t66, t18827) # t18831: "cuda:0 f32[1, 71, 2048, 64]"
# t18834 = prims.convert_element_type(t18831, dtypes.bfloat16) # t18834: "cuda:0 bf16[1, 71, 2048, 64]"
# t18843 = prims.mul(t61, t18827) # t18843: "cuda:0 f32[1, 71, 2048, 64]"
# t18855 = prims.slice_prim(t18834, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t18855: "cuda:0 bf16[1, 71, 2048, 32]"
# t18856 = prims.slice_prim(t18834, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t18856: "cuda:0 bf16[1, 71, 2048, 32]"
# t18857 = prims.convert_element_type(t18855, dtypes.float32) # t18857: "cuda:0 f32[1, 71, 2048, 32]"
# t18858 = prims.neg(t18857) # t18858: "cuda:0 f32[1, 71, 2048, 32]"
# t18859 = prims.convert_element_type(t18858, dtypes.bfloat16) # t18859: "cuda:0 bf16[1, 71, 2048, 32]"
# t18860 = prims.pad(t18859, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t18860: "cuda:0 bf16[1, 71, 2048, 64]"
# t18862 = prims.convert_element_type(t18860, dtypes.float32) # t18862: "cuda:0 f32[1, 71, 2048, 64]"
# t18863 = prims.add(t18843, t18862) # t18863: "cuda:0 f32[1, 71, 2048, 64]"
# t18865 = prims.pad(t18856, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t18865: "cuda:0 bf16[1, 71, 2048, 64]"
# t18867 = prims.convert_element_type(t18865, dtypes.float32) # t18867: "cuda:0 f32[1, 71, 2048, 64]"
# t18868 = prims.add(t18863, t18867) # t18868: "cuda:0 f32[1, 71, 2048, 64]"
# t18869 = prims.convert_element_type(t18868, dtypes.bfloat16) # t18869: "cuda:0 bf16[1, 71, 2048, 64]"
# t18870 = prims.pad(t18869, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t18870: "cuda:0 bf16[1, 71, 2048, 64]"
# t18871 = prims.convert_element_type(t18778, dtypes.float32) # t18871: "cuda:0 f32[1, 71, 2048, 64]"
# t18872 = prims.convert_element_type(t18870, dtypes.float32) # t18872: "cuda:0 f32[1, 71, 2048, 64]"
# t18873 = prims.add(t18871, t18872) # t18873: "cuda:0 f32[1, 71, 2048, 64]"
# t18874 = prims.convert_element_type(t18873, dtypes.bfloat16) # t18874: "cuda:0 bf16[1, 71, 2048, 64]"
# t18884 = prims.reshape(t18826, (1, 1, 71, 2048, 64)) # t18884: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18889 = prims.reshape(t18874, (1, 1, 71, 2048, 64)) # t18889: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t18895 = prims.convert_element_type(t18879, dtypes.float32) # t18895: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t18896 = prims.sum(t18895, (0, 1, 2)) # t18896: "cuda:0 f32[2048, 64]"
# t18897 = prims.convert_element_type(t18896, dtypes.bfloat16) # t18897: "cuda:0 bf16[2048, 64]"
# t18898 = prims.broadcast_in_dim(t18897, [1, 1, 1, 2048, 64], [3, 4]) # t18898: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t18904 = prims.convert_element_type(t18884, dtypes.float32) # t18904: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t18905 = prims.sum(t18904, (0, 1, 2)) # t18905: "cuda:0 f32[2048, 64]"
# t18906 = prims.convert_element_type(t18905, dtypes.bfloat16) # t18906: "cuda:0 bf16[2048, 64]"
# t18907 = prims.broadcast_in_dim(t18906, [1, 1, 1, 2048, 64], [3, 4]) # t18907: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t18913 = prims.cat((t18889, t18907, t18898), i603) # t18913: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i603, t18772, t18776, t18879
t18919 = torch.permute(t18913, (0, 3, 1, 2, 4)) # t18919: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t18919 = ltorch.permute(t18913, (0, 3, 1, 2, 4)) # t18919: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t18919 = prims.transpose(t18913, (0, 3, 1, 2, 4)) # t18919: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t18913
t18925 = torch.reshape(t18919, (1, 2048, 4672)) # t18925: "cuda:0 bf16[1, 2048, 4672]"
# t18925 = ltorch.reshape(t18919, (1, 2048, 4672)) # t18925: "cuda:0 bf16[1, 2048, 4672]"
# t18925 = prims.reshape(t18919, (1, 2048, 4672)) # t18925: "cuda:0 bf16[1, 2048, 4672]"
del t18919
t18926 = torch.reshape(t18925, (-1, 4672)) # t18926: "cuda:0 bf16[2048, 4672]"
# t18926 = ltorch.reshape(t18925, (-1, 4672)) # t18926: "cuda:0 bf16[2048, 4672]"
# t18926 = prims.reshape(t18925, (2048, 4672)) # t18926: "cuda:0 bf16[2048, 4672]"
del t18925
t18930 = torch.permute(t18926, (1, 0)) # t18930: "cuda:0 bf16[4672, 2048]"
# t18930 = ltorch.permute(t18926, (1, 0)) # t18930: "cuda:0 bf16[4672, 2048]"
# t18930 = prims.transpose(t18926, (1, 0)) # t18930: "cuda:0 bf16[4672, 2048]"
t18932 = torch.matmul(t18930, t18752) # t18932: "cuda:0 bf16[4672, 4544]"
# t18932 = ltorch.matmul(t18930, t18931) # t18932: "cuda:0 bf16[4672, 4544]"
# t18932 = prims.matmul(t18930, t18931) # t18932: "cuda:0 bf16[4672, 4544]"
del t18930, t18752
t18927 = torch.matmul(t18926, t_transformer_h_9_attn_attn_weight) # t18927: "cuda:0 bf16[2048, 4544]"
# t18927 = ltorch.matmul(t18926, t_transformer_h_9_attn_attn_weight) # t18927: "cuda:0 bf16[2048, 4544]"
# t18927 = prims.matmul(t18926, t_transformer_h_9_attn_attn_weight) # t18927: "cuda:0 bf16[2048, 4544]"
del t18926, t_transformer_h_9_attn_attn_weight
t18749 = torch.reshape(t18748, (1, 2048, 4544)) # t18749: "cuda:0 bf16[1, 2048, 4544]"
# t18749 = ltorch.reshape(t18748, (1, 2048, 4544)) # t18749: "cuda:0 bf16[1, 2048, 4544]"
# t18749 = prims.reshape(t18748, (1, 2048, 4544)) # t18749: "cuda:0 bf16[1, 2048, 4544]"
del t18748
t18928 = torch.reshape(t18927, (1, 2048, 4544)) # t18928: "cuda:0 bf16[1, 2048, 4544]"
# t18928 = ltorch.reshape(t18927, (1, 2048, 4544)) # t18928: "cuda:0 bf16[1, 2048, 4544]"
# t18928 = prims.reshape(t18927, (1, 2048, 4544)) # t18928: "cuda:0 bf16[1, 2048, 4544]"
del t18927
[t18941, t18947, t18989] = nvFusion69(i18969, t1259, t1391, t1412, t1427, t1432, t1438, t18706, t18749, t18928)
# t1418 = prims.convert_element_type(t1259, dtypes.float32) # t1418: "cuda:0 f32[1, 2048, 4544]"
# t1413 = prims.convert_element_type(t1412, dtypes.float32) # t1413: "cuda:0 f32[1, 2048, 4544]"
# t1414 = prims.convert_element_type(t1391, dtypes.float32) # t1414: "cuda:0 f32[1, 2048, 4544]"
# t1415 = prims.add(t1413, t1414) # t1415: "cuda:0 f32[1, 2048, 4544]"
# t1419 = prims.add(t1415, t1418) # t1419: "cuda:0 f32[1, 2048, 4544]"
# t1429 = prims.broadcast_in_dim(t1427, [1, 2048, 1], [0, 1]) # t1429: "cuda:0 f32[1, 2048, 1]"
# t1433 = prims.broadcast_in_dim(t1429, (1, 2048, 4544), (0, 1, 2)) # t1433: "cuda:0 f32[1, 2048, 4544]"
# t1435 = prims.sub(t1419, t1433) # t1435: "cuda:0 f32[1, 2048, 4544]"
# t1436 = prims.broadcast_in_dim(t1432, (1, 2048, 4544), (0, 1, 2)) # t1436: "cuda:0 f32[1, 2048, 4544]"
# t1437 = prims.mul(t1435, t1436) # t1437: "cuda:0 f32[1, 2048, 4544]"
# t1439 = prims.convert_element_type(t1438, dtypes.float32) # t1439: "cuda:0 f32[1, 2048, 4544]"
# t18986 = prims.convert_element_type(t18706, dtypes.float32) # t18986: "cuda:0 f32[1, 2048, 4544]"
# t18933 = prims.convert_element_type(t18749, dtypes.float32) # t18933: "cuda:0 f32[1, 2048, 4544]"
# t18934 = prims.convert_element_type(t18928, dtypes.float32) # t18934: "cuda:0 f32[1, 2048, 4544]"
# t18935 = prims.add(t18933, t18934) # t18935: "cuda:0 f32[1, 2048, 4544]"
# t18940 = prims.sum(t18935, (0, 1)) # t18940: "cuda:0 f32[4544]"
# t18941 = prims.convert_element_type(t18940, dtypes.bfloat16) # t18941: "cuda:0 bf16[4544]"
# t18942 = prims.mul(t1439, t18935) # t18942: "cuda:0 f32[1, 2048, 4544]"
# t18943 = prims.mul(t1437, t18935) # t18943: "cuda:0 f32[1, 2048, 4544]"
# t18946 = prims.sum(t18943, (0, 1)) # t18946: "cuda:0 f32[4544]"
# t18947 = prims.convert_element_type(t18946, dtypes.bfloat16) # t18947: "cuda:0 bf16[4544]"
# t18948 = prims.mul(t1436, t18942) # t18948: "cuda:0 f32[1, 2048, 4544]"
# t18949 = prims.mul(t1435, t18942) # t18949: "cuda:0 f32[1, 2048, 4544]"
# t18950 = prims.sum(t18949, (0, 2)) # t18950: "cuda:0 f32[2048]"
# t18951 = prims.broadcast_in_dim(t18950, [1, 2048, 1], [1]) # t18951: "cuda:0 f32[1, 2048, 1]"
# t18952 = prims.neg(t18948) # t18952: "cuda:0 f32[1, 2048, 4544]"
# t18954 = prims.sum(t18952, (0, 2)) # t18954: "cuda:0 f32[2048]"
# t18955 = prims.broadcast_in_dim(t18954, [1, 2048, 1], [1]) # t18955: "cuda:0 f32[1, 2048, 1]"
# t18956 = prims.mul(-0.5, t18951) # t18956: "cuda:0 f32[1, 2048, 1]"
# t18957 = prims.pow(t1432, 3.0) # t18957: "cuda:0 f32[1, 2048, 1]"
# t18958 = prims.mul(t18956, t18957) # t18958: "cuda:0 f32[1, 2048, 1]"
# t18960 = prims.sum(t18955, (0, 2)) # t18960: "cuda:0 f32[2048]"
# t18961 = prims.broadcast_in_dim(t18960, [1, 2048], [1]) # t18961: "cuda:0 f32[1, 2048]"
# t18962 = prims.sum(t18958, (0, 2)) # t18962: "cuda:0 f32[2048]"
# t18963 = prims.broadcast_in_dim(t18962, [1, 2048], [1]) # t18963: "cuda:0 f32[1, 2048]"
# t18966 = prims.broadcast_in_dim(t18961, [1, 2048, 1], [0, 1]) # t18966: "cuda:0 f32[1, 2048, 1]"
# t18967 = prims.broadcast_in_dim(t18966, (1, 2048, 4544), (0, 1, 2)) # t18967: "cuda:0 f32[1, 2048, 4544]"
# t18968 = prims.mul(0.00022007042253521127, t18967) # t18968: "cuda:0 f32[1, 2048, 4544]"
# t18970 = prims.broadcast_in_dim(t18963, [1, 2048, 1], [0, 1]) # t18970: "cuda:0 f32[1, 2048, 1]"
# t18971 = prims.broadcast_in_dim(t18970, (1, 2048, 4544), (0, 1, 2)) # t18971: "cuda:0 f32[1, 2048, 4544]"
# t18973 = prims.broadcast_in_dim(t1427, [1, 2048, 1], [0, 1]) # t18973: "cuda:0 f32[1, 2048, 1]"
# t18974 = prims.broadcast_in_dim(t18973, (1, 2048, 4544), (0, 1, 2)) # t18974: "cuda:0 f32[1, 2048, 4544]"
# t18975 = prims.mul(2.0, t18971) # t18975: "cuda:0 f32[1, 2048, 4544]"
# t18976 = prims.sub(t1419, t18974) # t18976: "cuda:0 f32[1, 2048, 4544]"
# t18977 = prims.mul(t18975, t18976) # t18977: "cuda:0 f32[1, 2048, 4544]"
# f18978 = prims.convert_element_type(i18969, float) # f18978: "float 4544.0"
# t18979 = prims.div(t18977, f18978) # t18979: "cuda:0 f32[1, 2048, 4544]"
# t18980 = prims.add(t18968, t18979) # t18980: "cuda:0 f32[1, 2048, 4544]"
# t18984 = prims.add(t18948, t18980) # t18984: "cuda:0 f32[1, 2048, 4544]"
# t18988 = prims.add(t18986, t18984) # t18988: "cuda:0 f32[1, 2048, 4544]"
# t18989 = prims.convert_element_type(t18988, dtypes.bfloat16) # t18989: "cuda:0 bf16[1, 2048, 4544]"
del i18969, t1259, t1391, t1412, t1427, t1432, t1438, t18706, t18749, t18928
t18996 = torch.reshape(t18989, (-1, 4544)) # t18996: "cuda:0 bf16[2048, 4544]"
# t18996 = ltorch.reshape(t18989, (-1, 4544)) # t18996: "cuda:0 bf16[2048, 4544]"
# t18996 = prims.reshape(t18989, (2048, 4544)) # t18996: "cuda:0 bf16[2048, 4544]"
t19000 = torch.permute(t18996, (1, 0)) # t19000: "cuda:0 bf16[4544, 2048]"
# t19000 = ltorch.permute(t18996, (1, 0)) # t19000: "cuda:0 bf16[4544, 2048]"
# t19000 = prims.transpose(t18996, (1, 0)) # t19000: "cuda:0 bf16[4544, 2048]"
t18997 = torch.matmul(t18996, t_transformer_h_8_mlp_proj_weight) # t18997: "cuda:0 bf16[2048, 18176]"
# t18997 = ltorch.matmul(t18996, t_transformer_h_8_mlp_proj_weight) # t18997: "cuda:0 bf16[2048, 18176]"
# t18997 = prims.matmul(t18996, t_transformer_h_8_mlp_proj_weight) # t18997: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_8_mlp_proj_weight
t19002 = torch.matmul(t19000, t19001) # t19002: "cuda:0 bf16[4544, 18176]"
# t19002 = ltorch.matmul(t19000, t19001) # t19002: "cuda:0 bf16[4544, 18176]"
# t19002 = prims.matmul(t19000, t19001) # t19002: "cuda:0 bf16[4544, 18176]"
del t19001
t19038 = torch.matmul(t18996, t_transformer_h_8_attn_proj_weight) # t19038: "cuda:0 bf16[2048, 4544]"
# t19038 = ltorch.matmul(t19037, t_transformer_h_8_attn_proj_weight) # t19038: "cuda:0 bf16[2048, 4544]"
# t19038 = prims.matmul(t19037, t_transformer_h_8_attn_proj_weight) # t19038: "cuda:0 bf16[2048, 4544]"
del t18996, t_transformer_h_8_attn_proj_weight
t19043 = torch.matmul(t19000, t19042) # t19043: "cuda:0 bf16[4544, 4544]"
# t19043 = ltorch.matmul(t19041, t19042) # t19043: "cuda:0 bf16[4544, 4544]"
# t19043 = prims.matmul(t19041, t19042) # t19043: "cuda:0 bf16[4544, 4544]"
del t19000, t19042
t18998 = torch.reshape(t18997, (1, 2048, 18176)) # t18998: "cuda:0 bf16[1, 2048, 18176]"
# t18998 = ltorch.reshape(t18997, (1, 2048, 18176)) # t18998: "cuda:0 bf16[1, 2048, 18176]"
# t18998 = prims.reshape(t18997, (1, 2048, 18176)) # t18998: "cuda:0 bf16[1, 2048, 18176]"
del t18997
t19039 = torch.reshape(t19038, (1, 2048, 4544)) # t19039: "cuda:0 bf16[1, 2048, 4544]"
# t19039 = ltorch.reshape(t19038, (1, 2048, 4544)) # t19039: "cuda:0 bf16[1, 2048, 4544]"
# t19039 = prims.reshape(t19038, (1, 2048, 4544)) # t19039: "cuda:0 bf16[1, 2048, 4544]"
del t19038
t19047 = torch.reshape(t19039, (1, 2048, 71, 64)) # t19047: "cuda:0 bf16[1, 2048, 71, 64]"
# t19047 = ltorch.reshape(t19039, (1, 2048, 71, 64)) # t19047: "cuda:0 bf16[1, 2048, 71, 64]"
# t19047 = prims.reshape(t19039, (1, 2048, 71, 64)) # t19047: "cuda:0 bf16[1, 2048, 71, 64]"
del t19039
t19050 = torch.permute(t19047, (0, 2, 1, 3)) # t19050: "cuda:0 bf16[1, 71, 2048, 64]"
# t19050 = ltorch.permute(t19047, (0, 2, 1, 3)) # t19050: "cuda:0 bf16[1, 71, 2048, 64]"
# t19050 = prims.transpose(t19047, (0, 2, 1, 3)) # t19050: "cuda:0 bf16[1, 71, 2048, 64]"
del t19047
[t19029] = nvFusion70(f575, f577, t1392, t18998)
# t1393 = prims.convert_element_type(t1392, dtypes.float32) # t1393: "cuda:0 f32[1, 2048, 18176]"
# t1395 = prims.div(t1393, 1.4142135623730951) # t1395: "cuda:0 f32[1, 2048, 18176]"
# t1398 = prims.erf(t1395) # t1398: "cuda:0 f32[1, 2048, 18176]"
# t1402 = prims.mul(0.5, t1398) # t1402: "cuda:0 f32[1, 2048, 18176]"
# t1406 = prims.add(0.5, t1402) # t1406: "cuda:0 f32[1, 2048, 18176]"
# t19003 = prims.convert_element_type(t18998, dtypes.float32) # t19003: "cuda:0 f32[1, 2048, 18176]"
# t19004 = prims.mul(t1406, t19003) # t19004: "cuda:0 f32[1, 2048, 18176]"
# t19005 = prims.mul(t1393, t19003) # t19005: "cuda:0 f32[1, 2048, 18176]"
# t19013 = prims.mul(f577, t19005) # t19013: "cuda:0 f32[1, 2048, 18176]"
# t19016 = prims.pow(t1395, 2.0) # t19016: "cuda:0 f32[1, 2048, 18176]"
# t19017 = prims.neg(t19016) # t19017: "cuda:0 f32[1, 2048, 18176]"
# t19018 = prims.exp(t19017) # t19018: "cuda:0 f32[1, 2048, 18176]"
# t19019 = prims.mul(1.1283791670955126, t19018) # t19019: "cuda:0 f32[1, 2048, 18176]"
# t19020 = prims.mul(t19019, t19013) # t19020: "cuda:0 f32[1, 2048, 18176]"
# t19024 = prims.div(t19020, f575) # t19024: "cuda:0 f32[1, 2048, 18176]"
# t19028 = prims.add(t19004, t19024) # t19028: "cuda:0 f32[1, 2048, 18176]"
# t19029 = prims.convert_element_type(t19028, dtypes.bfloat16) # t19029: "cuda:0 bf16[1, 2048, 18176]"
del f575, f577, t1392, t18998
t19030 = torch.reshape(t19029, (-1, 18176)) # t19030: "cuda:0 bf16[2048, 18176]"
# t19030 = ltorch.reshape(t19029, (-1, 18176)) # t19030: "cuda:0 bf16[2048, 18176]"
# t19030 = prims.reshape(t19029, (2048, 18176)) # t19030: "cuda:0 bf16[2048, 18176]"
del t19029
t19034 = torch.permute(t19030, (1, 0)) # t19034: "cuda:0 bf16[18176, 2048]"
# t19034 = ltorch.permute(t19030, (1, 0)) # t19034: "cuda:0 bf16[18176, 2048]"
# t19034 = prims.transpose(t19030, (1, 0)) # t19034: "cuda:0 bf16[18176, 2048]"
t19036 = torch.matmul(t19034, t19035) # t19036: "cuda:0 bf16[18176, 4544]"
# t19036 = ltorch.matmul(t19034, t19035) # t19036: "cuda:0 bf16[18176, 4544]"
# t19036 = prims.matmul(t19034, t19035) # t19036: "cuda:0 bf16[18176, 4544]"
del t19034
t19031 = torch.matmul(t19030, t_transformer_h_8_mlp_fc_weight) # t19031: "cuda:0 bf16[2048, 4544]"
# t19031 = ltorch.matmul(t19030, t_transformer_h_8_mlp_fc_weight) # t19031: "cuda:0 bf16[2048, 4544]"
# t19031 = prims.matmul(t19030, t_transformer_h_8_mlp_fc_weight) # t19031: "cuda:0 bf16[2048, 4544]"
del t19030, t_transformer_h_8_mlp_fc_weight
(t19051, t19052, t19053) = cudnn_sdpa_bwd(t19050, t1376, t1379, t1329, None, f566, b567, t1380, t1381, t1382, t1383, scale=f568, cat_grad_qkv=False)
del t19050, t1376, t1379, t1329, f566, b567, t1380, t1381, t1382, t1383, f568
t19055 = torch_slice_prim_impl(t19052, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19055: "cuda:0 bf16[1, 71, 2048, 64]"
del t19052
t19059 = torch_slice_prim_impl(t19051, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19059: "cuda:0 bf16[1, 71, 2048, 64]"
del t19051
t19162 = torch.reshape(t19053, (1, 1, 71, 2048, 64)) # t19162: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19162 = ltorch.reshape(t19053, (1, 1, 71, 2048, 64)) # t19162: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19162 = prims.reshape(t19053, (1, 1, 71, 2048, 64)) # t19162: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t19053
[t19196] = nvFusion71(i539, t19055, t19059, t19162, t61, t66)
# t19056 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t19056: "cuda:0 bf16[1, 71, 2048, 0]"
# t19057 = prims.pad(t19056, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t19057: "cuda:0 bf16[1, 71, 2048, 64]"
# t19060 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t19060: "cuda:0 bf16[1, 71, 2048, 0]"
# t19061 = prims.pad(t19060, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t19061: "cuda:0 bf16[1, 71, 2048, 64]"
# t19062 = prims.convert_element_type(t19055, dtypes.float32) # t19062: "cuda:0 f32[1, 71, 2048, 64]"
# t19066 = prims.mul(t66, t19062) # t19066: "cuda:0 f32[1, 71, 2048, 64]"
# t19069 = prims.convert_element_type(t19066, dtypes.bfloat16) # t19069: "cuda:0 bf16[1, 71, 2048, 64]"
# t19078 = prims.mul(t61, t19062) # t19078: "cuda:0 f32[1, 71, 2048, 64]"
# t19090 = prims.slice_prim(t19069, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t19090: "cuda:0 bf16[1, 71, 2048, 32]"
# t19091 = prims.slice_prim(t19069, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19091: "cuda:0 bf16[1, 71, 2048, 32]"
# t19092 = prims.convert_element_type(t19090, dtypes.float32) # t19092: "cuda:0 f32[1, 71, 2048, 32]"
# t19093 = prims.neg(t19092) # t19093: "cuda:0 f32[1, 71, 2048, 32]"
# t19094 = prims.convert_element_type(t19093, dtypes.bfloat16) # t19094: "cuda:0 bf16[1, 71, 2048, 32]"
# t19095 = prims.pad(t19094, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t19095: "cuda:0 bf16[1, 71, 2048, 64]"
# t19097 = prims.convert_element_type(t19095, dtypes.float32) # t19097: "cuda:0 f32[1, 71, 2048, 64]"
# t19098 = prims.add(t19078, t19097) # t19098: "cuda:0 f32[1, 71, 2048, 64]"
# t19100 = prims.pad(t19091, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t19100: "cuda:0 bf16[1, 71, 2048, 64]"
# t19102 = prims.convert_element_type(t19100, dtypes.float32) # t19102: "cuda:0 f32[1, 71, 2048, 64]"
# t19103 = prims.add(t19098, t19102) # t19103: "cuda:0 f32[1, 71, 2048, 64]"
# t19104 = prims.convert_element_type(t19103, dtypes.bfloat16) # t19104: "cuda:0 bf16[1, 71, 2048, 64]"
# t19105 = prims.pad(t19104, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t19105: "cuda:0 bf16[1, 71, 2048, 64]"
# t19106 = prims.convert_element_type(t19057, dtypes.float32) # t19106: "cuda:0 f32[1, 71, 2048, 64]"
# t19107 = prims.convert_element_type(t19105, dtypes.float32) # t19107: "cuda:0 f32[1, 71, 2048, 64]"
# t19108 = prims.add(t19106, t19107) # t19108: "cuda:0 f32[1, 71, 2048, 64]"
# t19109 = prims.convert_element_type(t19108, dtypes.bfloat16) # t19109: "cuda:0 bf16[1, 71, 2048, 64]"
# t19110 = prims.convert_element_type(t19059, dtypes.float32) # t19110: "cuda:0 f32[1, 71, 2048, 64]"
# t19114 = prims.mul(t66, t19110) # t19114: "cuda:0 f32[1, 71, 2048, 64]"
# t19117 = prims.convert_element_type(t19114, dtypes.bfloat16) # t19117: "cuda:0 bf16[1, 71, 2048, 64]"
# t19126 = prims.mul(t61, t19110) # t19126: "cuda:0 f32[1, 71, 2048, 64]"
# t19138 = prims.slice_prim(t19117, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t19138: "cuda:0 bf16[1, 71, 2048, 32]"
# t19139 = prims.slice_prim(t19117, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19139: "cuda:0 bf16[1, 71, 2048, 32]"
# t19140 = prims.convert_element_type(t19138, dtypes.float32) # t19140: "cuda:0 f32[1, 71, 2048, 32]"
# t19141 = prims.neg(t19140) # t19141: "cuda:0 f32[1, 71, 2048, 32]"
# t19142 = prims.convert_element_type(t19141, dtypes.bfloat16) # t19142: "cuda:0 bf16[1, 71, 2048, 32]"
# t19143 = prims.pad(t19142, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t19143: "cuda:0 bf16[1, 71, 2048, 64]"
# t19145 = prims.convert_element_type(t19143, dtypes.float32) # t19145: "cuda:0 f32[1, 71, 2048, 64]"
# t19146 = prims.add(t19126, t19145) # t19146: "cuda:0 f32[1, 71, 2048, 64]"
# t19148 = prims.pad(t19139, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t19148: "cuda:0 bf16[1, 71, 2048, 64]"
# t19150 = prims.convert_element_type(t19148, dtypes.float32) # t19150: "cuda:0 f32[1, 71, 2048, 64]"
# t19151 = prims.add(t19146, t19150) # t19151: "cuda:0 f32[1, 71, 2048, 64]"
# t19152 = prims.convert_element_type(t19151, dtypes.bfloat16) # t19152: "cuda:0 bf16[1, 71, 2048, 64]"
# t19153 = prims.pad(t19152, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t19153: "cuda:0 bf16[1, 71, 2048, 64]"
# t19154 = prims.convert_element_type(t19061, dtypes.float32) # t19154: "cuda:0 f32[1, 71, 2048, 64]"
# t19155 = prims.convert_element_type(t19153, dtypes.float32) # t19155: "cuda:0 f32[1, 71, 2048, 64]"
# t19156 = prims.add(t19154, t19155) # t19156: "cuda:0 f32[1, 71, 2048, 64]"
# t19157 = prims.convert_element_type(t19156, dtypes.bfloat16) # t19157: "cuda:0 bf16[1, 71, 2048, 64]"
# t19167 = prims.reshape(t19109, (1, 1, 71, 2048, 64)) # t19167: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19172 = prims.reshape(t19157, (1, 1, 71, 2048, 64)) # t19172: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19178 = prims.convert_element_type(t19162, dtypes.float32) # t19178: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t19179 = prims.sum(t19178, (0, 1, 2)) # t19179: "cuda:0 f32[2048, 64]"
# t19180 = prims.convert_element_type(t19179, dtypes.bfloat16) # t19180: "cuda:0 bf16[2048, 64]"
# t19181 = prims.broadcast_in_dim(t19180, [1, 1, 1, 2048, 64], [3, 4]) # t19181: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t19187 = prims.convert_element_type(t19167, dtypes.float32) # t19187: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t19188 = prims.sum(t19187, (0, 1, 2)) # t19188: "cuda:0 f32[2048, 64]"
# t19189 = prims.convert_element_type(t19188, dtypes.bfloat16) # t19189: "cuda:0 bf16[2048, 64]"
# t19190 = prims.broadcast_in_dim(t19189, [1, 1, 1, 2048, 64], [3, 4]) # t19190: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t19196 = prims.cat((t19172, t19190, t19181), i539) # t19196: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i539, t19055, t19059, t19162
t19202 = torch.permute(t19196, (0, 3, 1, 2, 4)) # t19202: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t19202 = ltorch.permute(t19196, (0, 3, 1, 2, 4)) # t19202: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t19202 = prims.transpose(t19196, (0, 3, 1, 2, 4)) # t19202: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t19196
t19208 = torch.reshape(t19202, (1, 2048, 4672)) # t19208: "cuda:0 bf16[1, 2048, 4672]"
# t19208 = ltorch.reshape(t19202, (1, 2048, 4672)) # t19208: "cuda:0 bf16[1, 2048, 4672]"
# t19208 = prims.reshape(t19202, (1, 2048, 4672)) # t19208: "cuda:0 bf16[1, 2048, 4672]"
del t19202
t19209 = torch.reshape(t19208, (-1, 4672)) # t19209: "cuda:0 bf16[2048, 4672]"
# t19209 = ltorch.reshape(t19208, (-1, 4672)) # t19209: "cuda:0 bf16[2048, 4672]"
# t19209 = prims.reshape(t19208, (2048, 4672)) # t19209: "cuda:0 bf16[2048, 4672]"
del t19208
t19213 = torch.permute(t19209, (1, 0)) # t19213: "cuda:0 bf16[4672, 2048]"
# t19213 = ltorch.permute(t19209, (1, 0)) # t19213: "cuda:0 bf16[4672, 2048]"
# t19213 = prims.transpose(t19209, (1, 0)) # t19213: "cuda:0 bf16[4672, 2048]"
t19215 = torch.matmul(t19213, t19035) # t19215: "cuda:0 bf16[4672, 4544]"
# t19215 = ltorch.matmul(t19213, t19214) # t19215: "cuda:0 bf16[4672, 4544]"
# t19215 = prims.matmul(t19213, t19214) # t19215: "cuda:0 bf16[4672, 4544]"
del t19213, t19035
t19210 = torch.matmul(t19209, t_transformer_h_8_attn_attn_weight) # t19210: "cuda:0 bf16[2048, 4544]"
# t19210 = ltorch.matmul(t19209, t_transformer_h_8_attn_attn_weight) # t19210: "cuda:0 bf16[2048, 4544]"
# t19210 = prims.matmul(t19209, t_transformer_h_8_attn_attn_weight) # t19210: "cuda:0 bf16[2048, 4544]"
del t19209, t_transformer_h_8_attn_attn_weight
t19032 = torch.reshape(t19031, (1, 2048, 4544)) # t19032: "cuda:0 bf16[1, 2048, 4544]"
# t19032 = ltorch.reshape(t19031, (1, 2048, 4544)) # t19032: "cuda:0 bf16[1, 2048, 4544]"
# t19032 = prims.reshape(t19031, (1, 2048, 4544)) # t19032: "cuda:0 bf16[1, 2048, 4544]"
del t19031
t19211 = torch.reshape(t19210, (1, 2048, 4544)) # t19211: "cuda:0 bf16[1, 2048, 4544]"
# t19211 = ltorch.reshape(t19210, (1, 2048, 4544)) # t19211: "cuda:0 bf16[1, 2048, 4544]"
# t19211 = prims.reshape(t19210, (1, 2048, 4544)) # t19211: "cuda:0 bf16[1, 2048, 4544]"
del t19210
[t19224, t19230, t19272] = nvFusion72(i19252, t1098, t1230, t1251, t1266, t1271, t1277, t18989, t19032, t19211)
# t1257 = prims.convert_element_type(t1098, dtypes.float32) # t1257: "cuda:0 f32[1, 2048, 4544]"
# t1252 = prims.convert_element_type(t1251, dtypes.float32) # t1252: "cuda:0 f32[1, 2048, 4544]"
# t1253 = prims.convert_element_type(t1230, dtypes.float32) # t1253: "cuda:0 f32[1, 2048, 4544]"
# t1254 = prims.add(t1252, t1253) # t1254: "cuda:0 f32[1, 2048, 4544]"
# t1258 = prims.add(t1254, t1257) # t1258: "cuda:0 f32[1, 2048, 4544]"
# t1268 = prims.broadcast_in_dim(t1266, [1, 2048, 1], [0, 1]) # t1268: "cuda:0 f32[1, 2048, 1]"
# t1272 = prims.broadcast_in_dim(t1268, (1, 2048, 4544), (0, 1, 2)) # t1272: "cuda:0 f32[1, 2048, 4544]"
# t1274 = prims.sub(t1258, t1272) # t1274: "cuda:0 f32[1, 2048, 4544]"
# t1275 = prims.broadcast_in_dim(t1271, (1, 2048, 4544), (0, 1, 2)) # t1275: "cuda:0 f32[1, 2048, 4544]"
# t1276 = prims.mul(t1274, t1275) # t1276: "cuda:0 f32[1, 2048, 4544]"
# t1278 = prims.convert_element_type(t1277, dtypes.float32) # t1278: "cuda:0 f32[1, 2048, 4544]"
# t19269 = prims.convert_element_type(t18989, dtypes.float32) # t19269: "cuda:0 f32[1, 2048, 4544]"
# t19216 = prims.convert_element_type(t19032, dtypes.float32) # t19216: "cuda:0 f32[1, 2048, 4544]"
# t19217 = prims.convert_element_type(t19211, dtypes.float32) # t19217: "cuda:0 f32[1, 2048, 4544]"
# t19218 = prims.add(t19216, t19217) # t19218: "cuda:0 f32[1, 2048, 4544]"
# t19223 = prims.sum(t19218, (0, 1)) # t19223: "cuda:0 f32[4544]"
# t19224 = prims.convert_element_type(t19223, dtypes.bfloat16) # t19224: "cuda:0 bf16[4544]"
# t19225 = prims.mul(t1278, t19218) # t19225: "cuda:0 f32[1, 2048, 4544]"
# t19226 = prims.mul(t1276, t19218) # t19226: "cuda:0 f32[1, 2048, 4544]"
# t19229 = prims.sum(t19226, (0, 1)) # t19229: "cuda:0 f32[4544]"
# t19230 = prims.convert_element_type(t19229, dtypes.bfloat16) # t19230: "cuda:0 bf16[4544]"
# t19231 = prims.mul(t1275, t19225) # t19231: "cuda:0 f32[1, 2048, 4544]"
# t19232 = prims.mul(t1274, t19225) # t19232: "cuda:0 f32[1, 2048, 4544]"
# t19233 = prims.sum(t19232, (0, 2)) # t19233: "cuda:0 f32[2048]"
# t19234 = prims.broadcast_in_dim(t19233, [1, 2048, 1], [1]) # t19234: "cuda:0 f32[1, 2048, 1]"
# t19235 = prims.neg(t19231) # t19235: "cuda:0 f32[1, 2048, 4544]"
# t19237 = prims.sum(t19235, (0, 2)) # t19237: "cuda:0 f32[2048]"
# t19238 = prims.broadcast_in_dim(t19237, [1, 2048, 1], [1]) # t19238: "cuda:0 f32[1, 2048, 1]"
# t19239 = prims.mul(-0.5, t19234) # t19239: "cuda:0 f32[1, 2048, 1]"
# t19240 = prims.pow(t1271, 3.0) # t19240: "cuda:0 f32[1, 2048, 1]"
# t19241 = prims.mul(t19239, t19240) # t19241: "cuda:0 f32[1, 2048, 1]"
# t19243 = prims.sum(t19238, (0, 2)) # t19243: "cuda:0 f32[2048]"
# t19244 = prims.broadcast_in_dim(t19243, [1, 2048], [1]) # t19244: "cuda:0 f32[1, 2048]"
# t19245 = prims.sum(t19241, (0, 2)) # t19245: "cuda:0 f32[2048]"
# t19246 = prims.broadcast_in_dim(t19245, [1, 2048], [1]) # t19246: "cuda:0 f32[1, 2048]"
# t19249 = prims.broadcast_in_dim(t19244, [1, 2048, 1], [0, 1]) # t19249: "cuda:0 f32[1, 2048, 1]"
# t19250 = prims.broadcast_in_dim(t19249, (1, 2048, 4544), (0, 1, 2)) # t19250: "cuda:0 f32[1, 2048, 4544]"
# t19251 = prims.mul(0.00022007042253521127, t19250) # t19251: "cuda:0 f32[1, 2048, 4544]"
# t19253 = prims.broadcast_in_dim(t19246, [1, 2048, 1], [0, 1]) # t19253: "cuda:0 f32[1, 2048, 1]"
# t19254 = prims.broadcast_in_dim(t19253, (1, 2048, 4544), (0, 1, 2)) # t19254: "cuda:0 f32[1, 2048, 4544]"
# t19256 = prims.broadcast_in_dim(t1266, [1, 2048, 1], [0, 1]) # t19256: "cuda:0 f32[1, 2048, 1]"
# t19257 = prims.broadcast_in_dim(t19256, (1, 2048, 4544), (0, 1, 2)) # t19257: "cuda:0 f32[1, 2048, 4544]"
# t19258 = prims.mul(2.0, t19254) # t19258: "cuda:0 f32[1, 2048, 4544]"
# t19259 = prims.sub(t1258, t19257) # t19259: "cuda:0 f32[1, 2048, 4544]"
# t19260 = prims.mul(t19258, t19259) # t19260: "cuda:0 f32[1, 2048, 4544]"
# f19261 = prims.convert_element_type(i19252, float) # f19261: "float 4544.0"
# t19262 = prims.div(t19260, f19261) # t19262: "cuda:0 f32[1, 2048, 4544]"
# t19263 = prims.add(t19251, t19262) # t19263: "cuda:0 f32[1, 2048, 4544]"
# t19267 = prims.add(t19231, t19263) # t19267: "cuda:0 f32[1, 2048, 4544]"
# t19271 = prims.add(t19269, t19267) # t19271: "cuda:0 f32[1, 2048, 4544]"
# t19272 = prims.convert_element_type(t19271, dtypes.bfloat16) # t19272: "cuda:0 bf16[1, 2048, 4544]"
del i19252, t1098, t1230, t1251, t1266, t1271, t1277, t18989, t19032, t19211
t19279 = torch.reshape(t19272, (-1, 4544)) # t19279: "cuda:0 bf16[2048, 4544]"
# t19279 = ltorch.reshape(t19272, (-1, 4544)) # t19279: "cuda:0 bf16[2048, 4544]"
# t19279 = prims.reshape(t19272, (2048, 4544)) # t19279: "cuda:0 bf16[2048, 4544]"
t19283 = torch.permute(t19279, (1, 0)) # t19283: "cuda:0 bf16[4544, 2048]"
# t19283 = ltorch.permute(t19279, (1, 0)) # t19283: "cuda:0 bf16[4544, 2048]"
# t19283 = prims.transpose(t19279, (1, 0)) # t19283: "cuda:0 bf16[4544, 2048]"
t19285 = torch.matmul(t19283, t19284) # t19285: "cuda:0 bf16[4544, 18176]"
# t19285 = ltorch.matmul(t19283, t19284) # t19285: "cuda:0 bf16[4544, 18176]"
# t19285 = prims.matmul(t19283, t19284) # t19285: "cuda:0 bf16[4544, 18176]"
del t19284
t19321 = torch.matmul(t19279, t_transformer_h_7_attn_proj_weight) # t19321: "cuda:0 bf16[2048, 4544]"
# t19321 = ltorch.matmul(t19320, t_transformer_h_7_attn_proj_weight) # t19321: "cuda:0 bf16[2048, 4544]"
# t19321 = prims.matmul(t19320, t_transformer_h_7_attn_proj_weight) # t19321: "cuda:0 bf16[2048, 4544]"
del t_transformer_h_7_attn_proj_weight
t19326 = torch.matmul(t19283, t19325) # t19326: "cuda:0 bf16[4544, 4544]"
# t19326 = ltorch.matmul(t19324, t19325) # t19326: "cuda:0 bf16[4544, 4544]"
# t19326 = prims.matmul(t19324, t19325) # t19326: "cuda:0 bf16[4544, 4544]"
del t19283, t19325
t19280 = torch.matmul(t19279, t_transformer_h_7_mlp_proj_weight) # t19280: "cuda:0 bf16[2048, 18176]"
# t19280 = ltorch.matmul(t19279, t_transformer_h_7_mlp_proj_weight) # t19280: "cuda:0 bf16[2048, 18176]"
# t19280 = prims.matmul(t19279, t_transformer_h_7_mlp_proj_weight) # t19280: "cuda:0 bf16[2048, 18176]"
del t19279, t_transformer_h_7_mlp_proj_weight
t19322 = torch.reshape(t19321, (1, 2048, 4544)) # t19322: "cuda:0 bf16[1, 2048, 4544]"
# t19322 = ltorch.reshape(t19321, (1, 2048, 4544)) # t19322: "cuda:0 bf16[1, 2048, 4544]"
# t19322 = prims.reshape(t19321, (1, 2048, 4544)) # t19322: "cuda:0 bf16[1, 2048, 4544]"
del t19321
t19330 = torch.reshape(t19322, (1, 2048, 71, 64)) # t19330: "cuda:0 bf16[1, 2048, 71, 64]"
# t19330 = ltorch.reshape(t19322, (1, 2048, 71, 64)) # t19330: "cuda:0 bf16[1, 2048, 71, 64]"
# t19330 = prims.reshape(t19322, (1, 2048, 71, 64)) # t19330: "cuda:0 bf16[1, 2048, 71, 64]"
del t19322
t19333 = torch.permute(t19330, (0, 2, 1, 3)) # t19333: "cuda:0 bf16[1, 71, 2048, 64]"
# t19333 = ltorch.permute(t19330, (0, 2, 1, 3)) # t19333: "cuda:0 bf16[1, 71, 2048, 64]"
# t19333 = prims.transpose(t19330, (0, 2, 1, 3)) # t19333: "cuda:0 bf16[1, 71, 2048, 64]"
del t19330
t19281 = torch.reshape(t19280, (1, 2048, 18176)) # t19281: "cuda:0 bf16[1, 2048, 18176]"
# t19281 = ltorch.reshape(t19280, (1, 2048, 18176)) # t19281: "cuda:0 bf16[1, 2048, 18176]"
# t19281 = prims.reshape(t19280, (1, 2048, 18176)) # t19281: "cuda:0 bf16[1, 2048, 18176]"
del t19280
[t19312] = nvFusion73(f511, f513, t1231, t19281)
# t1232 = prims.convert_element_type(t1231, dtypes.float32) # t1232: "cuda:0 f32[1, 2048, 18176]"
# t1234 = prims.div(t1232, 1.4142135623730951) # t1234: "cuda:0 f32[1, 2048, 18176]"
# t1237 = prims.erf(t1234) # t1237: "cuda:0 f32[1, 2048, 18176]"
# t1241 = prims.mul(0.5, t1237) # t1241: "cuda:0 f32[1, 2048, 18176]"
# t1245 = prims.add(0.5, t1241) # t1245: "cuda:0 f32[1, 2048, 18176]"
# t19286 = prims.convert_element_type(t19281, dtypes.float32) # t19286: "cuda:0 f32[1, 2048, 18176]"
# t19287 = prims.mul(t1245, t19286) # t19287: "cuda:0 f32[1, 2048, 18176]"
# t19288 = prims.mul(t1232, t19286) # t19288: "cuda:0 f32[1, 2048, 18176]"
# t19296 = prims.mul(f513, t19288) # t19296: "cuda:0 f32[1, 2048, 18176]"
# t19299 = prims.pow(t1234, 2.0) # t19299: "cuda:0 f32[1, 2048, 18176]"
# t19300 = prims.neg(t19299) # t19300: "cuda:0 f32[1, 2048, 18176]"
# t19301 = prims.exp(t19300) # t19301: "cuda:0 f32[1, 2048, 18176]"
# t19302 = prims.mul(1.1283791670955126, t19301) # t19302: "cuda:0 f32[1, 2048, 18176]"
# t19303 = prims.mul(t19302, t19296) # t19303: "cuda:0 f32[1, 2048, 18176]"
# t19307 = prims.div(t19303, f511) # t19307: "cuda:0 f32[1, 2048, 18176]"
# t19311 = prims.add(t19287, t19307) # t19311: "cuda:0 f32[1, 2048, 18176]"
# t19312 = prims.convert_element_type(t19311, dtypes.bfloat16) # t19312: "cuda:0 bf16[1, 2048, 18176]"
del f511, f513, t1231, t19281
t19313 = torch.reshape(t19312, (-1, 18176)) # t19313: "cuda:0 bf16[2048, 18176]"
# t19313 = ltorch.reshape(t19312, (-1, 18176)) # t19313: "cuda:0 bf16[2048, 18176]"
# t19313 = prims.reshape(t19312, (2048, 18176)) # t19313: "cuda:0 bf16[2048, 18176]"
del t19312
t19317 = torch.permute(t19313, (1, 0)) # t19317: "cuda:0 bf16[18176, 2048]"
# t19317 = ltorch.permute(t19313, (1, 0)) # t19317: "cuda:0 bf16[18176, 2048]"
# t19317 = prims.transpose(t19313, (1, 0)) # t19317: "cuda:0 bf16[18176, 2048]"
(t19334, t19335, t19336) = cudnn_sdpa_bwd(t19333, t1215, t1218, t1168, None, f502, b503, t1219, t1220, t1221, t1222, scale=f504, cat_grad_qkv=False)
del t19333, t1215, t1218, t1168, f502, b503, t1219, t1220, t1221, t1222, f504
t19319 = torch.matmul(t19317, t19318) # t19319: "cuda:0 bf16[18176, 4544]"
# t19319 = ltorch.matmul(t19317, t19318) # t19319: "cuda:0 bf16[18176, 4544]"
# t19319 = prims.matmul(t19317, t19318) # t19319: "cuda:0 bf16[18176, 4544]"
del t19317
t19314 = torch.matmul(t19313, t_transformer_h_7_mlp_fc_weight) # t19314: "cuda:0 bf16[2048, 4544]"
# t19314 = ltorch.matmul(t19313, t_transformer_h_7_mlp_fc_weight) # t19314: "cuda:0 bf16[2048, 4544]"
# t19314 = prims.matmul(t19313, t_transformer_h_7_mlp_fc_weight) # t19314: "cuda:0 bf16[2048, 4544]"
del t19313, t_transformer_h_7_mlp_fc_weight
t19338 = torch_slice_prim_impl(t19335, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19338: "cuda:0 bf16[1, 71, 2048, 64]"
del t19335
t19342 = torch_slice_prim_impl(t19334, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19342: "cuda:0 bf16[1, 71, 2048, 64]"
del t19334
t19445 = torch.reshape(t19336, (1, 1, 71, 2048, 64)) # t19445: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19445 = ltorch.reshape(t19336, (1, 1, 71, 2048, 64)) # t19445: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19445 = prims.reshape(t19336, (1, 1, 71, 2048, 64)) # t19445: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t19336
[t19479] = nvFusion74(i475, t19338, t19342, t19445, t61, t66)
# t19339 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t19339: "cuda:0 bf16[1, 71, 2048, 0]"
# t19340 = prims.pad(t19339, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t19340: "cuda:0 bf16[1, 71, 2048, 64]"
# t19343 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t19343: "cuda:0 bf16[1, 71, 2048, 0]"
# t19344 = prims.pad(t19343, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t19344: "cuda:0 bf16[1, 71, 2048, 64]"
# t19345 = prims.convert_element_type(t19338, dtypes.float32) # t19345: "cuda:0 f32[1, 71, 2048, 64]"
# t19349 = prims.mul(t66, t19345) # t19349: "cuda:0 f32[1, 71, 2048, 64]"
# t19352 = prims.convert_element_type(t19349, dtypes.bfloat16) # t19352: "cuda:0 bf16[1, 71, 2048, 64]"
# t19361 = prims.mul(t61, t19345) # t19361: "cuda:0 f32[1, 71, 2048, 64]"
# t19373 = prims.slice_prim(t19352, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t19373: "cuda:0 bf16[1, 71, 2048, 32]"
# t19374 = prims.slice_prim(t19352, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19374: "cuda:0 bf16[1, 71, 2048, 32]"
# t19375 = prims.convert_element_type(t19373, dtypes.float32) # t19375: "cuda:0 f32[1, 71, 2048, 32]"
# t19376 = prims.neg(t19375) # t19376: "cuda:0 f32[1, 71, 2048, 32]"
# t19377 = prims.convert_element_type(t19376, dtypes.bfloat16) # t19377: "cuda:0 bf16[1, 71, 2048, 32]"
# t19378 = prims.pad(t19377, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t19378: "cuda:0 bf16[1, 71, 2048, 64]"
# t19380 = prims.convert_element_type(t19378, dtypes.float32) # t19380: "cuda:0 f32[1, 71, 2048, 64]"
# t19381 = prims.add(t19361, t19380) # t19381: "cuda:0 f32[1, 71, 2048, 64]"
# t19383 = prims.pad(t19374, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t19383: "cuda:0 bf16[1, 71, 2048, 64]"
# t19385 = prims.convert_element_type(t19383, dtypes.float32) # t19385: "cuda:0 f32[1, 71, 2048, 64]"
# t19386 = prims.add(t19381, t19385) # t19386: "cuda:0 f32[1, 71, 2048, 64]"
# t19387 = prims.convert_element_type(t19386, dtypes.bfloat16) # t19387: "cuda:0 bf16[1, 71, 2048, 64]"
# t19388 = prims.pad(t19387, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t19388: "cuda:0 bf16[1, 71, 2048, 64]"
# t19389 = prims.convert_element_type(t19340, dtypes.float32) # t19389: "cuda:0 f32[1, 71, 2048, 64]"
# t19390 = prims.convert_element_type(t19388, dtypes.float32) # t19390: "cuda:0 f32[1, 71, 2048, 64]"
# t19391 = prims.add(t19389, t19390) # t19391: "cuda:0 f32[1, 71, 2048, 64]"
# t19392 = prims.convert_element_type(t19391, dtypes.bfloat16) # t19392: "cuda:0 bf16[1, 71, 2048, 64]"
# t19393 = prims.convert_element_type(t19342, dtypes.float32) # t19393: "cuda:0 f32[1, 71, 2048, 64]"
# t19397 = prims.mul(t66, t19393) # t19397: "cuda:0 f32[1, 71, 2048, 64]"
# t19400 = prims.convert_element_type(t19397, dtypes.bfloat16) # t19400: "cuda:0 bf16[1, 71, 2048, 64]"
# t19409 = prims.mul(t61, t19393) # t19409: "cuda:0 f32[1, 71, 2048, 64]"
# t19421 = prims.slice_prim(t19400, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t19421: "cuda:0 bf16[1, 71, 2048, 32]"
# t19422 = prims.slice_prim(t19400, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19422: "cuda:0 bf16[1, 71, 2048, 32]"
# t19423 = prims.convert_element_type(t19421, dtypes.float32) # t19423: "cuda:0 f32[1, 71, 2048, 32]"
# t19424 = prims.neg(t19423) # t19424: "cuda:0 f32[1, 71, 2048, 32]"
# t19425 = prims.convert_element_type(t19424, dtypes.bfloat16) # t19425: "cuda:0 bf16[1, 71, 2048, 32]"
# t19426 = prims.pad(t19425, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t19426: "cuda:0 bf16[1, 71, 2048, 64]"
# t19428 = prims.convert_element_type(t19426, dtypes.float32) # t19428: "cuda:0 f32[1, 71, 2048, 64]"
# t19429 = prims.add(t19409, t19428) # t19429: "cuda:0 f32[1, 71, 2048, 64]"
# t19431 = prims.pad(t19422, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t19431: "cuda:0 bf16[1, 71, 2048, 64]"
# t19433 = prims.convert_element_type(t19431, dtypes.float32) # t19433: "cuda:0 f32[1, 71, 2048, 64]"
# t19434 = prims.add(t19429, t19433) # t19434: "cuda:0 f32[1, 71, 2048, 64]"
# t19435 = prims.convert_element_type(t19434, dtypes.bfloat16) # t19435: "cuda:0 bf16[1, 71, 2048, 64]"
# t19436 = prims.pad(t19435, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t19436: "cuda:0 bf16[1, 71, 2048, 64]"
# t19437 = prims.convert_element_type(t19344, dtypes.float32) # t19437: "cuda:0 f32[1, 71, 2048, 64]"
# t19438 = prims.convert_element_type(t19436, dtypes.float32) # t19438: "cuda:0 f32[1, 71, 2048, 64]"
# t19439 = prims.add(t19437, t19438) # t19439: "cuda:0 f32[1, 71, 2048, 64]"
# t19440 = prims.convert_element_type(t19439, dtypes.bfloat16) # t19440: "cuda:0 bf16[1, 71, 2048, 64]"
# t19450 = prims.reshape(t19392, (1, 1, 71, 2048, 64)) # t19450: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19455 = prims.reshape(t19440, (1, 1, 71, 2048, 64)) # t19455: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19461 = prims.convert_element_type(t19445, dtypes.float32) # t19461: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t19462 = prims.sum(t19461, (0, 1, 2)) # t19462: "cuda:0 f32[2048, 64]"
# t19463 = prims.convert_element_type(t19462, dtypes.bfloat16) # t19463: "cuda:0 bf16[2048, 64]"
# t19464 = prims.broadcast_in_dim(t19463, [1, 1, 1, 2048, 64], [3, 4]) # t19464: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t19470 = prims.convert_element_type(t19450, dtypes.float32) # t19470: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t19471 = prims.sum(t19470, (0, 1, 2)) # t19471: "cuda:0 f32[2048, 64]"
# t19472 = prims.convert_element_type(t19471, dtypes.bfloat16) # t19472: "cuda:0 bf16[2048, 64]"
# t19473 = prims.broadcast_in_dim(t19472, [1, 1, 1, 2048, 64], [3, 4]) # t19473: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t19479 = prims.cat((t19455, t19473, t19464), i475) # t19479: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i475, t19338, t19342, t19445
t19485 = torch.permute(t19479, (0, 3, 1, 2, 4)) # t19485: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t19485 = ltorch.permute(t19479, (0, 3, 1, 2, 4)) # t19485: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t19485 = prims.transpose(t19479, (0, 3, 1, 2, 4)) # t19485: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t19479
t19491 = torch.reshape(t19485, (1, 2048, 4672)) # t19491: "cuda:0 bf16[1, 2048, 4672]"
# t19491 = ltorch.reshape(t19485, (1, 2048, 4672)) # t19491: "cuda:0 bf16[1, 2048, 4672]"
# t19491 = prims.reshape(t19485, (1, 2048, 4672)) # t19491: "cuda:0 bf16[1, 2048, 4672]"
del t19485
t19492 = torch.reshape(t19491, (-1, 4672)) # t19492: "cuda:0 bf16[2048, 4672]"
# t19492 = ltorch.reshape(t19491, (-1, 4672)) # t19492: "cuda:0 bf16[2048, 4672]"
# t19492 = prims.reshape(t19491, (2048, 4672)) # t19492: "cuda:0 bf16[2048, 4672]"
del t19491
t19496 = torch.permute(t19492, (1, 0)) # t19496: "cuda:0 bf16[4672, 2048]"
# t19496 = ltorch.permute(t19492, (1, 0)) # t19496: "cuda:0 bf16[4672, 2048]"
# t19496 = prims.transpose(t19492, (1, 0)) # t19496: "cuda:0 bf16[4672, 2048]"
t19498 = torch.matmul(t19496, t19318) # t19498: "cuda:0 bf16[4672, 4544]"
# t19498 = ltorch.matmul(t19496, t19497) # t19498: "cuda:0 bf16[4672, 4544]"
# t19498 = prims.matmul(t19496, t19497) # t19498: "cuda:0 bf16[4672, 4544]"
del t19496, t19318
t19493 = torch.matmul(t19492, t_transformer_h_7_attn_attn_weight) # t19493: "cuda:0 bf16[2048, 4544]"
# t19493 = ltorch.matmul(t19492, t_transformer_h_7_attn_attn_weight) # t19493: "cuda:0 bf16[2048, 4544]"
# t19493 = prims.matmul(t19492, t_transformer_h_7_attn_attn_weight) # t19493: "cuda:0 bf16[2048, 4544]"
del t19492, t_transformer_h_7_attn_attn_weight
t19315 = torch.reshape(t19314, (1, 2048, 4544)) # t19315: "cuda:0 bf16[1, 2048, 4544]"
# t19315 = ltorch.reshape(t19314, (1, 2048, 4544)) # t19315: "cuda:0 bf16[1, 2048, 4544]"
# t19315 = prims.reshape(t19314, (1, 2048, 4544)) # t19315: "cuda:0 bf16[1, 2048, 4544]"
del t19314
t19494 = torch.reshape(t19493, (1, 2048, 4544)) # t19494: "cuda:0 bf16[1, 2048, 4544]"
# t19494 = ltorch.reshape(t19493, (1, 2048, 4544)) # t19494: "cuda:0 bf16[1, 2048, 4544]"
# t19494 = prims.reshape(t19493, (1, 2048, 4544)) # t19494: "cuda:0 bf16[1, 2048, 4544]"
del t19493
[t19507, t19513, t19555] = nvFusion75(i19535, t1069, t1090, t1105, t1110, t1116, t19272, t19315, t19494, t937)
# t1096 = prims.convert_element_type(t937, dtypes.float32) # t1096: "cuda:0 f32[1, 2048, 4544]"
# t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: "cuda:0 f32[1, 2048, 4544]"
# t1092 = prims.convert_element_type(t1069, dtypes.float32) # t1092: "cuda:0 f32[1, 2048, 4544]"
# t1093 = prims.add(t1091, t1092) # t1093: "cuda:0 f32[1, 2048, 4544]"
# t1097 = prims.add(t1093, t1096) # t1097: "cuda:0 f32[1, 2048, 4544]"
# t1107 = prims.broadcast_in_dim(t1105, [1, 2048, 1], [0, 1]) # t1107: "cuda:0 f32[1, 2048, 1]"
# t1111 = prims.broadcast_in_dim(t1107, (1, 2048, 4544), (0, 1, 2)) # t1111: "cuda:0 f32[1, 2048, 4544]"
# t1113 = prims.sub(t1097, t1111) # t1113: "cuda:0 f32[1, 2048, 4544]"
# t1114 = prims.broadcast_in_dim(t1110, (1, 2048, 4544), (0, 1, 2)) # t1114: "cuda:0 f32[1, 2048, 4544]"
# t1115 = prims.mul(t1113, t1114) # t1115: "cuda:0 f32[1, 2048, 4544]"
# t1117 = prims.convert_element_type(t1116, dtypes.float32) # t1117: "cuda:0 f32[1, 2048, 4544]"
# t19552 = prims.convert_element_type(t19272, dtypes.float32) # t19552: "cuda:0 f32[1, 2048, 4544]"
# t19499 = prims.convert_element_type(t19315, dtypes.float32) # t19499: "cuda:0 f32[1, 2048, 4544]"
# t19500 = prims.convert_element_type(t19494, dtypes.float32) # t19500: "cuda:0 f32[1, 2048, 4544]"
# t19501 = prims.add(t19499, t19500) # t19501: "cuda:0 f32[1, 2048, 4544]"
# t19506 = prims.sum(t19501, (0, 1)) # t19506: "cuda:0 f32[4544]"
# t19507 = prims.convert_element_type(t19506, dtypes.bfloat16) # t19507: "cuda:0 bf16[4544]"
# t19508 = prims.mul(t1117, t19501) # t19508: "cuda:0 f32[1, 2048, 4544]"
# t19509 = prims.mul(t1115, t19501) # t19509: "cuda:0 f32[1, 2048, 4544]"
# t19512 = prims.sum(t19509, (0, 1)) # t19512: "cuda:0 f32[4544]"
# t19513 = prims.convert_element_type(t19512, dtypes.bfloat16) # t19513: "cuda:0 bf16[4544]"
# t19514 = prims.mul(t1114, t19508) # t19514: "cuda:0 f32[1, 2048, 4544]"
# t19515 = prims.mul(t1113, t19508) # t19515: "cuda:0 f32[1, 2048, 4544]"
# t19516 = prims.sum(t19515, (0, 2)) # t19516: "cuda:0 f32[2048]"
# t19517 = prims.broadcast_in_dim(t19516, [1, 2048, 1], [1]) # t19517: "cuda:0 f32[1, 2048, 1]"
# t19518 = prims.neg(t19514) # t19518: "cuda:0 f32[1, 2048, 4544]"
# t19520 = prims.sum(t19518, (0, 2)) # t19520: "cuda:0 f32[2048]"
# t19521 = prims.broadcast_in_dim(t19520, [1, 2048, 1], [1]) # t19521: "cuda:0 f32[1, 2048, 1]"
# t19522 = prims.mul(-0.5, t19517) # t19522: "cuda:0 f32[1, 2048, 1]"
# t19523 = prims.pow(t1110, 3.0) # t19523: "cuda:0 f32[1, 2048, 1]"
# t19524 = prims.mul(t19522, t19523) # t19524: "cuda:0 f32[1, 2048, 1]"
# t19526 = prims.sum(t19521, (0, 2)) # t19526: "cuda:0 f32[2048]"
# t19527 = prims.broadcast_in_dim(t19526, [1, 2048], [1]) # t19527: "cuda:0 f32[1, 2048]"
# t19528 = prims.sum(t19524, (0, 2)) # t19528: "cuda:0 f32[2048]"
# t19529 = prims.broadcast_in_dim(t19528, [1, 2048], [1]) # t19529: "cuda:0 f32[1, 2048]"
# t19532 = prims.broadcast_in_dim(t19527, [1, 2048, 1], [0, 1]) # t19532: "cuda:0 f32[1, 2048, 1]"
# t19533 = prims.broadcast_in_dim(t19532, (1, 2048, 4544), (0, 1, 2)) # t19533: "cuda:0 f32[1, 2048, 4544]"
# t19534 = prims.mul(0.00022007042253521127, t19533) # t19534: "cuda:0 f32[1, 2048, 4544]"
# t19536 = prims.broadcast_in_dim(t19529, [1, 2048, 1], [0, 1]) # t19536: "cuda:0 f32[1, 2048, 1]"
# t19537 = prims.broadcast_in_dim(t19536, (1, 2048, 4544), (0, 1, 2)) # t19537: "cuda:0 f32[1, 2048, 4544]"
# t19539 = prims.broadcast_in_dim(t1105, [1, 2048, 1], [0, 1]) # t19539: "cuda:0 f32[1, 2048, 1]"
# t19540 = prims.broadcast_in_dim(t19539, (1, 2048, 4544), (0, 1, 2)) # t19540: "cuda:0 f32[1, 2048, 4544]"
# t19541 = prims.mul(2.0, t19537) # t19541: "cuda:0 f32[1, 2048, 4544]"
# t19542 = prims.sub(t1097, t19540) # t19542: "cuda:0 f32[1, 2048, 4544]"
# t19543 = prims.mul(t19541, t19542) # t19543: "cuda:0 f32[1, 2048, 4544]"
# f19544 = prims.convert_element_type(i19535, float) # f19544: "float 4544.0"
# t19545 = prims.div(t19543, f19544) # t19545: "cuda:0 f32[1, 2048, 4544]"
# t19546 = prims.add(t19534, t19545) # t19546: "cuda:0 f32[1, 2048, 4544]"
# t19550 = prims.add(t19514, t19546) # t19550: "cuda:0 f32[1, 2048, 4544]"
# t19554 = prims.add(t19552, t19550) # t19554: "cuda:0 f32[1, 2048, 4544]"
# t19555 = prims.convert_element_type(t19554, dtypes.bfloat16) # t19555: "cuda:0 bf16[1, 2048, 4544]"
del i19535, t1069, t1090, t1105, t1110, t1116, t19272, t19315, t19494, t937
t19562 = torch.reshape(t19555, (-1, 4544)) # t19562: "cuda:0 bf16[2048, 4544]"
# t19562 = ltorch.reshape(t19555, (-1, 4544)) # t19562: "cuda:0 bf16[2048, 4544]"
# t19562 = prims.reshape(t19555, (2048, 4544)) # t19562: "cuda:0 bf16[2048, 4544]"
t19566 = torch.permute(t19562, (1, 0)) # t19566: "cuda:0 bf16[4544, 2048]"
# t19566 = ltorch.permute(t19562, (1, 0)) # t19566: "cuda:0 bf16[4544, 2048]"
# t19566 = prims.transpose(t19562, (1, 0)) # t19566: "cuda:0 bf16[4544, 2048]"
t19609 = torch.matmul(t19566, t19608) # t19609: "cuda:0 bf16[4544, 4544]"
# t19609 = ltorch.matmul(t19607, t19608) # t19609: "cuda:0 bf16[4544, 4544]"
# t19609 = prims.matmul(t19607, t19608) # t19609: "cuda:0 bf16[4544, 4544]"
del t19608
t19563 = torch.matmul(t19562, t_transformer_h_6_mlp_proj_weight) # t19563: "cuda:0 bf16[2048, 18176]"
# t19563 = ltorch.matmul(t19562, t_transformer_h_6_mlp_proj_weight) # t19563: "cuda:0 bf16[2048, 18176]"
# t19563 = prims.matmul(t19562, t_transformer_h_6_mlp_proj_weight) # t19563: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_6_mlp_proj_weight
t19568 = torch.matmul(t19566, t19567) # t19568: "cuda:0 bf16[4544, 18176]"
# t19568 = ltorch.matmul(t19566, t19567) # t19568: "cuda:0 bf16[4544, 18176]"
# t19568 = prims.matmul(t19566, t19567) # t19568: "cuda:0 bf16[4544, 18176]"
del t19566, t19567
t19604 = torch.matmul(t19562, t_transformer_h_6_attn_proj_weight) # t19604: "cuda:0 bf16[2048, 4544]"
# t19604 = ltorch.matmul(t19603, t_transformer_h_6_attn_proj_weight) # t19604: "cuda:0 bf16[2048, 4544]"
# t19604 = prims.matmul(t19603, t_transformer_h_6_attn_proj_weight) # t19604: "cuda:0 bf16[2048, 4544]"
del t19562, t_transformer_h_6_attn_proj_weight
t19564 = torch.reshape(t19563, (1, 2048, 18176)) # t19564: "cuda:0 bf16[1, 2048, 18176]"
# t19564 = ltorch.reshape(t19563, (1, 2048, 18176)) # t19564: "cuda:0 bf16[1, 2048, 18176]"
# t19564 = prims.reshape(t19563, (1, 2048, 18176)) # t19564: "cuda:0 bf16[1, 2048, 18176]"
del t19563
t19605 = torch.reshape(t19604, (1, 2048, 4544)) # t19605: "cuda:0 bf16[1, 2048, 4544]"
# t19605 = ltorch.reshape(t19604, (1, 2048, 4544)) # t19605: "cuda:0 bf16[1, 2048, 4544]"
# t19605 = prims.reshape(t19604, (1, 2048, 4544)) # t19605: "cuda:0 bf16[1, 2048, 4544]"
del t19604
t19613 = torch.reshape(t19605, (1, 2048, 71, 64)) # t19613: "cuda:0 bf16[1, 2048, 71, 64]"
# t19613 = ltorch.reshape(t19605, (1, 2048, 71, 64)) # t19613: "cuda:0 bf16[1, 2048, 71, 64]"
# t19613 = prims.reshape(t19605, (1, 2048, 71, 64)) # t19613: "cuda:0 bf16[1, 2048, 71, 64]"
del t19605
t19616 = torch.permute(t19613, (0, 2, 1, 3)) # t19616: "cuda:0 bf16[1, 71, 2048, 64]"
# t19616 = ltorch.permute(t19613, (0, 2, 1, 3)) # t19616: "cuda:0 bf16[1, 71, 2048, 64]"
# t19616 = prims.transpose(t19613, (0, 2, 1, 3)) # t19616: "cuda:0 bf16[1, 71, 2048, 64]"
del t19613
[t19595] = nvFusion76(f447, f449, t1070, t19564)
# t1071 = prims.convert_element_type(t1070, dtypes.float32) # t1071: "cuda:0 f32[1, 2048, 18176]"
# t1073 = prims.div(t1071, 1.4142135623730951) # t1073: "cuda:0 f32[1, 2048, 18176]"
# t1076 = prims.erf(t1073) # t1076: "cuda:0 f32[1, 2048, 18176]"
# t1080 = prims.mul(0.5, t1076) # t1080: "cuda:0 f32[1, 2048, 18176]"
# t1084 = prims.add(0.5, t1080) # t1084: "cuda:0 f32[1, 2048, 18176]"
# t19569 = prims.convert_element_type(t19564, dtypes.float32) # t19569: "cuda:0 f32[1, 2048, 18176]"
# t19570 = prims.mul(t1084, t19569) # t19570: "cuda:0 f32[1, 2048, 18176]"
# t19571 = prims.mul(t1071, t19569) # t19571: "cuda:0 f32[1, 2048, 18176]"
# t19579 = prims.mul(f449, t19571) # t19579: "cuda:0 f32[1, 2048, 18176]"
# t19582 = prims.pow(t1073, 2.0) # t19582: "cuda:0 f32[1, 2048, 18176]"
# t19583 = prims.neg(t19582) # t19583: "cuda:0 f32[1, 2048, 18176]"
# t19584 = prims.exp(t19583) # t19584: "cuda:0 f32[1, 2048, 18176]"
# t19585 = prims.mul(1.1283791670955126, t19584) # t19585: "cuda:0 f32[1, 2048, 18176]"
# t19586 = prims.mul(t19585, t19579) # t19586: "cuda:0 f32[1, 2048, 18176]"
# t19590 = prims.div(t19586, f447) # t19590: "cuda:0 f32[1, 2048, 18176]"
# t19594 = prims.add(t19570, t19590) # t19594: "cuda:0 f32[1, 2048, 18176]"
# t19595 = prims.convert_element_type(t19594, dtypes.bfloat16) # t19595: "cuda:0 bf16[1, 2048, 18176]"
del f447, f449, t1070, t19564
t19596 = torch.reshape(t19595, (-1, 18176)) # t19596: "cuda:0 bf16[2048, 18176]"
# t19596 = ltorch.reshape(t19595, (-1, 18176)) # t19596: "cuda:0 bf16[2048, 18176]"
# t19596 = prims.reshape(t19595, (2048, 18176)) # t19596: "cuda:0 bf16[2048, 18176]"
del t19595
t19600 = torch.permute(t19596, (1, 0)) # t19600: "cuda:0 bf16[18176, 2048]"
# t19600 = ltorch.permute(t19596, (1, 0)) # t19600: "cuda:0 bf16[18176, 2048]"
# t19600 = prims.transpose(t19596, (1, 0)) # t19600: "cuda:0 bf16[18176, 2048]"
t19602 = torch.matmul(t19600, t19601) # t19602: "cuda:0 bf16[18176, 4544]"
# t19602 = ltorch.matmul(t19600, t19601) # t19602: "cuda:0 bf16[18176, 4544]"
# t19602 = prims.matmul(t19600, t19601) # t19602: "cuda:0 bf16[18176, 4544]"
del t19600
t19597 = torch.matmul(t19596, t_transformer_h_6_mlp_fc_weight) # t19597: "cuda:0 bf16[2048, 4544]"
# t19597 = ltorch.matmul(t19596, t_transformer_h_6_mlp_fc_weight) # t19597: "cuda:0 bf16[2048, 4544]"
# t19597 = prims.matmul(t19596, t_transformer_h_6_mlp_fc_weight) # t19597: "cuda:0 bf16[2048, 4544]"
del t19596, t_transformer_h_6_mlp_fc_weight
(t19617, t19618, t19619) = cudnn_sdpa_bwd(t19616, t1054, t1057, t1007, None, f438, b439, t1058, t1059, t1060, t1061, scale=f440, cat_grad_qkv=False)
del t19616, t1054, t1057, t1007, f438, b439, t1058, t1059, t1060, t1061, f440
t19621 = torch_slice_prim_impl(t19618, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19621: "cuda:0 bf16[1, 71, 2048, 64]"
del t19618
t19625 = torch_slice_prim_impl(t19617, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19625: "cuda:0 bf16[1, 71, 2048, 64]"
del t19617
t19728 = torch.reshape(t19619, (1, 1, 71, 2048, 64)) # t19728: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19728 = ltorch.reshape(t19619, (1, 1, 71, 2048, 64)) # t19728: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19728 = prims.reshape(t19619, (1, 1, 71, 2048, 64)) # t19728: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t19619
[t19762] = nvFusion77(i411, t19621, t19625, t19728, t61, t66)
# t19622 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t19622: "cuda:0 bf16[1, 71, 2048, 0]"
# t19623 = prims.pad(t19622, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t19623: "cuda:0 bf16[1, 71, 2048, 64]"
# t19626 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t19626: "cuda:0 bf16[1, 71, 2048, 0]"
# t19627 = prims.pad(t19626, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t19627: "cuda:0 bf16[1, 71, 2048, 64]"
# t19628 = prims.convert_element_type(t19621, dtypes.float32) # t19628: "cuda:0 f32[1, 71, 2048, 64]"
# t19632 = prims.mul(t66, t19628) # t19632: "cuda:0 f32[1, 71, 2048, 64]"
# t19635 = prims.convert_element_type(t19632, dtypes.bfloat16) # t19635: "cuda:0 bf16[1, 71, 2048, 64]"
# t19644 = prims.mul(t61, t19628) # t19644: "cuda:0 f32[1, 71, 2048, 64]"
# t19656 = prims.slice_prim(t19635, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t19656: "cuda:0 bf16[1, 71, 2048, 32]"
# t19657 = prims.slice_prim(t19635, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19657: "cuda:0 bf16[1, 71, 2048, 32]"
# t19658 = prims.convert_element_type(t19656, dtypes.float32) # t19658: "cuda:0 f32[1, 71, 2048, 32]"
# t19659 = prims.neg(t19658) # t19659: "cuda:0 f32[1, 71, 2048, 32]"
# t19660 = prims.convert_element_type(t19659, dtypes.bfloat16) # t19660: "cuda:0 bf16[1, 71, 2048, 32]"
# t19661 = prims.pad(t19660, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t19661: "cuda:0 bf16[1, 71, 2048, 64]"
# t19663 = prims.convert_element_type(t19661, dtypes.float32) # t19663: "cuda:0 f32[1, 71, 2048, 64]"
# t19664 = prims.add(t19644, t19663) # t19664: "cuda:0 f32[1, 71, 2048, 64]"
# t19666 = prims.pad(t19657, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t19666: "cuda:0 bf16[1, 71, 2048, 64]"
# t19668 = prims.convert_element_type(t19666, dtypes.float32) # t19668: "cuda:0 f32[1, 71, 2048, 64]"
# t19669 = prims.add(t19664, t19668) # t19669: "cuda:0 f32[1, 71, 2048, 64]"
# t19670 = prims.convert_element_type(t19669, dtypes.bfloat16) # t19670: "cuda:0 bf16[1, 71, 2048, 64]"
# t19671 = prims.pad(t19670, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t19671: "cuda:0 bf16[1, 71, 2048, 64]"
# t19672 = prims.convert_element_type(t19623, dtypes.float32) # t19672: "cuda:0 f32[1, 71, 2048, 64]"
# t19673 = prims.convert_element_type(t19671, dtypes.float32) # t19673: "cuda:0 f32[1, 71, 2048, 64]"
# t19674 = prims.add(t19672, t19673) # t19674: "cuda:0 f32[1, 71, 2048, 64]"
# t19675 = prims.convert_element_type(t19674, dtypes.bfloat16) # t19675: "cuda:0 bf16[1, 71, 2048, 64]"
# t19676 = prims.convert_element_type(t19625, dtypes.float32) # t19676: "cuda:0 f32[1, 71, 2048, 64]"
# t19680 = prims.mul(t66, t19676) # t19680: "cuda:0 f32[1, 71, 2048, 64]"
# t19683 = prims.convert_element_type(t19680, dtypes.bfloat16) # t19683: "cuda:0 bf16[1, 71, 2048, 64]"
# t19692 = prims.mul(t61, t19676) # t19692: "cuda:0 f32[1, 71, 2048, 64]"
# t19704 = prims.slice_prim(t19683, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t19704: "cuda:0 bf16[1, 71, 2048, 32]"
# t19705 = prims.slice_prim(t19683, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19705: "cuda:0 bf16[1, 71, 2048, 32]"
# t19706 = prims.convert_element_type(t19704, dtypes.float32) # t19706: "cuda:0 f32[1, 71, 2048, 32]"
# t19707 = prims.neg(t19706) # t19707: "cuda:0 f32[1, 71, 2048, 32]"
# t19708 = prims.convert_element_type(t19707, dtypes.bfloat16) # t19708: "cuda:0 bf16[1, 71, 2048, 32]"
# t19709 = prims.pad(t19708, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t19709: "cuda:0 bf16[1, 71, 2048, 64]"
# t19711 = prims.convert_element_type(t19709, dtypes.float32) # t19711: "cuda:0 f32[1, 71, 2048, 64]"
# t19712 = prims.add(t19692, t19711) # t19712: "cuda:0 f32[1, 71, 2048, 64]"
# t19714 = prims.pad(t19705, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t19714: "cuda:0 bf16[1, 71, 2048, 64]"
# t19716 = prims.convert_element_type(t19714, dtypes.float32) # t19716: "cuda:0 f32[1, 71, 2048, 64]"
# t19717 = prims.add(t19712, t19716) # t19717: "cuda:0 f32[1, 71, 2048, 64]"
# t19718 = prims.convert_element_type(t19717, dtypes.bfloat16) # t19718: "cuda:0 bf16[1, 71, 2048, 64]"
# t19719 = prims.pad(t19718, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t19719: "cuda:0 bf16[1, 71, 2048, 64]"
# t19720 = prims.convert_element_type(t19627, dtypes.float32) # t19720: "cuda:0 f32[1, 71, 2048, 64]"
# t19721 = prims.convert_element_type(t19719, dtypes.float32) # t19721: "cuda:0 f32[1, 71, 2048, 64]"
# t19722 = prims.add(t19720, t19721) # t19722: "cuda:0 f32[1, 71, 2048, 64]"
# t19723 = prims.convert_element_type(t19722, dtypes.bfloat16) # t19723: "cuda:0 bf16[1, 71, 2048, 64]"
# t19733 = prims.reshape(t19675, (1, 1, 71, 2048, 64)) # t19733: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19738 = prims.reshape(t19723, (1, 1, 71, 2048, 64)) # t19738: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t19744 = prims.convert_element_type(t19728, dtypes.float32) # t19744: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t19745 = prims.sum(t19744, (0, 1, 2)) # t19745: "cuda:0 f32[2048, 64]"
# t19746 = prims.convert_element_type(t19745, dtypes.bfloat16) # t19746: "cuda:0 bf16[2048, 64]"
# t19747 = prims.broadcast_in_dim(t19746, [1, 1, 1, 2048, 64], [3, 4]) # t19747: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t19753 = prims.convert_element_type(t19733, dtypes.float32) # t19753: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t19754 = prims.sum(t19753, (0, 1, 2)) # t19754: "cuda:0 f32[2048, 64]"
# t19755 = prims.convert_element_type(t19754, dtypes.bfloat16) # t19755: "cuda:0 bf16[2048, 64]"
# t19756 = prims.broadcast_in_dim(t19755, [1, 1, 1, 2048, 64], [3, 4]) # t19756: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t19762 = prims.cat((t19738, t19756, t19747), i411) # t19762: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i411, t19621, t19625, t19728
t19768 = torch.permute(t19762, (0, 3, 1, 2, 4)) # t19768: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t19768 = ltorch.permute(t19762, (0, 3, 1, 2, 4)) # t19768: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t19768 = prims.transpose(t19762, (0, 3, 1, 2, 4)) # t19768: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t19762
t19774 = torch.reshape(t19768, (1, 2048, 4672)) # t19774: "cuda:0 bf16[1, 2048, 4672]"
# t19774 = ltorch.reshape(t19768, (1, 2048, 4672)) # t19774: "cuda:0 bf16[1, 2048, 4672]"
# t19774 = prims.reshape(t19768, (1, 2048, 4672)) # t19774: "cuda:0 bf16[1, 2048, 4672]"
del t19768
t19775 = torch.reshape(t19774, (-1, 4672)) # t19775: "cuda:0 bf16[2048, 4672]"
# t19775 = ltorch.reshape(t19774, (-1, 4672)) # t19775: "cuda:0 bf16[2048, 4672]"
# t19775 = prims.reshape(t19774, (2048, 4672)) # t19775: "cuda:0 bf16[2048, 4672]"
del t19774
t19779 = torch.permute(t19775, (1, 0)) # t19779: "cuda:0 bf16[4672, 2048]"
# t19779 = ltorch.permute(t19775, (1, 0)) # t19779: "cuda:0 bf16[4672, 2048]"
# t19779 = prims.transpose(t19775, (1, 0)) # t19779: "cuda:0 bf16[4672, 2048]"
t19781 = torch.matmul(t19779, t19601) # t19781: "cuda:0 bf16[4672, 4544]"
# t19781 = ltorch.matmul(t19779, t19780) # t19781: "cuda:0 bf16[4672, 4544]"
# t19781 = prims.matmul(t19779, t19780) # t19781: "cuda:0 bf16[4672, 4544]"
del t19779, t19601
t19776 = torch.matmul(t19775, t_transformer_h_6_attn_attn_weight) # t19776: "cuda:0 bf16[2048, 4544]"
# t19776 = ltorch.matmul(t19775, t_transformer_h_6_attn_attn_weight) # t19776: "cuda:0 bf16[2048, 4544]"
# t19776 = prims.matmul(t19775, t_transformer_h_6_attn_attn_weight) # t19776: "cuda:0 bf16[2048, 4544]"
del t19775, t_transformer_h_6_attn_attn_weight
t19598 = torch.reshape(t19597, (1, 2048, 4544)) # t19598: "cuda:0 bf16[1, 2048, 4544]"
# t19598 = ltorch.reshape(t19597, (1, 2048, 4544)) # t19598: "cuda:0 bf16[1, 2048, 4544]"
# t19598 = prims.reshape(t19597, (1, 2048, 4544)) # t19598: "cuda:0 bf16[1, 2048, 4544]"
del t19597
t19777 = torch.reshape(t19776, (1, 2048, 4544)) # t19777: "cuda:0 bf16[1, 2048, 4544]"
# t19777 = ltorch.reshape(t19776, (1, 2048, 4544)) # t19777: "cuda:0 bf16[1, 2048, 4544]"
# t19777 = prims.reshape(t19776, (1, 2048, 4544)) # t19777: "cuda:0 bf16[1, 2048, 4544]"
del t19776
[t19790, t19796, t19838] = nvFusion78(i19818, t19555, t19598, t19777, t776, t908, t929, t944, t949, t955)
# t935 = prims.convert_element_type(t776, dtypes.float32) # t935: "cuda:0 f32[1, 2048, 4544]"
# t930 = prims.convert_element_type(t929, dtypes.float32) # t930: "cuda:0 f32[1, 2048, 4544]"
# t931 = prims.convert_element_type(t908, dtypes.float32) # t931: "cuda:0 f32[1, 2048, 4544]"
# t932 = prims.add(t930, t931) # t932: "cuda:0 f32[1, 2048, 4544]"
# t936 = prims.add(t932, t935) # t936: "cuda:0 f32[1, 2048, 4544]"
# t946 = prims.broadcast_in_dim(t944, [1, 2048, 1], [0, 1]) # t946: "cuda:0 f32[1, 2048, 1]"
# t950 = prims.broadcast_in_dim(t946, (1, 2048, 4544), (0, 1, 2)) # t950: "cuda:0 f32[1, 2048, 4544]"
# t952 = prims.sub(t936, t950) # t952: "cuda:0 f32[1, 2048, 4544]"
# t953 = prims.broadcast_in_dim(t949, (1, 2048, 4544), (0, 1, 2)) # t953: "cuda:0 f32[1, 2048, 4544]"
# t954 = prims.mul(t952, t953) # t954: "cuda:0 f32[1, 2048, 4544]"
# t956 = prims.convert_element_type(t955, dtypes.float32) # t956: "cuda:0 f32[1, 2048, 4544]"
# t19835 = prims.convert_element_type(t19555, dtypes.float32) # t19835: "cuda:0 f32[1, 2048, 4544]"
# t19782 = prims.convert_element_type(t19598, dtypes.float32) # t19782: "cuda:0 f32[1, 2048, 4544]"
# t19783 = prims.convert_element_type(t19777, dtypes.float32) # t19783: "cuda:0 f32[1, 2048, 4544]"
# t19784 = prims.add(t19782, t19783) # t19784: "cuda:0 f32[1, 2048, 4544]"
# t19789 = prims.sum(t19784, (0, 1)) # t19789: "cuda:0 f32[4544]"
# t19790 = prims.convert_element_type(t19789, dtypes.bfloat16) # t19790: "cuda:0 bf16[4544]"
# t19791 = prims.mul(t956, t19784) # t19791: "cuda:0 f32[1, 2048, 4544]"
# t19792 = prims.mul(t954, t19784) # t19792: "cuda:0 f32[1, 2048, 4544]"
# t19795 = prims.sum(t19792, (0, 1)) # t19795: "cuda:0 f32[4544]"
# t19796 = prims.convert_element_type(t19795, dtypes.bfloat16) # t19796: "cuda:0 bf16[4544]"
# t19797 = prims.mul(t953, t19791) # t19797: "cuda:0 f32[1, 2048, 4544]"
# t19798 = prims.mul(t952, t19791) # t19798: "cuda:0 f32[1, 2048, 4544]"
# t19799 = prims.sum(t19798, (0, 2)) # t19799: "cuda:0 f32[2048]"
# t19800 = prims.broadcast_in_dim(t19799, [1, 2048, 1], [1]) # t19800: "cuda:0 f32[1, 2048, 1]"
# t19801 = prims.neg(t19797) # t19801: "cuda:0 f32[1, 2048, 4544]"
# t19803 = prims.sum(t19801, (0, 2)) # t19803: "cuda:0 f32[2048]"
# t19804 = prims.broadcast_in_dim(t19803, [1, 2048, 1], [1]) # t19804: "cuda:0 f32[1, 2048, 1]"
# t19805 = prims.mul(-0.5, t19800) # t19805: "cuda:0 f32[1, 2048, 1]"
# t19806 = prims.pow(t949, 3.0) # t19806: "cuda:0 f32[1, 2048, 1]"
# t19807 = prims.mul(t19805, t19806) # t19807: "cuda:0 f32[1, 2048, 1]"
# t19809 = prims.sum(t19804, (0, 2)) # t19809: "cuda:0 f32[2048]"
# t19810 = prims.broadcast_in_dim(t19809, [1, 2048], [1]) # t19810: "cuda:0 f32[1, 2048]"
# t19811 = prims.sum(t19807, (0, 2)) # t19811: "cuda:0 f32[2048]"
# t19812 = prims.broadcast_in_dim(t19811, [1, 2048], [1]) # t19812: "cuda:0 f32[1, 2048]"
# t19815 = prims.broadcast_in_dim(t19810, [1, 2048, 1], [0, 1]) # t19815: "cuda:0 f32[1, 2048, 1]"
# t19816 = prims.broadcast_in_dim(t19815, (1, 2048, 4544), (0, 1, 2)) # t19816: "cuda:0 f32[1, 2048, 4544]"
# t19817 = prims.mul(0.00022007042253521127, t19816) # t19817: "cuda:0 f32[1, 2048, 4544]"
# t19819 = prims.broadcast_in_dim(t19812, [1, 2048, 1], [0, 1]) # t19819: "cuda:0 f32[1, 2048, 1]"
# t19820 = prims.broadcast_in_dim(t19819, (1, 2048, 4544), (0, 1, 2)) # t19820: "cuda:0 f32[1, 2048, 4544]"
# t19822 = prims.broadcast_in_dim(t944, [1, 2048, 1], [0, 1]) # t19822: "cuda:0 f32[1, 2048, 1]"
# t19823 = prims.broadcast_in_dim(t19822, (1, 2048, 4544), (0, 1, 2)) # t19823: "cuda:0 f32[1, 2048, 4544]"
# t19824 = prims.mul(2.0, t19820) # t19824: "cuda:0 f32[1, 2048, 4544]"
# t19825 = prims.sub(t936, t19823) # t19825: "cuda:0 f32[1, 2048, 4544]"
# t19826 = prims.mul(t19824, t19825) # t19826: "cuda:0 f32[1, 2048, 4544]"
# f19827 = prims.convert_element_type(i19818, float) # f19827: "float 4544.0"
# t19828 = prims.div(t19826, f19827) # t19828: "cuda:0 f32[1, 2048, 4544]"
# t19829 = prims.add(t19817, t19828) # t19829: "cuda:0 f32[1, 2048, 4544]"
# t19833 = prims.add(t19797, t19829) # t19833: "cuda:0 f32[1, 2048, 4544]"
# t19837 = prims.add(t19835, t19833) # t19837: "cuda:0 f32[1, 2048, 4544]"
# t19838 = prims.convert_element_type(t19837, dtypes.bfloat16) # t19838: "cuda:0 bf16[1, 2048, 4544]"
del i19818, t19555, t19598, t19777, t776, t908, t929, t944, t949, t955
t19845 = torch.reshape(t19838, (-1, 4544)) # t19845: "cuda:0 bf16[2048, 4544]"
# t19845 = ltorch.reshape(t19838, (-1, 4544)) # t19845: "cuda:0 bf16[2048, 4544]"
# t19845 = prims.reshape(t19838, (2048, 4544)) # t19845: "cuda:0 bf16[2048, 4544]"
t19849 = torch.permute(t19845, (1, 0)) # t19849: "cuda:0 bf16[4544, 2048]"
# t19849 = ltorch.permute(t19845, (1, 0)) # t19849: "cuda:0 bf16[4544, 2048]"
# t19849 = prims.transpose(t19845, (1, 0)) # t19849: "cuda:0 bf16[4544, 2048]"
t19846 = torch.matmul(t19845, t_transformer_h_5_mlp_proj_weight) # t19846: "cuda:0 bf16[2048, 18176]"
# t19846 = ltorch.matmul(t19845, t_transformer_h_5_mlp_proj_weight) # t19846: "cuda:0 bf16[2048, 18176]"
# t19846 = prims.matmul(t19845, t_transformer_h_5_mlp_proj_weight) # t19846: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_5_mlp_proj_weight
t19851 = torch.matmul(t19849, t19850) # t19851: "cuda:0 bf16[4544, 18176]"
# t19851 = ltorch.matmul(t19849, t19850) # t19851: "cuda:0 bf16[4544, 18176]"
# t19851 = prims.matmul(t19849, t19850) # t19851: "cuda:0 bf16[4544, 18176]"
del t19850
t19887 = torch.matmul(t19845, t_transformer_h_5_attn_proj_weight) # t19887: "cuda:0 bf16[2048, 4544]"
# t19887 = ltorch.matmul(t19886, t_transformer_h_5_attn_proj_weight) # t19887: "cuda:0 bf16[2048, 4544]"
# t19887 = prims.matmul(t19886, t_transformer_h_5_attn_proj_weight) # t19887: "cuda:0 bf16[2048, 4544]"
del t19845, t_transformer_h_5_attn_proj_weight
t19892 = torch.matmul(t19849, t19891) # t19892: "cuda:0 bf16[4544, 4544]"
# t19892 = ltorch.matmul(t19890, t19891) # t19892: "cuda:0 bf16[4544, 4544]"
# t19892 = prims.matmul(t19890, t19891) # t19892: "cuda:0 bf16[4544, 4544]"
del t19849, t19891
t19847 = torch.reshape(t19846, (1, 2048, 18176)) # t19847: "cuda:0 bf16[1, 2048, 18176]"
# t19847 = ltorch.reshape(t19846, (1, 2048, 18176)) # t19847: "cuda:0 bf16[1, 2048, 18176]"
# t19847 = prims.reshape(t19846, (1, 2048, 18176)) # t19847: "cuda:0 bf16[1, 2048, 18176]"
del t19846
t19888 = torch.reshape(t19887, (1, 2048, 4544)) # t19888: "cuda:0 bf16[1, 2048, 4544]"
# t19888 = ltorch.reshape(t19887, (1, 2048, 4544)) # t19888: "cuda:0 bf16[1, 2048, 4544]"
# t19888 = prims.reshape(t19887, (1, 2048, 4544)) # t19888: "cuda:0 bf16[1, 2048, 4544]"
del t19887
t19896 = torch.reshape(t19888, (1, 2048, 71, 64)) # t19896: "cuda:0 bf16[1, 2048, 71, 64]"
# t19896 = ltorch.reshape(t19888, (1, 2048, 71, 64)) # t19896: "cuda:0 bf16[1, 2048, 71, 64]"
# t19896 = prims.reshape(t19888, (1, 2048, 71, 64)) # t19896: "cuda:0 bf16[1, 2048, 71, 64]"
del t19888
t19899 = torch.permute(t19896, (0, 2, 1, 3)) # t19899: "cuda:0 bf16[1, 71, 2048, 64]"
# t19899 = ltorch.permute(t19896, (0, 2, 1, 3)) # t19899: "cuda:0 bf16[1, 71, 2048, 64]"
# t19899 = prims.transpose(t19896, (0, 2, 1, 3)) # t19899: "cuda:0 bf16[1, 71, 2048, 64]"
del t19896
[t19878] = nvFusion79(f383, f385, t19847, t909)
# t910 = prims.convert_element_type(t909, dtypes.float32) # t910: "cuda:0 f32[1, 2048, 18176]"
# t912 = prims.div(t910, 1.4142135623730951) # t912: "cuda:0 f32[1, 2048, 18176]"
# t915 = prims.erf(t912) # t915: "cuda:0 f32[1, 2048, 18176]"
# t919 = prims.mul(0.5, t915) # t919: "cuda:0 f32[1, 2048, 18176]"
# t923 = prims.add(0.5, t919) # t923: "cuda:0 f32[1, 2048, 18176]"
# t19852 = prims.convert_element_type(t19847, dtypes.float32) # t19852: "cuda:0 f32[1, 2048, 18176]"
# t19853 = prims.mul(t923, t19852) # t19853: "cuda:0 f32[1, 2048, 18176]"
# t19854 = prims.mul(t910, t19852) # t19854: "cuda:0 f32[1, 2048, 18176]"
# t19862 = prims.mul(f385, t19854) # t19862: "cuda:0 f32[1, 2048, 18176]"
# t19865 = prims.pow(t912, 2.0) # t19865: "cuda:0 f32[1, 2048, 18176]"
# t19866 = prims.neg(t19865) # t19866: "cuda:0 f32[1, 2048, 18176]"
# t19867 = prims.exp(t19866) # t19867: "cuda:0 f32[1, 2048, 18176]"
# t19868 = prims.mul(1.1283791670955126, t19867) # t19868: "cuda:0 f32[1, 2048, 18176]"
# t19869 = prims.mul(t19868, t19862) # t19869: "cuda:0 f32[1, 2048, 18176]"
# t19873 = prims.div(t19869, f383) # t19873: "cuda:0 f32[1, 2048, 18176]"
# t19877 = prims.add(t19853, t19873) # t19877: "cuda:0 f32[1, 2048, 18176]"
# t19878 = prims.convert_element_type(t19877, dtypes.bfloat16) # t19878: "cuda:0 bf16[1, 2048, 18176]"
del f383, f385, t19847, t909
t19879 = torch.reshape(t19878, (-1, 18176)) # t19879: "cuda:0 bf16[2048, 18176]"
# t19879 = ltorch.reshape(t19878, (-1, 18176)) # t19879: "cuda:0 bf16[2048, 18176]"
# t19879 = prims.reshape(t19878, (2048, 18176)) # t19879: "cuda:0 bf16[2048, 18176]"
del t19878
t19883 = torch.permute(t19879, (1, 0)) # t19883: "cuda:0 bf16[18176, 2048]"
# t19883 = ltorch.permute(t19879, (1, 0)) # t19883: "cuda:0 bf16[18176, 2048]"
# t19883 = prims.transpose(t19879, (1, 0)) # t19883: "cuda:0 bf16[18176, 2048]"
t19885 = torch.matmul(t19883, t19884) # t19885: "cuda:0 bf16[18176, 4544]"
# t19885 = ltorch.matmul(t19883, t19884) # t19885: "cuda:0 bf16[18176, 4544]"
# t19885 = prims.matmul(t19883, t19884) # t19885: "cuda:0 bf16[18176, 4544]"
del t19883
t19880 = torch.matmul(t19879, t_transformer_h_5_mlp_fc_weight) # t19880: "cuda:0 bf16[2048, 4544]"
# t19880 = ltorch.matmul(t19879, t_transformer_h_5_mlp_fc_weight) # t19880: "cuda:0 bf16[2048, 4544]"
# t19880 = prims.matmul(t19879, t_transformer_h_5_mlp_fc_weight) # t19880: "cuda:0 bf16[2048, 4544]"
del t19879, t_transformer_h_5_mlp_fc_weight
(t19900, t19901, t19902) = cudnn_sdpa_bwd(t19899, t893, t896, t846, None, f374, b375, t897, t898, t899, t900, scale=f376, cat_grad_qkv=False)
del t19899, t893, t896, t846, f374, b375, t897, t898, t899, t900, f376
t19904 = torch_slice_prim_impl(t19901, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19904: "cuda:0 bf16[1, 71, 2048, 64]"
del t19901
t19908 = torch_slice_prim_impl(t19900, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19908: "cuda:0 bf16[1, 71, 2048, 64]"
del t19900
t20011 = torch.reshape(t19902, (1, 1, 71, 2048, 64)) # t20011: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20011 = ltorch.reshape(t19902, (1, 1, 71, 2048, 64)) # t20011: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20011 = prims.reshape(t19902, (1, 1, 71, 2048, 64)) # t20011: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t19902
[t20045] = nvFusion80(i347, t19904, t19908, t20011, t61, t66)
# t19905 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t19905: "cuda:0 bf16[1, 71, 2048, 0]"
# t19906 = prims.pad(t19905, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t19906: "cuda:0 bf16[1, 71, 2048, 64]"
# t19909 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t19909: "cuda:0 bf16[1, 71, 2048, 0]"
# t19910 = prims.pad(t19909, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t19910: "cuda:0 bf16[1, 71, 2048, 64]"
# t19911 = prims.convert_element_type(t19904, dtypes.float32) # t19911: "cuda:0 f32[1, 71, 2048, 64]"
# t19915 = prims.mul(t66, t19911) # t19915: "cuda:0 f32[1, 71, 2048, 64]"
# t19918 = prims.convert_element_type(t19915, dtypes.bfloat16) # t19918: "cuda:0 bf16[1, 71, 2048, 64]"
# t19927 = prims.mul(t61, t19911) # t19927: "cuda:0 f32[1, 71, 2048, 64]"
# t19939 = prims.slice_prim(t19918, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t19939: "cuda:0 bf16[1, 71, 2048, 32]"
# t19940 = prims.slice_prim(t19918, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19940: "cuda:0 bf16[1, 71, 2048, 32]"
# t19941 = prims.convert_element_type(t19939, dtypes.float32) # t19941: "cuda:0 f32[1, 71, 2048, 32]"
# t19942 = prims.neg(t19941) # t19942: "cuda:0 f32[1, 71, 2048, 32]"
# t19943 = prims.convert_element_type(t19942, dtypes.bfloat16) # t19943: "cuda:0 bf16[1, 71, 2048, 32]"
# t19944 = prims.pad(t19943, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t19944: "cuda:0 bf16[1, 71, 2048, 64]"
# t19946 = prims.convert_element_type(t19944, dtypes.float32) # t19946: "cuda:0 f32[1, 71, 2048, 64]"
# t19947 = prims.add(t19927, t19946) # t19947: "cuda:0 f32[1, 71, 2048, 64]"
# t19949 = prims.pad(t19940, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t19949: "cuda:0 bf16[1, 71, 2048, 64]"
# t19951 = prims.convert_element_type(t19949, dtypes.float32) # t19951: "cuda:0 f32[1, 71, 2048, 64]"
# t19952 = prims.add(t19947, t19951) # t19952: "cuda:0 f32[1, 71, 2048, 64]"
# t19953 = prims.convert_element_type(t19952, dtypes.bfloat16) # t19953: "cuda:0 bf16[1, 71, 2048, 64]"
# t19954 = prims.pad(t19953, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t19954: "cuda:0 bf16[1, 71, 2048, 64]"
# t19955 = prims.convert_element_type(t19906, dtypes.float32) # t19955: "cuda:0 f32[1, 71, 2048, 64]"
# t19956 = prims.convert_element_type(t19954, dtypes.float32) # t19956: "cuda:0 f32[1, 71, 2048, 64]"
# t19957 = prims.add(t19955, t19956) # t19957: "cuda:0 f32[1, 71, 2048, 64]"
# t19958 = prims.convert_element_type(t19957, dtypes.bfloat16) # t19958: "cuda:0 bf16[1, 71, 2048, 64]"
# t19959 = prims.convert_element_type(t19908, dtypes.float32) # t19959: "cuda:0 f32[1, 71, 2048, 64]"
# t19963 = prims.mul(t66, t19959) # t19963: "cuda:0 f32[1, 71, 2048, 64]"
# t19966 = prims.convert_element_type(t19963, dtypes.bfloat16) # t19966: "cuda:0 bf16[1, 71, 2048, 64]"
# t19975 = prims.mul(t61, t19959) # t19975: "cuda:0 f32[1, 71, 2048, 64]"
# t19987 = prims.slice_prim(t19966, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t19987: "cuda:0 bf16[1, 71, 2048, 32]"
# t19988 = prims.slice_prim(t19966, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t19988: "cuda:0 bf16[1, 71, 2048, 32]"
# t19989 = prims.convert_element_type(t19987, dtypes.float32) # t19989: "cuda:0 f32[1, 71, 2048, 32]"
# t19990 = prims.neg(t19989) # t19990: "cuda:0 f32[1, 71, 2048, 32]"
# t19991 = prims.convert_element_type(t19990, dtypes.bfloat16) # t19991: "cuda:0 bf16[1, 71, 2048, 32]"
# t19992 = prims.pad(t19991, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t19992: "cuda:0 bf16[1, 71, 2048, 64]"
# t19994 = prims.convert_element_type(t19992, dtypes.float32) # t19994: "cuda:0 f32[1, 71, 2048, 64]"
# t19995 = prims.add(t19975, t19994) # t19995: "cuda:0 f32[1, 71, 2048, 64]"
# t19997 = prims.pad(t19988, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t19997: "cuda:0 bf16[1, 71, 2048, 64]"
# t19999 = prims.convert_element_type(t19997, dtypes.float32) # t19999: "cuda:0 f32[1, 71, 2048, 64]"
# t20000 = prims.add(t19995, t19999) # t20000: "cuda:0 f32[1, 71, 2048, 64]"
# t20001 = prims.convert_element_type(t20000, dtypes.bfloat16) # t20001: "cuda:0 bf16[1, 71, 2048, 64]"
# t20002 = prims.pad(t20001, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t20002: "cuda:0 bf16[1, 71, 2048, 64]"
# t20003 = prims.convert_element_type(t19910, dtypes.float32) # t20003: "cuda:0 f32[1, 71, 2048, 64]"
# t20004 = prims.convert_element_type(t20002, dtypes.float32) # t20004: "cuda:0 f32[1, 71, 2048, 64]"
# t20005 = prims.add(t20003, t20004) # t20005: "cuda:0 f32[1, 71, 2048, 64]"
# t20006 = prims.convert_element_type(t20005, dtypes.bfloat16) # t20006: "cuda:0 bf16[1, 71, 2048, 64]"
# t20016 = prims.reshape(t19958, (1, 1, 71, 2048, 64)) # t20016: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20021 = prims.reshape(t20006, (1, 1, 71, 2048, 64)) # t20021: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20027 = prims.convert_element_type(t20011, dtypes.float32) # t20027: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t20028 = prims.sum(t20027, (0, 1, 2)) # t20028: "cuda:0 f32[2048, 64]"
# t20029 = prims.convert_element_type(t20028, dtypes.bfloat16) # t20029: "cuda:0 bf16[2048, 64]"
# t20030 = prims.broadcast_in_dim(t20029, [1, 1, 1, 2048, 64], [3, 4]) # t20030: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t20036 = prims.convert_element_type(t20016, dtypes.float32) # t20036: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t20037 = prims.sum(t20036, (0, 1, 2)) # t20037: "cuda:0 f32[2048, 64]"
# t20038 = prims.convert_element_type(t20037, dtypes.bfloat16) # t20038: "cuda:0 bf16[2048, 64]"
# t20039 = prims.broadcast_in_dim(t20038, [1, 1, 1, 2048, 64], [3, 4]) # t20039: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t20045 = prims.cat((t20021, t20039, t20030), i347) # t20045: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i347, t19904, t19908, t20011
t20051 = torch.permute(t20045, (0, 3, 1, 2, 4)) # t20051: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t20051 = ltorch.permute(t20045, (0, 3, 1, 2, 4)) # t20051: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t20051 = prims.transpose(t20045, (0, 3, 1, 2, 4)) # t20051: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t20045
t20057 = torch.reshape(t20051, (1, 2048, 4672)) # t20057: "cuda:0 bf16[1, 2048, 4672]"
# t20057 = ltorch.reshape(t20051, (1, 2048, 4672)) # t20057: "cuda:0 bf16[1, 2048, 4672]"
# t20057 = prims.reshape(t20051, (1, 2048, 4672)) # t20057: "cuda:0 bf16[1, 2048, 4672]"
del t20051
t20058 = torch.reshape(t20057, (-1, 4672)) # t20058: "cuda:0 bf16[2048, 4672]"
# t20058 = ltorch.reshape(t20057, (-1, 4672)) # t20058: "cuda:0 bf16[2048, 4672]"
# t20058 = prims.reshape(t20057, (2048, 4672)) # t20058: "cuda:0 bf16[2048, 4672]"
del t20057
t20062 = torch.permute(t20058, (1, 0)) # t20062: "cuda:0 bf16[4672, 2048]"
# t20062 = ltorch.permute(t20058, (1, 0)) # t20062: "cuda:0 bf16[4672, 2048]"
# t20062 = prims.transpose(t20058, (1, 0)) # t20062: "cuda:0 bf16[4672, 2048]"
t20064 = torch.matmul(t20062, t19884) # t20064: "cuda:0 bf16[4672, 4544]"
# t20064 = ltorch.matmul(t20062, t20063) # t20064: "cuda:0 bf16[4672, 4544]"
# t20064 = prims.matmul(t20062, t20063) # t20064: "cuda:0 bf16[4672, 4544]"
del t20062, t19884
t20059 = torch.matmul(t20058, t_transformer_h_5_attn_attn_weight) # t20059: "cuda:0 bf16[2048, 4544]"
# t20059 = ltorch.matmul(t20058, t_transformer_h_5_attn_attn_weight) # t20059: "cuda:0 bf16[2048, 4544]"
# t20059 = prims.matmul(t20058, t_transformer_h_5_attn_attn_weight) # t20059: "cuda:0 bf16[2048, 4544]"
del t20058, t_transformer_h_5_attn_attn_weight
t19881 = torch.reshape(t19880, (1, 2048, 4544)) # t19881: "cuda:0 bf16[1, 2048, 4544]"
# t19881 = ltorch.reshape(t19880, (1, 2048, 4544)) # t19881: "cuda:0 bf16[1, 2048, 4544]"
# t19881 = prims.reshape(t19880, (1, 2048, 4544)) # t19881: "cuda:0 bf16[1, 2048, 4544]"
del t19880
t20060 = torch.reshape(t20059, (1, 2048, 4544)) # t20060: "cuda:0 bf16[1, 2048, 4544]"
# t20060 = ltorch.reshape(t20059, (1, 2048, 4544)) # t20060: "cuda:0 bf16[1, 2048, 4544]"
# t20060 = prims.reshape(t20059, (1, 2048, 4544)) # t20060: "cuda:0 bf16[1, 2048, 4544]"
del t20059
[t20073, t20079, t20121] = nvFusion81(i20101, t19838, t19881, t20060, t615, t747, t768, t783, t788, t794)
# t774 = prims.convert_element_type(t615, dtypes.float32) # t774: "cuda:0 f32[1, 2048, 4544]"
# t769 = prims.convert_element_type(t768, dtypes.float32) # t769: "cuda:0 f32[1, 2048, 4544]"
# t770 = prims.convert_element_type(t747, dtypes.float32) # t770: "cuda:0 f32[1, 2048, 4544]"
# t771 = prims.add(t769, t770) # t771: "cuda:0 f32[1, 2048, 4544]"
# t775 = prims.add(t771, t774) # t775: "cuda:0 f32[1, 2048, 4544]"
# t785 = prims.broadcast_in_dim(t783, [1, 2048, 1], [0, 1]) # t785: "cuda:0 f32[1, 2048, 1]"
# t789 = prims.broadcast_in_dim(t785, (1, 2048, 4544), (0, 1, 2)) # t789: "cuda:0 f32[1, 2048, 4544]"
# t791 = prims.sub(t775, t789) # t791: "cuda:0 f32[1, 2048, 4544]"
# t792 = prims.broadcast_in_dim(t788, (1, 2048, 4544), (0, 1, 2)) # t792: "cuda:0 f32[1, 2048, 4544]"
# t793 = prims.mul(t791, t792) # t793: "cuda:0 f32[1, 2048, 4544]"
# t795 = prims.convert_element_type(t794, dtypes.float32) # t795: "cuda:0 f32[1, 2048, 4544]"
# t20118 = prims.convert_element_type(t19838, dtypes.float32) # t20118: "cuda:0 f32[1, 2048, 4544]"
# t20065 = prims.convert_element_type(t19881, dtypes.float32) # t20065: "cuda:0 f32[1, 2048, 4544]"
# t20066 = prims.convert_element_type(t20060, dtypes.float32) # t20066: "cuda:0 f32[1, 2048, 4544]"
# t20067 = prims.add(t20065, t20066) # t20067: "cuda:0 f32[1, 2048, 4544]"
# t20072 = prims.sum(t20067, (0, 1)) # t20072: "cuda:0 f32[4544]"
# t20073 = prims.convert_element_type(t20072, dtypes.bfloat16) # t20073: "cuda:0 bf16[4544]"
# t20074 = prims.mul(t795, t20067) # t20074: "cuda:0 f32[1, 2048, 4544]"
# t20075 = prims.mul(t793, t20067) # t20075: "cuda:0 f32[1, 2048, 4544]"
# t20078 = prims.sum(t20075, (0, 1)) # t20078: "cuda:0 f32[4544]"
# t20079 = prims.convert_element_type(t20078, dtypes.bfloat16) # t20079: "cuda:0 bf16[4544]"
# t20080 = prims.mul(t792, t20074) # t20080: "cuda:0 f32[1, 2048, 4544]"
# t20081 = prims.mul(t791, t20074) # t20081: "cuda:0 f32[1, 2048, 4544]"
# t20082 = prims.sum(t20081, (0, 2)) # t20082: "cuda:0 f32[2048]"
# t20083 = prims.broadcast_in_dim(t20082, [1, 2048, 1], [1]) # t20083: "cuda:0 f32[1, 2048, 1]"
# t20084 = prims.neg(t20080) # t20084: "cuda:0 f32[1, 2048, 4544]"
# t20086 = prims.sum(t20084, (0, 2)) # t20086: "cuda:0 f32[2048]"
# t20087 = prims.broadcast_in_dim(t20086, [1, 2048, 1], [1]) # t20087: "cuda:0 f32[1, 2048, 1]"
# t20088 = prims.mul(-0.5, t20083) # t20088: "cuda:0 f32[1, 2048, 1]"
# t20089 = prims.pow(t788, 3.0) # t20089: "cuda:0 f32[1, 2048, 1]"
# t20090 = prims.mul(t20088, t20089) # t20090: "cuda:0 f32[1, 2048, 1]"
# t20092 = prims.sum(t20087, (0, 2)) # t20092: "cuda:0 f32[2048]"
# t20093 = prims.broadcast_in_dim(t20092, [1, 2048], [1]) # t20093: "cuda:0 f32[1, 2048]"
# t20094 = prims.sum(t20090, (0, 2)) # t20094: "cuda:0 f32[2048]"
# t20095 = prims.broadcast_in_dim(t20094, [1, 2048], [1]) # t20095: "cuda:0 f32[1, 2048]"
# t20098 = prims.broadcast_in_dim(t20093, [1, 2048, 1], [0, 1]) # t20098: "cuda:0 f32[1, 2048, 1]"
# t20099 = prims.broadcast_in_dim(t20098, (1, 2048, 4544), (0, 1, 2)) # t20099: "cuda:0 f32[1, 2048, 4544]"
# t20100 = prims.mul(0.00022007042253521127, t20099) # t20100: "cuda:0 f32[1, 2048, 4544]"
# t20102 = prims.broadcast_in_dim(t20095, [1, 2048, 1], [0, 1]) # t20102: "cuda:0 f32[1, 2048, 1]"
# t20103 = prims.broadcast_in_dim(t20102, (1, 2048, 4544), (0, 1, 2)) # t20103: "cuda:0 f32[1, 2048, 4544]"
# t20105 = prims.broadcast_in_dim(t783, [1, 2048, 1], [0, 1]) # t20105: "cuda:0 f32[1, 2048, 1]"
# t20106 = prims.broadcast_in_dim(t20105, (1, 2048, 4544), (0, 1, 2)) # t20106: "cuda:0 f32[1, 2048, 4544]"
# t20107 = prims.mul(2.0, t20103) # t20107: "cuda:0 f32[1, 2048, 4544]"
# t20108 = prims.sub(t775, t20106) # t20108: "cuda:0 f32[1, 2048, 4544]"
# t20109 = prims.mul(t20107, t20108) # t20109: "cuda:0 f32[1, 2048, 4544]"
# f20110 = prims.convert_element_type(i20101, float) # f20110: "float 4544.0"
# t20111 = prims.div(t20109, f20110) # t20111: "cuda:0 f32[1, 2048, 4544]"
# t20112 = prims.add(t20100, t20111) # t20112: "cuda:0 f32[1, 2048, 4544]"
# t20116 = prims.add(t20080, t20112) # t20116: "cuda:0 f32[1, 2048, 4544]"
# t20120 = prims.add(t20118, t20116) # t20120: "cuda:0 f32[1, 2048, 4544]"
# t20121 = prims.convert_element_type(t20120, dtypes.bfloat16) # t20121: "cuda:0 bf16[1, 2048, 4544]"
del i20101, t19838, t19881, t20060, t615, t747, t768, t783, t788, t794
t20128 = torch.reshape(t20121, (-1, 4544)) # t20128: "cuda:0 bf16[2048, 4544]"
# t20128 = ltorch.reshape(t20121, (-1, 4544)) # t20128: "cuda:0 bf16[2048, 4544]"
# t20128 = prims.reshape(t20121, (2048, 4544)) # t20128: "cuda:0 bf16[2048, 4544]"
t20132 = torch.permute(t20128, (1, 0)) # t20132: "cuda:0 bf16[4544, 2048]"
# t20132 = ltorch.permute(t20128, (1, 0)) # t20132: "cuda:0 bf16[4544, 2048]"
# t20132 = prims.transpose(t20128, (1, 0)) # t20132: "cuda:0 bf16[4544, 2048]"
t20129 = torch.matmul(t20128, t_transformer_h_4_mlp_proj_weight) # t20129: "cuda:0 bf16[2048, 18176]"
# t20129 = ltorch.matmul(t20128, t_transformer_h_4_mlp_proj_weight) # t20129: "cuda:0 bf16[2048, 18176]"
# t20129 = prims.matmul(t20128, t_transformer_h_4_mlp_proj_weight) # t20129: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_4_mlp_proj_weight
t20134 = torch.matmul(t20132, t20133) # t20134: "cuda:0 bf16[4544, 18176]"
# t20134 = ltorch.matmul(t20132, t20133) # t20134: "cuda:0 bf16[4544, 18176]"
# t20134 = prims.matmul(t20132, t20133) # t20134: "cuda:0 bf16[4544, 18176]"
del t20133
t20170 = torch.matmul(t20128, t_transformer_h_4_attn_proj_weight) # t20170: "cuda:0 bf16[2048, 4544]"
# t20170 = ltorch.matmul(t20169, t_transformer_h_4_attn_proj_weight) # t20170: "cuda:0 bf16[2048, 4544]"
# t20170 = prims.matmul(t20169, t_transformer_h_4_attn_proj_weight) # t20170: "cuda:0 bf16[2048, 4544]"
del t20128, t_transformer_h_4_attn_proj_weight
t20175 = torch.matmul(t20132, t20174) # t20175: "cuda:0 bf16[4544, 4544]"
# t20175 = ltorch.matmul(t20173, t20174) # t20175: "cuda:0 bf16[4544, 4544]"
# t20175 = prims.matmul(t20173, t20174) # t20175: "cuda:0 bf16[4544, 4544]"
del t20132, t20174
t20130 = torch.reshape(t20129, (1, 2048, 18176)) # t20130: "cuda:0 bf16[1, 2048, 18176]"
# t20130 = ltorch.reshape(t20129, (1, 2048, 18176)) # t20130: "cuda:0 bf16[1, 2048, 18176]"
# t20130 = prims.reshape(t20129, (1, 2048, 18176)) # t20130: "cuda:0 bf16[1, 2048, 18176]"
del t20129
t20171 = torch.reshape(t20170, (1, 2048, 4544)) # t20171: "cuda:0 bf16[1, 2048, 4544]"
# t20171 = ltorch.reshape(t20170, (1, 2048, 4544)) # t20171: "cuda:0 bf16[1, 2048, 4544]"
# t20171 = prims.reshape(t20170, (1, 2048, 4544)) # t20171: "cuda:0 bf16[1, 2048, 4544]"
del t20170
t20179 = torch.reshape(t20171, (1, 2048, 71, 64)) # t20179: "cuda:0 bf16[1, 2048, 71, 64]"
# t20179 = ltorch.reshape(t20171, (1, 2048, 71, 64)) # t20179: "cuda:0 bf16[1, 2048, 71, 64]"
# t20179 = prims.reshape(t20171, (1, 2048, 71, 64)) # t20179: "cuda:0 bf16[1, 2048, 71, 64]"
del t20171
t20182 = torch.permute(t20179, (0, 2, 1, 3)) # t20182: "cuda:0 bf16[1, 71, 2048, 64]"
# t20182 = ltorch.permute(t20179, (0, 2, 1, 3)) # t20182: "cuda:0 bf16[1, 71, 2048, 64]"
# t20182 = prims.transpose(t20179, (0, 2, 1, 3)) # t20182: "cuda:0 bf16[1, 71, 2048, 64]"
del t20179
[t20161] = nvFusion82(f319, f321, t20130, t748)
# t749 = prims.convert_element_type(t748, dtypes.float32) # t749: "cuda:0 f32[1, 2048, 18176]"
# t751 = prims.div(t749, 1.4142135623730951) # t751: "cuda:0 f32[1, 2048, 18176]"
# t754 = prims.erf(t751) # t754: "cuda:0 f32[1, 2048, 18176]"
# t758 = prims.mul(0.5, t754) # t758: "cuda:0 f32[1, 2048, 18176]"
# t762 = prims.add(0.5, t758) # t762: "cuda:0 f32[1, 2048, 18176]"
# t20135 = prims.convert_element_type(t20130, dtypes.float32) # t20135: "cuda:0 f32[1, 2048, 18176]"
# t20136 = prims.mul(t762, t20135) # t20136: "cuda:0 f32[1, 2048, 18176]"
# t20137 = prims.mul(t749, t20135) # t20137: "cuda:0 f32[1, 2048, 18176]"
# t20145 = prims.mul(f321, t20137) # t20145: "cuda:0 f32[1, 2048, 18176]"
# t20148 = prims.pow(t751, 2.0) # t20148: "cuda:0 f32[1, 2048, 18176]"
# t20149 = prims.neg(t20148) # t20149: "cuda:0 f32[1, 2048, 18176]"
# t20150 = prims.exp(t20149) # t20150: "cuda:0 f32[1, 2048, 18176]"
# t20151 = prims.mul(1.1283791670955126, t20150) # t20151: "cuda:0 f32[1, 2048, 18176]"
# t20152 = prims.mul(t20151, t20145) # t20152: "cuda:0 f32[1, 2048, 18176]"
# t20156 = prims.div(t20152, f319) # t20156: "cuda:0 f32[1, 2048, 18176]"
# t20160 = prims.add(t20136, t20156) # t20160: "cuda:0 f32[1, 2048, 18176]"
# t20161 = prims.convert_element_type(t20160, dtypes.bfloat16) # t20161: "cuda:0 bf16[1, 2048, 18176]"
del f319, f321, t20130, t748
t20162 = torch.reshape(t20161, (-1, 18176)) # t20162: "cuda:0 bf16[2048, 18176]"
# t20162 = ltorch.reshape(t20161, (-1, 18176)) # t20162: "cuda:0 bf16[2048, 18176]"
# t20162 = prims.reshape(t20161, (2048, 18176)) # t20162: "cuda:0 bf16[2048, 18176]"
del t20161
t20166 = torch.permute(t20162, (1, 0)) # t20166: "cuda:0 bf16[18176, 2048]"
# t20166 = ltorch.permute(t20162, (1, 0)) # t20166: "cuda:0 bf16[18176, 2048]"
# t20166 = prims.transpose(t20162, (1, 0)) # t20166: "cuda:0 bf16[18176, 2048]"
t20168 = torch.matmul(t20166, t20167) # t20168: "cuda:0 bf16[18176, 4544]"
# t20168 = ltorch.matmul(t20166, t20167) # t20168: "cuda:0 bf16[18176, 4544]"
# t20168 = prims.matmul(t20166, t20167) # t20168: "cuda:0 bf16[18176, 4544]"
del t20166
t20163 = torch.matmul(t20162, t_transformer_h_4_mlp_fc_weight) # t20163: "cuda:0 bf16[2048, 4544]"
# t20163 = ltorch.matmul(t20162, t_transformer_h_4_mlp_fc_weight) # t20163: "cuda:0 bf16[2048, 4544]"
# t20163 = prims.matmul(t20162, t_transformer_h_4_mlp_fc_weight) # t20163: "cuda:0 bf16[2048, 4544]"
del t20162, t_transformer_h_4_mlp_fc_weight
(t20183, t20184, t20185) = cudnn_sdpa_bwd(t20182, t732, t735, t685, None, f310, b311, t736, t737, t738, t739, scale=f312, cat_grad_qkv=False)
del t20182, t732, t735, t685, f310, b311, t736, t737, t738, t739, f312
t20187 = torch_slice_prim_impl(t20184, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20187: "cuda:0 bf16[1, 71, 2048, 64]"
del t20184
t20191 = torch_slice_prim_impl(t20183, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20191: "cuda:0 bf16[1, 71, 2048, 64]"
del t20183
t20294 = torch.reshape(t20185, (1, 1, 71, 2048, 64)) # t20294: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20294 = ltorch.reshape(t20185, (1, 1, 71, 2048, 64)) # t20294: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20294 = prims.reshape(t20185, (1, 1, 71, 2048, 64)) # t20294: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t20185
[t20328] = nvFusion83(i283, t20187, t20191, t20294, t61, t66)
# t20188 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t20188: "cuda:0 bf16[1, 71, 2048, 0]"
# t20189 = prims.pad(t20188, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t20189: "cuda:0 bf16[1, 71, 2048, 64]"
# t20192 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t20192: "cuda:0 bf16[1, 71, 2048, 0]"
# t20193 = prims.pad(t20192, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t20193: "cuda:0 bf16[1, 71, 2048, 64]"
# t20194 = prims.convert_element_type(t20187, dtypes.float32) # t20194: "cuda:0 f32[1, 71, 2048, 64]"
# t20198 = prims.mul(t66, t20194) # t20198: "cuda:0 f32[1, 71, 2048, 64]"
# t20201 = prims.convert_element_type(t20198, dtypes.bfloat16) # t20201: "cuda:0 bf16[1, 71, 2048, 64]"
# t20210 = prims.mul(t61, t20194) # t20210: "cuda:0 f32[1, 71, 2048, 64]"
# t20222 = prims.slice_prim(t20201, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t20222: "cuda:0 bf16[1, 71, 2048, 32]"
# t20223 = prims.slice_prim(t20201, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20223: "cuda:0 bf16[1, 71, 2048, 32]"
# t20224 = prims.convert_element_type(t20222, dtypes.float32) # t20224: "cuda:0 f32[1, 71, 2048, 32]"
# t20225 = prims.neg(t20224) # t20225: "cuda:0 f32[1, 71, 2048, 32]"
# t20226 = prims.convert_element_type(t20225, dtypes.bfloat16) # t20226: "cuda:0 bf16[1, 71, 2048, 32]"
# t20227 = prims.pad(t20226, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t20227: "cuda:0 bf16[1, 71, 2048, 64]"
# t20229 = prims.convert_element_type(t20227, dtypes.float32) # t20229: "cuda:0 f32[1, 71, 2048, 64]"
# t20230 = prims.add(t20210, t20229) # t20230: "cuda:0 f32[1, 71, 2048, 64]"
# t20232 = prims.pad(t20223, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t20232: "cuda:0 bf16[1, 71, 2048, 64]"
# t20234 = prims.convert_element_type(t20232, dtypes.float32) # t20234: "cuda:0 f32[1, 71, 2048, 64]"
# t20235 = prims.add(t20230, t20234) # t20235: "cuda:0 f32[1, 71, 2048, 64]"
# t20236 = prims.convert_element_type(t20235, dtypes.bfloat16) # t20236: "cuda:0 bf16[1, 71, 2048, 64]"
# t20237 = prims.pad(t20236, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t20237: "cuda:0 bf16[1, 71, 2048, 64]"
# t20238 = prims.convert_element_type(t20189, dtypes.float32) # t20238: "cuda:0 f32[1, 71, 2048, 64]"
# t20239 = prims.convert_element_type(t20237, dtypes.float32) # t20239: "cuda:0 f32[1, 71, 2048, 64]"
# t20240 = prims.add(t20238, t20239) # t20240: "cuda:0 f32[1, 71, 2048, 64]"
# t20241 = prims.convert_element_type(t20240, dtypes.bfloat16) # t20241: "cuda:0 bf16[1, 71, 2048, 64]"
# t20242 = prims.convert_element_type(t20191, dtypes.float32) # t20242: "cuda:0 f32[1, 71, 2048, 64]"
# t20246 = prims.mul(t66, t20242) # t20246: "cuda:0 f32[1, 71, 2048, 64]"
# t20249 = prims.convert_element_type(t20246, dtypes.bfloat16) # t20249: "cuda:0 bf16[1, 71, 2048, 64]"
# t20258 = prims.mul(t61, t20242) # t20258: "cuda:0 f32[1, 71, 2048, 64]"
# t20270 = prims.slice_prim(t20249, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t20270: "cuda:0 bf16[1, 71, 2048, 32]"
# t20271 = prims.slice_prim(t20249, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20271: "cuda:0 bf16[1, 71, 2048, 32]"
# t20272 = prims.convert_element_type(t20270, dtypes.float32) # t20272: "cuda:0 f32[1, 71, 2048, 32]"
# t20273 = prims.neg(t20272) # t20273: "cuda:0 f32[1, 71, 2048, 32]"
# t20274 = prims.convert_element_type(t20273, dtypes.bfloat16) # t20274: "cuda:0 bf16[1, 71, 2048, 32]"
# t20275 = prims.pad(t20274, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t20275: "cuda:0 bf16[1, 71, 2048, 64]"
# t20277 = prims.convert_element_type(t20275, dtypes.float32) # t20277: "cuda:0 f32[1, 71, 2048, 64]"
# t20278 = prims.add(t20258, t20277) # t20278: "cuda:0 f32[1, 71, 2048, 64]"
# t20280 = prims.pad(t20271, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t20280: "cuda:0 bf16[1, 71, 2048, 64]"
# t20282 = prims.convert_element_type(t20280, dtypes.float32) # t20282: "cuda:0 f32[1, 71, 2048, 64]"
# t20283 = prims.add(t20278, t20282) # t20283: "cuda:0 f32[1, 71, 2048, 64]"
# t20284 = prims.convert_element_type(t20283, dtypes.bfloat16) # t20284: "cuda:0 bf16[1, 71, 2048, 64]"
# t20285 = prims.pad(t20284, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t20285: "cuda:0 bf16[1, 71, 2048, 64]"
# t20286 = prims.convert_element_type(t20193, dtypes.float32) # t20286: "cuda:0 f32[1, 71, 2048, 64]"
# t20287 = prims.convert_element_type(t20285, dtypes.float32) # t20287: "cuda:0 f32[1, 71, 2048, 64]"
# t20288 = prims.add(t20286, t20287) # t20288: "cuda:0 f32[1, 71, 2048, 64]"
# t20289 = prims.convert_element_type(t20288, dtypes.bfloat16) # t20289: "cuda:0 bf16[1, 71, 2048, 64]"
# t20299 = prims.reshape(t20241, (1, 1, 71, 2048, 64)) # t20299: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20304 = prims.reshape(t20289, (1, 1, 71, 2048, 64)) # t20304: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20310 = prims.convert_element_type(t20294, dtypes.float32) # t20310: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t20311 = prims.sum(t20310, (0, 1, 2)) # t20311: "cuda:0 f32[2048, 64]"
# t20312 = prims.convert_element_type(t20311, dtypes.bfloat16) # t20312: "cuda:0 bf16[2048, 64]"
# t20313 = prims.broadcast_in_dim(t20312, [1, 1, 1, 2048, 64], [3, 4]) # t20313: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t20319 = prims.convert_element_type(t20299, dtypes.float32) # t20319: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t20320 = prims.sum(t20319, (0, 1, 2)) # t20320: "cuda:0 f32[2048, 64]"
# t20321 = prims.convert_element_type(t20320, dtypes.bfloat16) # t20321: "cuda:0 bf16[2048, 64]"
# t20322 = prims.broadcast_in_dim(t20321, [1, 1, 1, 2048, 64], [3, 4]) # t20322: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t20328 = prims.cat((t20304, t20322, t20313), i283) # t20328: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i283, t20187, t20191, t20294
t20334 = torch.permute(t20328, (0, 3, 1, 2, 4)) # t20334: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t20334 = ltorch.permute(t20328, (0, 3, 1, 2, 4)) # t20334: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t20334 = prims.transpose(t20328, (0, 3, 1, 2, 4)) # t20334: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t20328
t20340 = torch.reshape(t20334, (1, 2048, 4672)) # t20340: "cuda:0 bf16[1, 2048, 4672]"
# t20340 = ltorch.reshape(t20334, (1, 2048, 4672)) # t20340: "cuda:0 bf16[1, 2048, 4672]"
# t20340 = prims.reshape(t20334, (1, 2048, 4672)) # t20340: "cuda:0 bf16[1, 2048, 4672]"
del t20334
t20341 = torch.reshape(t20340, (-1, 4672)) # t20341: "cuda:0 bf16[2048, 4672]"
# t20341 = ltorch.reshape(t20340, (-1, 4672)) # t20341: "cuda:0 bf16[2048, 4672]"
# t20341 = prims.reshape(t20340, (2048, 4672)) # t20341: "cuda:0 bf16[2048, 4672]"
del t20340
t20345 = torch.permute(t20341, (1, 0)) # t20345: "cuda:0 bf16[4672, 2048]"
# t20345 = ltorch.permute(t20341, (1, 0)) # t20345: "cuda:0 bf16[4672, 2048]"
# t20345 = prims.transpose(t20341, (1, 0)) # t20345: "cuda:0 bf16[4672, 2048]"
t20347 = torch.matmul(t20345, t20167) # t20347: "cuda:0 bf16[4672, 4544]"
# t20347 = ltorch.matmul(t20345, t20346) # t20347: "cuda:0 bf16[4672, 4544]"
# t20347 = prims.matmul(t20345, t20346) # t20347: "cuda:0 bf16[4672, 4544]"
del t20345, t20167
t20342 = torch.matmul(t20341, t_transformer_h_4_attn_attn_weight) # t20342: "cuda:0 bf16[2048, 4544]"
# t20342 = ltorch.matmul(t20341, t_transformer_h_4_attn_attn_weight) # t20342: "cuda:0 bf16[2048, 4544]"
# t20342 = prims.matmul(t20341, t_transformer_h_4_attn_attn_weight) # t20342: "cuda:0 bf16[2048, 4544]"
del t20341, t_transformer_h_4_attn_attn_weight
t20164 = torch.reshape(t20163, (1, 2048, 4544)) # t20164: "cuda:0 bf16[1, 2048, 4544]"
# t20164 = ltorch.reshape(t20163, (1, 2048, 4544)) # t20164: "cuda:0 bf16[1, 2048, 4544]"
# t20164 = prims.reshape(t20163, (1, 2048, 4544)) # t20164: "cuda:0 bf16[1, 2048, 4544]"
del t20163
t20343 = torch.reshape(t20342, (1, 2048, 4544)) # t20343: "cuda:0 bf16[1, 2048, 4544]"
# t20343 = ltorch.reshape(t20342, (1, 2048, 4544)) # t20343: "cuda:0 bf16[1, 2048, 4544]"
# t20343 = prims.reshape(t20342, (1, 2048, 4544)) # t20343: "cuda:0 bf16[1, 2048, 4544]"
del t20342
[t20356, t20362, t20404] = nvFusion84(i20384, t20121, t20164, t20343, t454, t586, t607, t622, t627, t633)
# t613 = prims.convert_element_type(t454, dtypes.float32) # t613: "cuda:0 f32[1, 2048, 4544]"
# t608 = prims.convert_element_type(t607, dtypes.float32) # t608: "cuda:0 f32[1, 2048, 4544]"
# t609 = prims.convert_element_type(t586, dtypes.float32) # t609: "cuda:0 f32[1, 2048, 4544]"
# t610 = prims.add(t608, t609) # t610: "cuda:0 f32[1, 2048, 4544]"
# t614 = prims.add(t610, t613) # t614: "cuda:0 f32[1, 2048, 4544]"
# t624 = prims.broadcast_in_dim(t622, [1, 2048, 1], [0, 1]) # t624: "cuda:0 f32[1, 2048, 1]"
# t628 = prims.broadcast_in_dim(t624, (1, 2048, 4544), (0, 1, 2)) # t628: "cuda:0 f32[1, 2048, 4544]"
# t630 = prims.sub(t614, t628) # t630: "cuda:0 f32[1, 2048, 4544]"
# t631 = prims.broadcast_in_dim(t627, (1, 2048, 4544), (0, 1, 2)) # t631: "cuda:0 f32[1, 2048, 4544]"
# t632 = prims.mul(t630, t631) # t632: "cuda:0 f32[1, 2048, 4544]"
# t634 = prims.convert_element_type(t633, dtypes.float32) # t634: "cuda:0 f32[1, 2048, 4544]"
# t20401 = prims.convert_element_type(t20121, dtypes.float32) # t20401: "cuda:0 f32[1, 2048, 4544]"
# t20348 = prims.convert_element_type(t20164, dtypes.float32) # t20348: "cuda:0 f32[1, 2048, 4544]"
# t20349 = prims.convert_element_type(t20343, dtypes.float32) # t20349: "cuda:0 f32[1, 2048, 4544]"
# t20350 = prims.add(t20348, t20349) # t20350: "cuda:0 f32[1, 2048, 4544]"
# t20355 = prims.sum(t20350, (0, 1)) # t20355: "cuda:0 f32[4544]"
# t20356 = prims.convert_element_type(t20355, dtypes.bfloat16) # t20356: "cuda:0 bf16[4544]"
# t20357 = prims.mul(t634, t20350) # t20357: "cuda:0 f32[1, 2048, 4544]"
# t20358 = prims.mul(t632, t20350) # t20358: "cuda:0 f32[1, 2048, 4544]"
# t20361 = prims.sum(t20358, (0, 1)) # t20361: "cuda:0 f32[4544]"
# t20362 = prims.convert_element_type(t20361, dtypes.bfloat16) # t20362: "cuda:0 bf16[4544]"
# t20363 = prims.mul(t631, t20357) # t20363: "cuda:0 f32[1, 2048, 4544]"
# t20364 = prims.mul(t630, t20357) # t20364: "cuda:0 f32[1, 2048, 4544]"
# t20365 = prims.sum(t20364, (0, 2)) # t20365: "cuda:0 f32[2048]"
# t20366 = prims.broadcast_in_dim(t20365, [1, 2048, 1], [1]) # t20366: "cuda:0 f32[1, 2048, 1]"
# t20367 = prims.neg(t20363) # t20367: "cuda:0 f32[1, 2048, 4544]"
# t20369 = prims.sum(t20367, (0, 2)) # t20369: "cuda:0 f32[2048]"
# t20370 = prims.broadcast_in_dim(t20369, [1, 2048, 1], [1]) # t20370: "cuda:0 f32[1, 2048, 1]"
# t20371 = prims.mul(-0.5, t20366) # t20371: "cuda:0 f32[1, 2048, 1]"
# t20372 = prims.pow(t627, 3.0) # t20372: "cuda:0 f32[1, 2048, 1]"
# t20373 = prims.mul(t20371, t20372) # t20373: "cuda:0 f32[1, 2048, 1]"
# t20375 = prims.sum(t20370, (0, 2)) # t20375: "cuda:0 f32[2048]"
# t20376 = prims.broadcast_in_dim(t20375, [1, 2048], [1]) # t20376: "cuda:0 f32[1, 2048]"
# t20377 = prims.sum(t20373, (0, 2)) # t20377: "cuda:0 f32[2048]"
# t20378 = prims.broadcast_in_dim(t20377, [1, 2048], [1]) # t20378: "cuda:0 f32[1, 2048]"
# t20381 = prims.broadcast_in_dim(t20376, [1, 2048, 1], [0, 1]) # t20381: "cuda:0 f32[1, 2048, 1]"
# t20382 = prims.broadcast_in_dim(t20381, (1, 2048, 4544), (0, 1, 2)) # t20382: "cuda:0 f32[1, 2048, 4544]"
# t20383 = prims.mul(0.00022007042253521127, t20382) # t20383: "cuda:0 f32[1, 2048, 4544]"
# t20385 = prims.broadcast_in_dim(t20378, [1, 2048, 1], [0, 1]) # t20385: "cuda:0 f32[1, 2048, 1]"
# t20386 = prims.broadcast_in_dim(t20385, (1, 2048, 4544), (0, 1, 2)) # t20386: "cuda:0 f32[1, 2048, 4544]"
# t20388 = prims.broadcast_in_dim(t622, [1, 2048, 1], [0, 1]) # t20388: "cuda:0 f32[1, 2048, 1]"
# t20389 = prims.broadcast_in_dim(t20388, (1, 2048, 4544), (0, 1, 2)) # t20389: "cuda:0 f32[1, 2048, 4544]"
# t20390 = prims.mul(2.0, t20386) # t20390: "cuda:0 f32[1, 2048, 4544]"
# t20391 = prims.sub(t614, t20389) # t20391: "cuda:0 f32[1, 2048, 4544]"
# t20392 = prims.mul(t20390, t20391) # t20392: "cuda:0 f32[1, 2048, 4544]"
# f20393 = prims.convert_element_type(i20384, float) # f20393: "float 4544.0"
# t20394 = prims.div(t20392, f20393) # t20394: "cuda:0 f32[1, 2048, 4544]"
# t20395 = prims.add(t20383, t20394) # t20395: "cuda:0 f32[1, 2048, 4544]"
# t20399 = prims.add(t20363, t20395) # t20399: "cuda:0 f32[1, 2048, 4544]"
# t20403 = prims.add(t20401, t20399) # t20403: "cuda:0 f32[1, 2048, 4544]"
# t20404 = prims.convert_element_type(t20403, dtypes.bfloat16) # t20404: "cuda:0 bf16[1, 2048, 4544]"
del i20384, t20121, t20164, t20343, t454, t586, t607, t622, t627, t633
t20411 = torch.reshape(t20404, (-1, 4544)) # t20411: "cuda:0 bf16[2048, 4544]"
# t20411 = ltorch.reshape(t20404, (-1, 4544)) # t20411: "cuda:0 bf16[2048, 4544]"
# t20411 = prims.reshape(t20404, (2048, 4544)) # t20411: "cuda:0 bf16[2048, 4544]"
t20415 = torch.permute(t20411, (1, 0)) # t20415: "cuda:0 bf16[4544, 2048]"
# t20415 = ltorch.permute(t20411, (1, 0)) # t20415: "cuda:0 bf16[4544, 2048]"
# t20415 = prims.transpose(t20411, (1, 0)) # t20415: "cuda:0 bf16[4544, 2048]"
t20417 = torch.matmul(t20415, t20416) # t20417: "cuda:0 bf16[4544, 18176]"
# t20417 = ltorch.matmul(t20415, t20416) # t20417: "cuda:0 bf16[4544, 18176]"
# t20417 = prims.matmul(t20415, t20416) # t20417: "cuda:0 bf16[4544, 18176]"
del t20416
t20453 = torch.matmul(t20411, t_transformer_h_3_attn_proj_weight) # t20453: "cuda:0 bf16[2048, 4544]"
# t20453 = ltorch.matmul(t20452, t_transformer_h_3_attn_proj_weight) # t20453: "cuda:0 bf16[2048, 4544]"
# t20453 = prims.matmul(t20452, t_transformer_h_3_attn_proj_weight) # t20453: "cuda:0 bf16[2048, 4544]"
del t_transformer_h_3_attn_proj_weight
t20458 = torch.matmul(t20415, t20457) # t20458: "cuda:0 bf16[4544, 4544]"
# t20458 = ltorch.matmul(t20456, t20457) # t20458: "cuda:0 bf16[4544, 4544]"
# t20458 = prims.matmul(t20456, t20457) # t20458: "cuda:0 bf16[4544, 4544]"
del t20415, t20457
t20412 = torch.matmul(t20411, t_transformer_h_3_mlp_proj_weight) # t20412: "cuda:0 bf16[2048, 18176]"
# t20412 = ltorch.matmul(t20411, t_transformer_h_3_mlp_proj_weight) # t20412: "cuda:0 bf16[2048, 18176]"
# t20412 = prims.matmul(t20411, t_transformer_h_3_mlp_proj_weight) # t20412: "cuda:0 bf16[2048, 18176]"
del t20411, t_transformer_h_3_mlp_proj_weight
t20454 = torch.reshape(t20453, (1, 2048, 4544)) # t20454: "cuda:0 bf16[1, 2048, 4544]"
# t20454 = ltorch.reshape(t20453, (1, 2048, 4544)) # t20454: "cuda:0 bf16[1, 2048, 4544]"
# t20454 = prims.reshape(t20453, (1, 2048, 4544)) # t20454: "cuda:0 bf16[1, 2048, 4544]"
del t20453
t20462 = torch.reshape(t20454, (1, 2048, 71, 64)) # t20462: "cuda:0 bf16[1, 2048, 71, 64]"
# t20462 = ltorch.reshape(t20454, (1, 2048, 71, 64)) # t20462: "cuda:0 bf16[1, 2048, 71, 64]"
# t20462 = prims.reshape(t20454, (1, 2048, 71, 64)) # t20462: "cuda:0 bf16[1, 2048, 71, 64]"
del t20454
t20465 = torch.permute(t20462, (0, 2, 1, 3)) # t20465: "cuda:0 bf16[1, 71, 2048, 64]"
# t20465 = ltorch.permute(t20462, (0, 2, 1, 3)) # t20465: "cuda:0 bf16[1, 71, 2048, 64]"
# t20465 = prims.transpose(t20462, (0, 2, 1, 3)) # t20465: "cuda:0 bf16[1, 71, 2048, 64]"
del t20462
t20413 = torch.reshape(t20412, (1, 2048, 18176)) # t20413: "cuda:0 bf16[1, 2048, 18176]"
# t20413 = ltorch.reshape(t20412, (1, 2048, 18176)) # t20413: "cuda:0 bf16[1, 2048, 18176]"
# t20413 = prims.reshape(t20412, (1, 2048, 18176)) # t20413: "cuda:0 bf16[1, 2048, 18176]"
del t20412
[t20444] = nvFusion85(f255, f257, t20413, t587)
# t588 = prims.convert_element_type(t587, dtypes.float32) # t588: "cuda:0 f32[1, 2048, 18176]"
# t590 = prims.div(t588, 1.4142135623730951) # t590: "cuda:0 f32[1, 2048, 18176]"
# t593 = prims.erf(t590) # t593: "cuda:0 f32[1, 2048, 18176]"
# t597 = prims.mul(0.5, t593) # t597: "cuda:0 f32[1, 2048, 18176]"
# t601 = prims.add(0.5, t597) # t601: "cuda:0 f32[1, 2048, 18176]"
# t20418 = prims.convert_element_type(t20413, dtypes.float32) # t20418: "cuda:0 f32[1, 2048, 18176]"
# t20419 = prims.mul(t601, t20418) # t20419: "cuda:0 f32[1, 2048, 18176]"
# t20420 = prims.mul(t588, t20418) # t20420: "cuda:0 f32[1, 2048, 18176]"
# t20428 = prims.mul(f257, t20420) # t20428: "cuda:0 f32[1, 2048, 18176]"
# t20431 = prims.pow(t590, 2.0) # t20431: "cuda:0 f32[1, 2048, 18176]"
# t20432 = prims.neg(t20431) # t20432: "cuda:0 f32[1, 2048, 18176]"
# t20433 = prims.exp(t20432) # t20433: "cuda:0 f32[1, 2048, 18176]"
# t20434 = prims.mul(1.1283791670955126, t20433) # t20434: "cuda:0 f32[1, 2048, 18176]"
# t20435 = prims.mul(t20434, t20428) # t20435: "cuda:0 f32[1, 2048, 18176]"
# t20439 = prims.div(t20435, f255) # t20439: "cuda:0 f32[1, 2048, 18176]"
# t20443 = prims.add(t20419, t20439) # t20443: "cuda:0 f32[1, 2048, 18176]"
# t20444 = prims.convert_element_type(t20443, dtypes.bfloat16) # t20444: "cuda:0 bf16[1, 2048, 18176]"
del f255, f257, t20413, t587
t20445 = torch.reshape(t20444, (-1, 18176)) # t20445: "cuda:0 bf16[2048, 18176]"
# t20445 = ltorch.reshape(t20444, (-1, 18176)) # t20445: "cuda:0 bf16[2048, 18176]"
# t20445 = prims.reshape(t20444, (2048, 18176)) # t20445: "cuda:0 bf16[2048, 18176]"
del t20444
t20449 = torch.permute(t20445, (1, 0)) # t20449: "cuda:0 bf16[18176, 2048]"
# t20449 = ltorch.permute(t20445, (1, 0)) # t20449: "cuda:0 bf16[18176, 2048]"
# t20449 = prims.transpose(t20445, (1, 0)) # t20449: "cuda:0 bf16[18176, 2048]"
(t20466, t20467, t20468) = cudnn_sdpa_bwd(t20465, t571, t574, t524, None, f246, b247, t575, t576, t577, t578, scale=f248, cat_grad_qkv=False)
del t20465, t571, t574, t524, f246, b247, t575, t576, t577, t578, f248
t20451 = torch.matmul(t20449, t20450) # t20451: "cuda:0 bf16[18176, 4544]"
# t20451 = ltorch.matmul(t20449, t20450) # t20451: "cuda:0 bf16[18176, 4544]"
# t20451 = prims.matmul(t20449, t20450) # t20451: "cuda:0 bf16[18176, 4544]"
del t20449
t20446 = torch.matmul(t20445, t_transformer_h_3_mlp_fc_weight) # t20446: "cuda:0 bf16[2048, 4544]"
# t20446 = ltorch.matmul(t20445, t_transformer_h_3_mlp_fc_weight) # t20446: "cuda:0 bf16[2048, 4544]"
# t20446 = prims.matmul(t20445, t_transformer_h_3_mlp_fc_weight) # t20446: "cuda:0 bf16[2048, 4544]"
del t20445, t_transformer_h_3_mlp_fc_weight
t20470 = torch_slice_prim_impl(t20467, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20470: "cuda:0 bf16[1, 71, 2048, 64]"
del t20467
t20474 = torch_slice_prim_impl(t20466, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20474: "cuda:0 bf16[1, 71, 2048, 64]"
del t20466
t20577 = torch.reshape(t20468, (1, 1, 71, 2048, 64)) # t20577: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20577 = ltorch.reshape(t20468, (1, 1, 71, 2048, 64)) # t20577: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20577 = prims.reshape(t20468, (1, 1, 71, 2048, 64)) # t20577: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t20468
[t20611] = nvFusion86(i219, t20470, t20474, t20577, t61, t66)
# t20471 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t20471: "cuda:0 bf16[1, 71, 2048, 0]"
# t20472 = prims.pad(t20471, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t20472: "cuda:0 bf16[1, 71, 2048, 64]"
# t20475 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t20475: "cuda:0 bf16[1, 71, 2048, 0]"
# t20476 = prims.pad(t20475, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t20476: "cuda:0 bf16[1, 71, 2048, 64]"
# t20477 = prims.convert_element_type(t20470, dtypes.float32) # t20477: "cuda:0 f32[1, 71, 2048, 64]"
# t20481 = prims.mul(t66, t20477) # t20481: "cuda:0 f32[1, 71, 2048, 64]"
# t20484 = prims.convert_element_type(t20481, dtypes.bfloat16) # t20484: "cuda:0 bf16[1, 71, 2048, 64]"
# t20493 = prims.mul(t61, t20477) # t20493: "cuda:0 f32[1, 71, 2048, 64]"
# t20505 = prims.slice_prim(t20484, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t20505: "cuda:0 bf16[1, 71, 2048, 32]"
# t20506 = prims.slice_prim(t20484, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20506: "cuda:0 bf16[1, 71, 2048, 32]"
# t20507 = prims.convert_element_type(t20505, dtypes.float32) # t20507: "cuda:0 f32[1, 71, 2048, 32]"
# t20508 = prims.neg(t20507) # t20508: "cuda:0 f32[1, 71, 2048, 32]"
# t20509 = prims.convert_element_type(t20508, dtypes.bfloat16) # t20509: "cuda:0 bf16[1, 71, 2048, 32]"
# t20510 = prims.pad(t20509, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t20510: "cuda:0 bf16[1, 71, 2048, 64]"
# t20512 = prims.convert_element_type(t20510, dtypes.float32) # t20512: "cuda:0 f32[1, 71, 2048, 64]"
# t20513 = prims.add(t20493, t20512) # t20513: "cuda:0 f32[1, 71, 2048, 64]"
# t20515 = prims.pad(t20506, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t20515: "cuda:0 bf16[1, 71, 2048, 64]"
# t20517 = prims.convert_element_type(t20515, dtypes.float32) # t20517: "cuda:0 f32[1, 71, 2048, 64]"
# t20518 = prims.add(t20513, t20517) # t20518: "cuda:0 f32[1, 71, 2048, 64]"
# t20519 = prims.convert_element_type(t20518, dtypes.bfloat16) # t20519: "cuda:0 bf16[1, 71, 2048, 64]"
# t20520 = prims.pad(t20519, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t20520: "cuda:0 bf16[1, 71, 2048, 64]"
# t20521 = prims.convert_element_type(t20472, dtypes.float32) # t20521: "cuda:0 f32[1, 71, 2048, 64]"
# t20522 = prims.convert_element_type(t20520, dtypes.float32) # t20522: "cuda:0 f32[1, 71, 2048, 64]"
# t20523 = prims.add(t20521, t20522) # t20523: "cuda:0 f32[1, 71, 2048, 64]"
# t20524 = prims.convert_element_type(t20523, dtypes.bfloat16) # t20524: "cuda:0 bf16[1, 71, 2048, 64]"
# t20525 = prims.convert_element_type(t20474, dtypes.float32) # t20525: "cuda:0 f32[1, 71, 2048, 64]"
# t20529 = prims.mul(t66, t20525) # t20529: "cuda:0 f32[1, 71, 2048, 64]"
# t20532 = prims.convert_element_type(t20529, dtypes.bfloat16) # t20532: "cuda:0 bf16[1, 71, 2048, 64]"
# t20541 = prims.mul(t61, t20525) # t20541: "cuda:0 f32[1, 71, 2048, 64]"
# t20553 = prims.slice_prim(t20532, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t20553: "cuda:0 bf16[1, 71, 2048, 32]"
# t20554 = prims.slice_prim(t20532, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20554: "cuda:0 bf16[1, 71, 2048, 32]"
# t20555 = prims.convert_element_type(t20553, dtypes.float32) # t20555: "cuda:0 f32[1, 71, 2048, 32]"
# t20556 = prims.neg(t20555) # t20556: "cuda:0 f32[1, 71, 2048, 32]"
# t20557 = prims.convert_element_type(t20556, dtypes.bfloat16) # t20557: "cuda:0 bf16[1, 71, 2048, 32]"
# t20558 = prims.pad(t20557, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t20558: "cuda:0 bf16[1, 71, 2048, 64]"
# t20560 = prims.convert_element_type(t20558, dtypes.float32) # t20560: "cuda:0 f32[1, 71, 2048, 64]"
# t20561 = prims.add(t20541, t20560) # t20561: "cuda:0 f32[1, 71, 2048, 64]"
# t20563 = prims.pad(t20554, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t20563: "cuda:0 bf16[1, 71, 2048, 64]"
# t20565 = prims.convert_element_type(t20563, dtypes.float32) # t20565: "cuda:0 f32[1, 71, 2048, 64]"
# t20566 = prims.add(t20561, t20565) # t20566: "cuda:0 f32[1, 71, 2048, 64]"
# t20567 = prims.convert_element_type(t20566, dtypes.bfloat16) # t20567: "cuda:0 bf16[1, 71, 2048, 64]"
# t20568 = prims.pad(t20567, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t20568: "cuda:0 bf16[1, 71, 2048, 64]"
# t20569 = prims.convert_element_type(t20476, dtypes.float32) # t20569: "cuda:0 f32[1, 71, 2048, 64]"
# t20570 = prims.convert_element_type(t20568, dtypes.float32) # t20570: "cuda:0 f32[1, 71, 2048, 64]"
# t20571 = prims.add(t20569, t20570) # t20571: "cuda:0 f32[1, 71, 2048, 64]"
# t20572 = prims.convert_element_type(t20571, dtypes.bfloat16) # t20572: "cuda:0 bf16[1, 71, 2048, 64]"
# t20582 = prims.reshape(t20524, (1, 1, 71, 2048, 64)) # t20582: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20587 = prims.reshape(t20572, (1, 1, 71, 2048, 64)) # t20587: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20593 = prims.convert_element_type(t20577, dtypes.float32) # t20593: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t20594 = prims.sum(t20593, (0, 1, 2)) # t20594: "cuda:0 f32[2048, 64]"
# t20595 = prims.convert_element_type(t20594, dtypes.bfloat16) # t20595: "cuda:0 bf16[2048, 64]"
# t20596 = prims.broadcast_in_dim(t20595, [1, 1, 1, 2048, 64], [3, 4]) # t20596: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t20602 = prims.convert_element_type(t20582, dtypes.float32) # t20602: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t20603 = prims.sum(t20602, (0, 1, 2)) # t20603: "cuda:0 f32[2048, 64]"
# t20604 = prims.convert_element_type(t20603, dtypes.bfloat16) # t20604: "cuda:0 bf16[2048, 64]"
# t20605 = prims.broadcast_in_dim(t20604, [1, 1, 1, 2048, 64], [3, 4]) # t20605: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t20611 = prims.cat((t20587, t20605, t20596), i219) # t20611: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i219, t20470, t20474, t20577
t20617 = torch.permute(t20611, (0, 3, 1, 2, 4)) # t20617: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t20617 = ltorch.permute(t20611, (0, 3, 1, 2, 4)) # t20617: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t20617 = prims.transpose(t20611, (0, 3, 1, 2, 4)) # t20617: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t20611
t20623 = torch.reshape(t20617, (1, 2048, 4672)) # t20623: "cuda:0 bf16[1, 2048, 4672]"
# t20623 = ltorch.reshape(t20617, (1, 2048, 4672)) # t20623: "cuda:0 bf16[1, 2048, 4672]"
# t20623 = prims.reshape(t20617, (1, 2048, 4672)) # t20623: "cuda:0 bf16[1, 2048, 4672]"
del t20617
t20624 = torch.reshape(t20623, (-1, 4672)) # t20624: "cuda:0 bf16[2048, 4672]"
# t20624 = ltorch.reshape(t20623, (-1, 4672)) # t20624: "cuda:0 bf16[2048, 4672]"
# t20624 = prims.reshape(t20623, (2048, 4672)) # t20624: "cuda:0 bf16[2048, 4672]"
del t20623
t20628 = torch.permute(t20624, (1, 0)) # t20628: "cuda:0 bf16[4672, 2048]"
# t20628 = ltorch.permute(t20624, (1, 0)) # t20628: "cuda:0 bf16[4672, 2048]"
# t20628 = prims.transpose(t20624, (1, 0)) # t20628: "cuda:0 bf16[4672, 2048]"
t20630 = torch.matmul(t20628, t20450) # t20630: "cuda:0 bf16[4672, 4544]"
# t20630 = ltorch.matmul(t20628, t20629) # t20630: "cuda:0 bf16[4672, 4544]"
# t20630 = prims.matmul(t20628, t20629) # t20630: "cuda:0 bf16[4672, 4544]"
del t20628, t20450
t20625 = torch.matmul(t20624, t_transformer_h_3_attn_attn_weight) # t20625: "cuda:0 bf16[2048, 4544]"
# t20625 = ltorch.matmul(t20624, t_transformer_h_3_attn_attn_weight) # t20625: "cuda:0 bf16[2048, 4544]"
# t20625 = prims.matmul(t20624, t_transformer_h_3_attn_attn_weight) # t20625: "cuda:0 bf16[2048, 4544]"
del t20624, t_transformer_h_3_attn_attn_weight
t20447 = torch.reshape(t20446, (1, 2048, 4544)) # t20447: "cuda:0 bf16[1, 2048, 4544]"
# t20447 = ltorch.reshape(t20446, (1, 2048, 4544)) # t20447: "cuda:0 bf16[1, 2048, 4544]"
# t20447 = prims.reshape(t20446, (1, 2048, 4544)) # t20447: "cuda:0 bf16[1, 2048, 4544]"
del t20446
t20626 = torch.reshape(t20625, (1, 2048, 4544)) # t20626: "cuda:0 bf16[1, 2048, 4544]"
# t20626 = ltorch.reshape(t20625, (1, 2048, 4544)) # t20626: "cuda:0 bf16[1, 2048, 4544]"
# t20626 = prims.reshape(t20625, (1, 2048, 4544)) # t20626: "cuda:0 bf16[1, 2048, 4544]"
del t20625
[t20639, t20645, t20687] = nvFusion87(i20667, t20404, t20447, t20626, t293, t425, t446, t461, t466, t472)
# t452 = prims.convert_element_type(t293, dtypes.float32) # t452: "cuda:0 f32[1, 2048, 4544]"
# t447 = prims.convert_element_type(t446, dtypes.float32) # t447: "cuda:0 f32[1, 2048, 4544]"
# t448 = prims.convert_element_type(t425, dtypes.float32) # t448: "cuda:0 f32[1, 2048, 4544]"
# t449 = prims.add(t447, t448) # t449: "cuda:0 f32[1, 2048, 4544]"
# t453 = prims.add(t449, t452) # t453: "cuda:0 f32[1, 2048, 4544]"
# t463 = prims.broadcast_in_dim(t461, [1, 2048, 1], [0, 1]) # t463: "cuda:0 f32[1, 2048, 1]"
# t467 = prims.broadcast_in_dim(t463, (1, 2048, 4544), (0, 1, 2)) # t467: "cuda:0 f32[1, 2048, 4544]"
# t469 = prims.sub(t453, t467) # t469: "cuda:0 f32[1, 2048, 4544]"
# t470 = prims.broadcast_in_dim(t466, (1, 2048, 4544), (0, 1, 2)) # t470: "cuda:0 f32[1, 2048, 4544]"
# t471 = prims.mul(t469, t470) # t471: "cuda:0 f32[1, 2048, 4544]"
# t473 = prims.convert_element_type(t472, dtypes.float32) # t473: "cuda:0 f32[1, 2048, 4544]"
# t20684 = prims.convert_element_type(t20404, dtypes.float32) # t20684: "cuda:0 f32[1, 2048, 4544]"
# t20631 = prims.convert_element_type(t20447, dtypes.float32) # t20631: "cuda:0 f32[1, 2048, 4544]"
# t20632 = prims.convert_element_type(t20626, dtypes.float32) # t20632: "cuda:0 f32[1, 2048, 4544]"
# t20633 = prims.add(t20631, t20632) # t20633: "cuda:0 f32[1, 2048, 4544]"
# t20638 = prims.sum(t20633, (0, 1)) # t20638: "cuda:0 f32[4544]"
# t20639 = prims.convert_element_type(t20638, dtypes.bfloat16) # t20639: "cuda:0 bf16[4544]"
# t20640 = prims.mul(t473, t20633) # t20640: "cuda:0 f32[1, 2048, 4544]"
# t20641 = prims.mul(t471, t20633) # t20641: "cuda:0 f32[1, 2048, 4544]"
# t20644 = prims.sum(t20641, (0, 1)) # t20644: "cuda:0 f32[4544]"
# t20645 = prims.convert_element_type(t20644, dtypes.bfloat16) # t20645: "cuda:0 bf16[4544]"
# t20646 = prims.mul(t470, t20640) # t20646: "cuda:0 f32[1, 2048, 4544]"
# t20647 = prims.mul(t469, t20640) # t20647: "cuda:0 f32[1, 2048, 4544]"
# t20648 = prims.sum(t20647, (0, 2)) # t20648: "cuda:0 f32[2048]"
# t20649 = prims.broadcast_in_dim(t20648, [1, 2048, 1], [1]) # t20649: "cuda:0 f32[1, 2048, 1]"
# t20650 = prims.neg(t20646) # t20650: "cuda:0 f32[1, 2048, 4544]"
# t20652 = prims.sum(t20650, (0, 2)) # t20652: "cuda:0 f32[2048]"
# t20653 = prims.broadcast_in_dim(t20652, [1, 2048, 1], [1]) # t20653: "cuda:0 f32[1, 2048, 1]"
# t20654 = prims.mul(-0.5, t20649) # t20654: "cuda:0 f32[1, 2048, 1]"
# t20655 = prims.pow(t466, 3.0) # t20655: "cuda:0 f32[1, 2048, 1]"
# t20656 = prims.mul(t20654, t20655) # t20656: "cuda:0 f32[1, 2048, 1]"
# t20658 = prims.sum(t20653, (0, 2)) # t20658: "cuda:0 f32[2048]"
# t20659 = prims.broadcast_in_dim(t20658, [1, 2048], [1]) # t20659: "cuda:0 f32[1, 2048]"
# t20660 = prims.sum(t20656, (0, 2)) # t20660: "cuda:0 f32[2048]"
# t20661 = prims.broadcast_in_dim(t20660, [1, 2048], [1]) # t20661: "cuda:0 f32[1, 2048]"
# t20664 = prims.broadcast_in_dim(t20659, [1, 2048, 1], [0, 1]) # t20664: "cuda:0 f32[1, 2048, 1]"
# t20665 = prims.broadcast_in_dim(t20664, (1, 2048, 4544), (0, 1, 2)) # t20665: "cuda:0 f32[1, 2048, 4544]"
# t20666 = prims.mul(0.00022007042253521127, t20665) # t20666: "cuda:0 f32[1, 2048, 4544]"
# t20668 = prims.broadcast_in_dim(t20661, [1, 2048, 1], [0, 1]) # t20668: "cuda:0 f32[1, 2048, 1]"
# t20669 = prims.broadcast_in_dim(t20668, (1, 2048, 4544), (0, 1, 2)) # t20669: "cuda:0 f32[1, 2048, 4544]"
# t20671 = prims.broadcast_in_dim(t461, [1, 2048, 1], [0, 1]) # t20671: "cuda:0 f32[1, 2048, 1]"
# t20672 = prims.broadcast_in_dim(t20671, (1, 2048, 4544), (0, 1, 2)) # t20672: "cuda:0 f32[1, 2048, 4544]"
# t20673 = prims.mul(2.0, t20669) # t20673: "cuda:0 f32[1, 2048, 4544]"
# t20674 = prims.sub(t453, t20672) # t20674: "cuda:0 f32[1, 2048, 4544]"
# t20675 = prims.mul(t20673, t20674) # t20675: "cuda:0 f32[1, 2048, 4544]"
# f20676 = prims.convert_element_type(i20667, float) # f20676: "float 4544.0"
# t20677 = prims.div(t20675, f20676) # t20677: "cuda:0 f32[1, 2048, 4544]"
# t20678 = prims.add(t20666, t20677) # t20678: "cuda:0 f32[1, 2048, 4544]"
# t20682 = prims.add(t20646, t20678) # t20682: "cuda:0 f32[1, 2048, 4544]"
# t20686 = prims.add(t20684, t20682) # t20686: "cuda:0 f32[1, 2048, 4544]"
# t20687 = prims.convert_element_type(t20686, dtypes.bfloat16) # t20687: "cuda:0 bf16[1, 2048, 4544]"
del i20667, t20404, t20447, t20626, t293, t425, t446, t461, t466, t472
t20694 = torch.reshape(t20687, (-1, 4544)) # t20694: "cuda:0 bf16[2048, 4544]"
# t20694 = ltorch.reshape(t20687, (-1, 4544)) # t20694: "cuda:0 bf16[2048, 4544]"
# t20694 = prims.reshape(t20687, (2048, 4544)) # t20694: "cuda:0 bf16[2048, 4544]"
t20698 = torch.permute(t20694, (1, 0)) # t20698: "cuda:0 bf16[4544, 2048]"
# t20698 = ltorch.permute(t20694, (1, 0)) # t20698: "cuda:0 bf16[4544, 2048]"
# t20698 = prims.transpose(t20694, (1, 0)) # t20698: "cuda:0 bf16[4544, 2048]"
t20741 = torch.matmul(t20698, t20740) # t20741: "cuda:0 bf16[4544, 4544]"
# t20741 = ltorch.matmul(t20739, t20740) # t20741: "cuda:0 bf16[4544, 4544]"
# t20741 = prims.matmul(t20739, t20740) # t20741: "cuda:0 bf16[4544, 4544]"
del t20740
t20695 = torch.matmul(t20694, t_transformer_h_2_mlp_proj_weight) # t20695: "cuda:0 bf16[2048, 18176]"
# t20695 = ltorch.matmul(t20694, t_transformer_h_2_mlp_proj_weight) # t20695: "cuda:0 bf16[2048, 18176]"
# t20695 = prims.matmul(t20694, t_transformer_h_2_mlp_proj_weight) # t20695: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_2_mlp_proj_weight
t20700 = torch.matmul(t20698, t20699) # t20700: "cuda:0 bf16[4544, 18176]"
# t20700 = ltorch.matmul(t20698, t20699) # t20700: "cuda:0 bf16[4544, 18176]"
# t20700 = prims.matmul(t20698, t20699) # t20700: "cuda:0 bf16[4544, 18176]"
del t20698, t20699
t20736 = torch.matmul(t20694, t_transformer_h_2_attn_proj_weight) # t20736: "cuda:0 bf16[2048, 4544]"
# t20736 = ltorch.matmul(t20735, t_transformer_h_2_attn_proj_weight) # t20736: "cuda:0 bf16[2048, 4544]"
# t20736 = prims.matmul(t20735, t_transformer_h_2_attn_proj_weight) # t20736: "cuda:0 bf16[2048, 4544]"
del t20694, t_transformer_h_2_attn_proj_weight
t20696 = torch.reshape(t20695, (1, 2048, 18176)) # t20696: "cuda:0 bf16[1, 2048, 18176]"
# t20696 = ltorch.reshape(t20695, (1, 2048, 18176)) # t20696: "cuda:0 bf16[1, 2048, 18176]"
# t20696 = prims.reshape(t20695, (1, 2048, 18176)) # t20696: "cuda:0 bf16[1, 2048, 18176]"
del t20695
t20737 = torch.reshape(t20736, (1, 2048, 4544)) # t20737: "cuda:0 bf16[1, 2048, 4544]"
# t20737 = ltorch.reshape(t20736, (1, 2048, 4544)) # t20737: "cuda:0 bf16[1, 2048, 4544]"
# t20737 = prims.reshape(t20736, (1, 2048, 4544)) # t20737: "cuda:0 bf16[1, 2048, 4544]"
del t20736
t20745 = torch.reshape(t20737, (1, 2048, 71, 64)) # t20745: "cuda:0 bf16[1, 2048, 71, 64]"
# t20745 = ltorch.reshape(t20737, (1, 2048, 71, 64)) # t20745: "cuda:0 bf16[1, 2048, 71, 64]"
# t20745 = prims.reshape(t20737, (1, 2048, 71, 64)) # t20745: "cuda:0 bf16[1, 2048, 71, 64]"
del t20737
t20748 = torch.permute(t20745, (0, 2, 1, 3)) # t20748: "cuda:0 bf16[1, 71, 2048, 64]"
# t20748 = ltorch.permute(t20745, (0, 2, 1, 3)) # t20748: "cuda:0 bf16[1, 71, 2048, 64]"
# t20748 = prims.transpose(t20745, (0, 2, 1, 3)) # t20748: "cuda:0 bf16[1, 71, 2048, 64]"
del t20745
[t20727] = nvFusion88(f191, f193, t20696, t426)
# t427 = prims.convert_element_type(t426, dtypes.float32) # t427: "cuda:0 f32[1, 2048, 18176]"
# t429 = prims.div(t427, 1.4142135623730951) # t429: "cuda:0 f32[1, 2048, 18176]"
# t432 = prims.erf(t429) # t432: "cuda:0 f32[1, 2048, 18176]"
# t436 = prims.mul(0.5, t432) # t436: "cuda:0 f32[1, 2048, 18176]"
# t440 = prims.add(0.5, t436) # t440: "cuda:0 f32[1, 2048, 18176]"
# t20701 = prims.convert_element_type(t20696, dtypes.float32) # t20701: "cuda:0 f32[1, 2048, 18176]"
# t20702 = prims.mul(t440, t20701) # t20702: "cuda:0 f32[1, 2048, 18176]"
# t20703 = prims.mul(t427, t20701) # t20703: "cuda:0 f32[1, 2048, 18176]"
# t20711 = prims.mul(f193, t20703) # t20711: "cuda:0 f32[1, 2048, 18176]"
# t20714 = prims.pow(t429, 2.0) # t20714: "cuda:0 f32[1, 2048, 18176]"
# t20715 = prims.neg(t20714) # t20715: "cuda:0 f32[1, 2048, 18176]"
# t20716 = prims.exp(t20715) # t20716: "cuda:0 f32[1, 2048, 18176]"
# t20717 = prims.mul(1.1283791670955126, t20716) # t20717: "cuda:0 f32[1, 2048, 18176]"
# t20718 = prims.mul(t20717, t20711) # t20718: "cuda:0 f32[1, 2048, 18176]"
# t20722 = prims.div(t20718, f191) # t20722: "cuda:0 f32[1, 2048, 18176]"
# t20726 = prims.add(t20702, t20722) # t20726: "cuda:0 f32[1, 2048, 18176]"
# t20727 = prims.convert_element_type(t20726, dtypes.bfloat16) # t20727: "cuda:0 bf16[1, 2048, 18176]"
del f191, f193, t20696, t426
t20728 = torch.reshape(t20727, (-1, 18176)) # t20728: "cuda:0 bf16[2048, 18176]"
# t20728 = ltorch.reshape(t20727, (-1, 18176)) # t20728: "cuda:0 bf16[2048, 18176]"
# t20728 = prims.reshape(t20727, (2048, 18176)) # t20728: "cuda:0 bf16[2048, 18176]"
del t20727
t20732 = torch.permute(t20728, (1, 0)) # t20732: "cuda:0 bf16[18176, 2048]"
# t20732 = ltorch.permute(t20728, (1, 0)) # t20732: "cuda:0 bf16[18176, 2048]"
# t20732 = prims.transpose(t20728, (1, 0)) # t20732: "cuda:0 bf16[18176, 2048]"
t20734 = torch.matmul(t20732, t20733) # t20734: "cuda:0 bf16[18176, 4544]"
# t20734 = ltorch.matmul(t20732, t20733) # t20734: "cuda:0 bf16[18176, 4544]"
# t20734 = prims.matmul(t20732, t20733) # t20734: "cuda:0 bf16[18176, 4544]"
del t20732
t20729 = torch.matmul(t20728, t_transformer_h_2_mlp_fc_weight) # t20729: "cuda:0 bf16[2048, 4544]"
# t20729 = ltorch.matmul(t20728, t_transformer_h_2_mlp_fc_weight) # t20729: "cuda:0 bf16[2048, 4544]"
# t20729 = prims.matmul(t20728, t_transformer_h_2_mlp_fc_weight) # t20729: "cuda:0 bf16[2048, 4544]"
del t20728, t_transformer_h_2_mlp_fc_weight
(t20749, t20750, t20751) = cudnn_sdpa_bwd(t20748, t410, t413, t363, None, f182, b183, t414, t415, t416, t417, scale=f184, cat_grad_qkv=False)
del t20748, t410, t413, t363, f182, b183, t414, t415, t416, t417, f184
t20753 = torch_slice_prim_impl(t20750, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20753: "cuda:0 bf16[1, 71, 2048, 64]"
del t20750
t20757 = torch_slice_prim_impl(t20749, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20757: "cuda:0 bf16[1, 71, 2048, 64]"
del t20749
t20860 = torch.reshape(t20751, (1, 1, 71, 2048, 64)) # t20860: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20860 = ltorch.reshape(t20751, (1, 1, 71, 2048, 64)) # t20860: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20860 = prims.reshape(t20751, (1, 1, 71, 2048, 64)) # t20860: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t20751
[t20894] = nvFusion89(i155, t20753, t20757, t20860, t61, t66)
# t20754 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t20754: "cuda:0 bf16[1, 71, 2048, 0]"
# t20755 = prims.pad(t20754, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t20755: "cuda:0 bf16[1, 71, 2048, 64]"
# t20758 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t20758: "cuda:0 bf16[1, 71, 2048, 0]"
# t20759 = prims.pad(t20758, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t20759: "cuda:0 bf16[1, 71, 2048, 64]"
# t20760 = prims.convert_element_type(t20753, dtypes.float32) # t20760: "cuda:0 f32[1, 71, 2048, 64]"
# t20764 = prims.mul(t66, t20760) # t20764: "cuda:0 f32[1, 71, 2048, 64]"
# t20767 = prims.convert_element_type(t20764, dtypes.bfloat16) # t20767: "cuda:0 bf16[1, 71, 2048, 64]"
# t20776 = prims.mul(t61, t20760) # t20776: "cuda:0 f32[1, 71, 2048, 64]"
# t20788 = prims.slice_prim(t20767, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t20788: "cuda:0 bf16[1, 71, 2048, 32]"
# t20789 = prims.slice_prim(t20767, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20789: "cuda:0 bf16[1, 71, 2048, 32]"
# t20790 = prims.convert_element_type(t20788, dtypes.float32) # t20790: "cuda:0 f32[1, 71, 2048, 32]"
# t20791 = prims.neg(t20790) # t20791: "cuda:0 f32[1, 71, 2048, 32]"
# t20792 = prims.convert_element_type(t20791, dtypes.bfloat16) # t20792: "cuda:0 bf16[1, 71, 2048, 32]"
# t20793 = prims.pad(t20792, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t20793: "cuda:0 bf16[1, 71, 2048, 64]"
# t20795 = prims.convert_element_type(t20793, dtypes.float32) # t20795: "cuda:0 f32[1, 71, 2048, 64]"
# t20796 = prims.add(t20776, t20795) # t20796: "cuda:0 f32[1, 71, 2048, 64]"
# t20798 = prims.pad(t20789, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t20798: "cuda:0 bf16[1, 71, 2048, 64]"
# t20800 = prims.convert_element_type(t20798, dtypes.float32) # t20800: "cuda:0 f32[1, 71, 2048, 64]"
# t20801 = prims.add(t20796, t20800) # t20801: "cuda:0 f32[1, 71, 2048, 64]"
# t20802 = prims.convert_element_type(t20801, dtypes.bfloat16) # t20802: "cuda:0 bf16[1, 71, 2048, 64]"
# t20803 = prims.pad(t20802, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t20803: "cuda:0 bf16[1, 71, 2048, 64]"
# t20804 = prims.convert_element_type(t20755, dtypes.float32) # t20804: "cuda:0 f32[1, 71, 2048, 64]"
# t20805 = prims.convert_element_type(t20803, dtypes.float32) # t20805: "cuda:0 f32[1, 71, 2048, 64]"
# t20806 = prims.add(t20804, t20805) # t20806: "cuda:0 f32[1, 71, 2048, 64]"
# t20807 = prims.convert_element_type(t20806, dtypes.bfloat16) # t20807: "cuda:0 bf16[1, 71, 2048, 64]"
# t20808 = prims.convert_element_type(t20757, dtypes.float32) # t20808: "cuda:0 f32[1, 71, 2048, 64]"
# t20812 = prims.mul(t66, t20808) # t20812: "cuda:0 f32[1, 71, 2048, 64]"
# t20815 = prims.convert_element_type(t20812, dtypes.bfloat16) # t20815: "cuda:0 bf16[1, 71, 2048, 64]"
# t20824 = prims.mul(t61, t20808) # t20824: "cuda:0 f32[1, 71, 2048, 64]"
# t20836 = prims.slice_prim(t20815, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t20836: "cuda:0 bf16[1, 71, 2048, 32]"
# t20837 = prims.slice_prim(t20815, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t20837: "cuda:0 bf16[1, 71, 2048, 32]"
# t20838 = prims.convert_element_type(t20836, dtypes.float32) # t20838: "cuda:0 f32[1, 71, 2048, 32]"
# t20839 = prims.neg(t20838) # t20839: "cuda:0 f32[1, 71, 2048, 32]"
# t20840 = prims.convert_element_type(t20839, dtypes.bfloat16) # t20840: "cuda:0 bf16[1, 71, 2048, 32]"
# t20841 = prims.pad(t20840, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t20841: "cuda:0 bf16[1, 71, 2048, 64]"
# t20843 = prims.convert_element_type(t20841, dtypes.float32) # t20843: "cuda:0 f32[1, 71, 2048, 64]"
# t20844 = prims.add(t20824, t20843) # t20844: "cuda:0 f32[1, 71, 2048, 64]"
# t20846 = prims.pad(t20837, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t20846: "cuda:0 bf16[1, 71, 2048, 64]"
# t20848 = prims.convert_element_type(t20846, dtypes.float32) # t20848: "cuda:0 f32[1, 71, 2048, 64]"
# t20849 = prims.add(t20844, t20848) # t20849: "cuda:0 f32[1, 71, 2048, 64]"
# t20850 = prims.convert_element_type(t20849, dtypes.bfloat16) # t20850: "cuda:0 bf16[1, 71, 2048, 64]"
# t20851 = prims.pad(t20850, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t20851: "cuda:0 bf16[1, 71, 2048, 64]"
# t20852 = prims.convert_element_type(t20759, dtypes.float32) # t20852: "cuda:0 f32[1, 71, 2048, 64]"
# t20853 = prims.convert_element_type(t20851, dtypes.float32) # t20853: "cuda:0 f32[1, 71, 2048, 64]"
# t20854 = prims.add(t20852, t20853) # t20854: "cuda:0 f32[1, 71, 2048, 64]"
# t20855 = prims.convert_element_type(t20854, dtypes.bfloat16) # t20855: "cuda:0 bf16[1, 71, 2048, 64]"
# t20865 = prims.reshape(t20807, (1, 1, 71, 2048, 64)) # t20865: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20870 = prims.reshape(t20855, (1, 1, 71, 2048, 64)) # t20870: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t20876 = prims.convert_element_type(t20860, dtypes.float32) # t20876: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t20877 = prims.sum(t20876, (0, 1, 2)) # t20877: "cuda:0 f32[2048, 64]"
# t20878 = prims.convert_element_type(t20877, dtypes.bfloat16) # t20878: "cuda:0 bf16[2048, 64]"
# t20879 = prims.broadcast_in_dim(t20878, [1, 1, 1, 2048, 64], [3, 4]) # t20879: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t20885 = prims.convert_element_type(t20865, dtypes.float32) # t20885: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t20886 = prims.sum(t20885, (0, 1, 2)) # t20886: "cuda:0 f32[2048, 64]"
# t20887 = prims.convert_element_type(t20886, dtypes.bfloat16) # t20887: "cuda:0 bf16[2048, 64]"
# t20888 = prims.broadcast_in_dim(t20887, [1, 1, 1, 2048, 64], [3, 4]) # t20888: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t20894 = prims.cat((t20870, t20888, t20879), i155) # t20894: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i155, t20753, t20757, t20860
t20900 = torch.permute(t20894, (0, 3, 1, 2, 4)) # t20900: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t20900 = ltorch.permute(t20894, (0, 3, 1, 2, 4)) # t20900: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t20900 = prims.transpose(t20894, (0, 3, 1, 2, 4)) # t20900: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t20894
t20906 = torch.reshape(t20900, (1, 2048, 4672)) # t20906: "cuda:0 bf16[1, 2048, 4672]"
# t20906 = ltorch.reshape(t20900, (1, 2048, 4672)) # t20906: "cuda:0 bf16[1, 2048, 4672]"
# t20906 = prims.reshape(t20900, (1, 2048, 4672)) # t20906: "cuda:0 bf16[1, 2048, 4672]"
del t20900
t20907 = torch.reshape(t20906, (-1, 4672)) # t20907: "cuda:0 bf16[2048, 4672]"
# t20907 = ltorch.reshape(t20906, (-1, 4672)) # t20907: "cuda:0 bf16[2048, 4672]"
# t20907 = prims.reshape(t20906, (2048, 4672)) # t20907: "cuda:0 bf16[2048, 4672]"
del t20906
t20911 = torch.permute(t20907, (1, 0)) # t20911: "cuda:0 bf16[4672, 2048]"
# t20911 = ltorch.permute(t20907, (1, 0)) # t20911: "cuda:0 bf16[4672, 2048]"
# t20911 = prims.transpose(t20907, (1, 0)) # t20911: "cuda:0 bf16[4672, 2048]"
t20913 = torch.matmul(t20911, t20733) # t20913: "cuda:0 bf16[4672, 4544]"
# t20913 = ltorch.matmul(t20911, t20912) # t20913: "cuda:0 bf16[4672, 4544]"
# t20913 = prims.matmul(t20911, t20912) # t20913: "cuda:0 bf16[4672, 4544]"
del t20911, t20733
t20908 = torch.matmul(t20907, t_transformer_h_2_attn_attn_weight) # t20908: "cuda:0 bf16[2048, 4544]"
# t20908 = ltorch.matmul(t20907, t_transformer_h_2_attn_attn_weight) # t20908: "cuda:0 bf16[2048, 4544]"
# t20908 = prims.matmul(t20907, t_transformer_h_2_attn_attn_weight) # t20908: "cuda:0 bf16[2048, 4544]"
del t20907, t_transformer_h_2_attn_attn_weight
t20730 = torch.reshape(t20729, (1, 2048, 4544)) # t20730: "cuda:0 bf16[1, 2048, 4544]"
# t20730 = ltorch.reshape(t20729, (1, 2048, 4544)) # t20730: "cuda:0 bf16[1, 2048, 4544]"
# t20730 = prims.reshape(t20729, (1, 2048, 4544)) # t20730: "cuda:0 bf16[1, 2048, 4544]"
del t20729
t20909 = torch.reshape(t20908, (1, 2048, 4544)) # t20909: "cuda:0 bf16[1, 2048, 4544]"
# t20909 = ltorch.reshape(t20908, (1, 2048, 4544)) # t20909: "cuda:0 bf16[1, 2048, 4544]"
# t20909 = prims.reshape(t20908, (1, 2048, 4544)) # t20909: "cuda:0 bf16[1, 2048, 4544]"
del t20908
[t20922, t20928, t20970] = nvFusion90(i20950, t132, t20687, t20730, t20909, t264, t285, t300, t305, t311)
# t291 = prims.convert_element_type(t132, dtypes.float32) # t291: "cuda:0 f32[1, 2048, 4544]"
# t286 = prims.convert_element_type(t285, dtypes.float32) # t286: "cuda:0 f32[1, 2048, 4544]"
# t287 = prims.convert_element_type(t264, dtypes.float32) # t287: "cuda:0 f32[1, 2048, 4544]"
# t288 = prims.add(t286, t287) # t288: "cuda:0 f32[1, 2048, 4544]"
# t292 = prims.add(t288, t291) # t292: "cuda:0 f32[1, 2048, 4544]"
# t302 = prims.broadcast_in_dim(t300, [1, 2048, 1], [0, 1]) # t302: "cuda:0 f32[1, 2048, 1]"
# t306 = prims.broadcast_in_dim(t302, (1, 2048, 4544), (0, 1, 2)) # t306: "cuda:0 f32[1, 2048, 4544]"
# t308 = prims.sub(t292, t306) # t308: "cuda:0 f32[1, 2048, 4544]"
# t309 = prims.broadcast_in_dim(t305, (1, 2048, 4544), (0, 1, 2)) # t309: "cuda:0 f32[1, 2048, 4544]"
# t310 = prims.mul(t308, t309) # t310: "cuda:0 f32[1, 2048, 4544]"
# t312 = prims.convert_element_type(t311, dtypes.float32) # t312: "cuda:0 f32[1, 2048, 4544]"
# t20967 = prims.convert_element_type(t20687, dtypes.float32) # t20967: "cuda:0 f32[1, 2048, 4544]"
# t20914 = prims.convert_element_type(t20730, dtypes.float32) # t20914: "cuda:0 f32[1, 2048, 4544]"
# t20915 = prims.convert_element_type(t20909, dtypes.float32) # t20915: "cuda:0 f32[1, 2048, 4544]"
# t20916 = prims.add(t20914, t20915) # t20916: "cuda:0 f32[1, 2048, 4544]"
# t20921 = prims.sum(t20916, (0, 1)) # t20921: "cuda:0 f32[4544]"
# t20922 = prims.convert_element_type(t20921, dtypes.bfloat16) # t20922: "cuda:0 bf16[4544]"
# t20923 = prims.mul(t312, t20916) # t20923: "cuda:0 f32[1, 2048, 4544]"
# t20924 = prims.mul(t310, t20916) # t20924: "cuda:0 f32[1, 2048, 4544]"
# t20927 = prims.sum(t20924, (0, 1)) # t20927: "cuda:0 f32[4544]"
# t20928 = prims.convert_element_type(t20927, dtypes.bfloat16) # t20928: "cuda:0 bf16[4544]"
# t20929 = prims.mul(t309, t20923) # t20929: "cuda:0 f32[1, 2048, 4544]"
# t20930 = prims.mul(t308, t20923) # t20930: "cuda:0 f32[1, 2048, 4544]"
# t20931 = prims.sum(t20930, (0, 2)) # t20931: "cuda:0 f32[2048]"
# t20932 = prims.broadcast_in_dim(t20931, [1, 2048, 1], [1]) # t20932: "cuda:0 f32[1, 2048, 1]"
# t20933 = prims.neg(t20929) # t20933: "cuda:0 f32[1, 2048, 4544]"
# t20935 = prims.sum(t20933, (0, 2)) # t20935: "cuda:0 f32[2048]"
# t20936 = prims.broadcast_in_dim(t20935, [1, 2048, 1], [1]) # t20936: "cuda:0 f32[1, 2048, 1]"
# t20937 = prims.mul(-0.5, t20932) # t20937: "cuda:0 f32[1, 2048, 1]"
# t20938 = prims.pow(t305, 3.0) # t20938: "cuda:0 f32[1, 2048, 1]"
# t20939 = prims.mul(t20937, t20938) # t20939: "cuda:0 f32[1, 2048, 1]"
# t20941 = prims.sum(t20936, (0, 2)) # t20941: "cuda:0 f32[2048]"
# t20942 = prims.broadcast_in_dim(t20941, [1, 2048], [1]) # t20942: "cuda:0 f32[1, 2048]"
# t20943 = prims.sum(t20939, (0, 2)) # t20943: "cuda:0 f32[2048]"
# t20944 = prims.broadcast_in_dim(t20943, [1, 2048], [1]) # t20944: "cuda:0 f32[1, 2048]"
# t20947 = prims.broadcast_in_dim(t20942, [1, 2048, 1], [0, 1]) # t20947: "cuda:0 f32[1, 2048, 1]"
# t20948 = prims.broadcast_in_dim(t20947, (1, 2048, 4544), (0, 1, 2)) # t20948: "cuda:0 f32[1, 2048, 4544]"
# t20949 = prims.mul(0.00022007042253521127, t20948) # t20949: "cuda:0 f32[1, 2048, 4544]"
# t20951 = prims.broadcast_in_dim(t20944, [1, 2048, 1], [0, 1]) # t20951: "cuda:0 f32[1, 2048, 1]"
# t20952 = prims.broadcast_in_dim(t20951, (1, 2048, 4544), (0, 1, 2)) # t20952: "cuda:0 f32[1, 2048, 4544]"
# t20954 = prims.broadcast_in_dim(t300, [1, 2048, 1], [0, 1]) # t20954: "cuda:0 f32[1, 2048, 1]"
# t20955 = prims.broadcast_in_dim(t20954, (1, 2048, 4544), (0, 1, 2)) # t20955: "cuda:0 f32[1, 2048, 4544]"
# t20956 = prims.mul(2.0, t20952) # t20956: "cuda:0 f32[1, 2048, 4544]"
# t20957 = prims.sub(t292, t20955) # t20957: "cuda:0 f32[1, 2048, 4544]"
# t20958 = prims.mul(t20956, t20957) # t20958: "cuda:0 f32[1, 2048, 4544]"
# f20959 = prims.convert_element_type(i20950, float) # f20959: "float 4544.0"
# t20960 = prims.div(t20958, f20959) # t20960: "cuda:0 f32[1, 2048, 4544]"
# t20961 = prims.add(t20949, t20960) # t20961: "cuda:0 f32[1, 2048, 4544]"
# t20965 = prims.add(t20929, t20961) # t20965: "cuda:0 f32[1, 2048, 4544]"
# t20969 = prims.add(t20967, t20965) # t20969: "cuda:0 f32[1, 2048, 4544]"
# t20970 = prims.convert_element_type(t20969, dtypes.bfloat16) # t20970: "cuda:0 bf16[1, 2048, 4544]"
del i20950, t132, t20687, t20730, t20909, t264, t285, t300, t305, t311
t20977 = torch.reshape(t20970, (-1, 4544)) # t20977: "cuda:0 bf16[2048, 4544]"
# t20977 = ltorch.reshape(t20970, (-1, 4544)) # t20977: "cuda:0 bf16[2048, 4544]"
# t20977 = prims.reshape(t20970, (2048, 4544)) # t20977: "cuda:0 bf16[2048, 4544]"
t20981 = torch.permute(t20977, (1, 0)) # t20981: "cuda:0 bf16[4544, 2048]"
# t20981 = ltorch.permute(t20977, (1, 0)) # t20981: "cuda:0 bf16[4544, 2048]"
# t20981 = prims.transpose(t20977, (1, 0)) # t20981: "cuda:0 bf16[4544, 2048]"
t20978 = torch.matmul(t20977, t_transformer_h_1_mlp_proj_weight) # t20978: "cuda:0 bf16[2048, 18176]"
# t20978 = ltorch.matmul(t20977, t_transformer_h_1_mlp_proj_weight) # t20978: "cuda:0 bf16[2048, 18176]"
# t20978 = prims.matmul(t20977, t_transformer_h_1_mlp_proj_weight) # t20978: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_1_mlp_proj_weight
t20983 = torch.matmul(t20981, t20982) # t20983: "cuda:0 bf16[4544, 18176]"
# t20983 = ltorch.matmul(t20981, t20982) # t20983: "cuda:0 bf16[4544, 18176]"
# t20983 = prims.matmul(t20981, t20982) # t20983: "cuda:0 bf16[4544, 18176]"
del t20982
t21019 = torch.matmul(t20977, t_transformer_h_1_attn_proj_weight) # t21019: "cuda:0 bf16[2048, 4544]"
# t21019 = ltorch.matmul(t21018, t_transformer_h_1_attn_proj_weight) # t21019: "cuda:0 bf16[2048, 4544]"
# t21019 = prims.matmul(t21018, t_transformer_h_1_attn_proj_weight) # t21019: "cuda:0 bf16[2048, 4544]"
del t20977, t_transformer_h_1_attn_proj_weight
t21024 = torch.matmul(t20981, t21023) # t21024: "cuda:0 bf16[4544, 4544]"
# t21024 = ltorch.matmul(t21022, t21023) # t21024: "cuda:0 bf16[4544, 4544]"
# t21024 = prims.matmul(t21022, t21023) # t21024: "cuda:0 bf16[4544, 4544]"
del t20981, t21023
t20979 = torch.reshape(t20978, (1, 2048, 18176)) # t20979: "cuda:0 bf16[1, 2048, 18176]"
# t20979 = ltorch.reshape(t20978, (1, 2048, 18176)) # t20979: "cuda:0 bf16[1, 2048, 18176]"
# t20979 = prims.reshape(t20978, (1, 2048, 18176)) # t20979: "cuda:0 bf16[1, 2048, 18176]"
del t20978
t21020 = torch.reshape(t21019, (1, 2048, 4544)) # t21020: "cuda:0 bf16[1, 2048, 4544]"
# t21020 = ltorch.reshape(t21019, (1, 2048, 4544)) # t21020: "cuda:0 bf16[1, 2048, 4544]"
# t21020 = prims.reshape(t21019, (1, 2048, 4544)) # t21020: "cuda:0 bf16[1, 2048, 4544]"
del t21019
t21028 = torch.reshape(t21020, (1, 2048, 71, 64)) # t21028: "cuda:0 bf16[1, 2048, 71, 64]"
# t21028 = ltorch.reshape(t21020, (1, 2048, 71, 64)) # t21028: "cuda:0 bf16[1, 2048, 71, 64]"
# t21028 = prims.reshape(t21020, (1, 2048, 71, 64)) # t21028: "cuda:0 bf16[1, 2048, 71, 64]"
del t21020
t21031 = torch.permute(t21028, (0, 2, 1, 3)) # t21031: "cuda:0 bf16[1, 71, 2048, 64]"
# t21031 = ltorch.permute(t21028, (0, 2, 1, 3)) # t21031: "cuda:0 bf16[1, 71, 2048, 64]"
# t21031 = prims.transpose(t21028, (0, 2, 1, 3)) # t21031: "cuda:0 bf16[1, 71, 2048, 64]"
del t21028
[t21010] = nvFusion91(f127, f129, t20979, t265)
# t266 = prims.convert_element_type(t265, dtypes.float32) # t266: "cuda:0 f32[1, 2048, 18176]"
# t268 = prims.div(t266, 1.4142135623730951) # t268: "cuda:0 f32[1, 2048, 18176]"
# t271 = prims.erf(t268) # t271: "cuda:0 f32[1, 2048, 18176]"
# t275 = prims.mul(0.5, t271) # t275: "cuda:0 f32[1, 2048, 18176]"
# t279 = prims.add(0.5, t275) # t279: "cuda:0 f32[1, 2048, 18176]"
# t20984 = prims.convert_element_type(t20979, dtypes.float32) # t20984: "cuda:0 f32[1, 2048, 18176]"
# t20985 = prims.mul(t279, t20984) # t20985: "cuda:0 f32[1, 2048, 18176]"
# t20986 = prims.mul(t266, t20984) # t20986: "cuda:0 f32[1, 2048, 18176]"
# t20994 = prims.mul(f129, t20986) # t20994: "cuda:0 f32[1, 2048, 18176]"
# t20997 = prims.pow(t268, 2.0) # t20997: "cuda:0 f32[1, 2048, 18176]"
# t20998 = prims.neg(t20997) # t20998: "cuda:0 f32[1, 2048, 18176]"
# t20999 = prims.exp(t20998) # t20999: "cuda:0 f32[1, 2048, 18176]"
# t21000 = prims.mul(1.1283791670955126, t20999) # t21000: "cuda:0 f32[1, 2048, 18176]"
# t21001 = prims.mul(t21000, t20994) # t21001: "cuda:0 f32[1, 2048, 18176]"
# t21005 = prims.div(t21001, f127) # t21005: "cuda:0 f32[1, 2048, 18176]"
# t21009 = prims.add(t20985, t21005) # t21009: "cuda:0 f32[1, 2048, 18176]"
# t21010 = prims.convert_element_type(t21009, dtypes.bfloat16) # t21010: "cuda:0 bf16[1, 2048, 18176]"
del f127, f129, t20979, t265
t21011 = torch.reshape(t21010, (-1, 18176)) # t21011: "cuda:0 bf16[2048, 18176]"
# t21011 = ltorch.reshape(t21010, (-1, 18176)) # t21011: "cuda:0 bf16[2048, 18176]"
# t21011 = prims.reshape(t21010, (2048, 18176)) # t21011: "cuda:0 bf16[2048, 18176]"
del t21010
t21015 = torch.permute(t21011, (1, 0)) # t21015: "cuda:0 bf16[18176, 2048]"
# t21015 = ltorch.permute(t21011, (1, 0)) # t21015: "cuda:0 bf16[18176, 2048]"
# t21015 = prims.transpose(t21011, (1, 0)) # t21015: "cuda:0 bf16[18176, 2048]"
t21017 = torch.matmul(t21015, t21016) # t21017: "cuda:0 bf16[18176, 4544]"
# t21017 = ltorch.matmul(t21015, t21016) # t21017: "cuda:0 bf16[18176, 4544]"
# t21017 = prims.matmul(t21015, t21016) # t21017: "cuda:0 bf16[18176, 4544]"
del t21015
t21012 = torch.matmul(t21011, t_transformer_h_1_mlp_fc_weight) # t21012: "cuda:0 bf16[2048, 4544]"
# t21012 = ltorch.matmul(t21011, t_transformer_h_1_mlp_fc_weight) # t21012: "cuda:0 bf16[2048, 4544]"
# t21012 = prims.matmul(t21011, t_transformer_h_1_mlp_fc_weight) # t21012: "cuda:0 bf16[2048, 4544]"
del t21011, t_transformer_h_1_mlp_fc_weight
(t21032, t21033, t21034) = cudnn_sdpa_bwd(t21031, t249, t252, t202, None, f118, b119, t253, t254, t255, t256, scale=f120, cat_grad_qkv=False)
del t21031, t249, t252, t202, f118, b119, t253, t254, t255, t256, f120
t21036 = torch_slice_prim_impl(t21033, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t21036: "cuda:0 bf16[1, 71, 2048, 64]"
del t21033
t21040 = torch_slice_prim_impl(t21032, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t21040: "cuda:0 bf16[1, 71, 2048, 64]"
del t21032
t21143 = torch.reshape(t21034, (1, 1, 71, 2048, 64)) # t21143: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t21143 = ltorch.reshape(t21034, (1, 1, 71, 2048, 64)) # t21143: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t21143 = prims.reshape(t21034, (1, 1, 71, 2048, 64)) # t21143: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t21034
[t21177] = nvFusion92(i91, t21036, t21040, t21143, t61, t66)
# t21037 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t21037: "cuda:0 bf16[1, 71, 2048, 0]"
# t21038 = prims.pad(t21037, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t21038: "cuda:0 bf16[1, 71, 2048, 64]"
# t21041 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t21041: "cuda:0 bf16[1, 71, 2048, 0]"
# t21042 = prims.pad(t21041, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t21042: "cuda:0 bf16[1, 71, 2048, 64]"
# t21043 = prims.convert_element_type(t21036, dtypes.float32) # t21043: "cuda:0 f32[1, 71, 2048, 64]"
# t21047 = prims.mul(t66, t21043) # t21047: "cuda:0 f32[1, 71, 2048, 64]"
# t21050 = prims.convert_element_type(t21047, dtypes.bfloat16) # t21050: "cuda:0 bf16[1, 71, 2048, 64]"
# t21059 = prims.mul(t61, t21043) # t21059: "cuda:0 f32[1, 71, 2048, 64]"
# t21071 = prims.slice_prim(t21050, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t21071: "cuda:0 bf16[1, 71, 2048, 32]"
# t21072 = prims.slice_prim(t21050, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t21072: "cuda:0 bf16[1, 71, 2048, 32]"
# t21073 = prims.convert_element_type(t21071, dtypes.float32) # t21073: "cuda:0 f32[1, 71, 2048, 32]"
# t21074 = prims.neg(t21073) # t21074: "cuda:0 f32[1, 71, 2048, 32]"
# t21075 = prims.convert_element_type(t21074, dtypes.bfloat16) # t21075: "cuda:0 bf16[1, 71, 2048, 32]"
# t21076 = prims.pad(t21075, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t21076: "cuda:0 bf16[1, 71, 2048, 64]"
# t21078 = prims.convert_element_type(t21076, dtypes.float32) # t21078: "cuda:0 f32[1, 71, 2048, 64]"
# t21079 = prims.add(t21059, t21078) # t21079: "cuda:0 f32[1, 71, 2048, 64]"
# t21081 = prims.pad(t21072, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t21081: "cuda:0 bf16[1, 71, 2048, 64]"
# t21083 = prims.convert_element_type(t21081, dtypes.float32) # t21083: "cuda:0 f32[1, 71, 2048, 64]"
# t21084 = prims.add(t21079, t21083) # t21084: "cuda:0 f32[1, 71, 2048, 64]"
# t21085 = prims.convert_element_type(t21084, dtypes.bfloat16) # t21085: "cuda:0 bf16[1, 71, 2048, 64]"
# t21086 = prims.pad(t21085, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t21086: "cuda:0 bf16[1, 71, 2048, 64]"
# t21087 = prims.convert_element_type(t21038, dtypes.float32) # t21087: "cuda:0 f32[1, 71, 2048, 64]"
# t21088 = prims.convert_element_type(t21086, dtypes.float32) # t21088: "cuda:0 f32[1, 71, 2048, 64]"
# t21089 = prims.add(t21087, t21088) # t21089: "cuda:0 f32[1, 71, 2048, 64]"
# t21090 = prims.convert_element_type(t21089, dtypes.bfloat16) # t21090: "cuda:0 bf16[1, 71, 2048, 64]"
# t21091 = prims.convert_element_type(t21040, dtypes.float32) # t21091: "cuda:0 f32[1, 71, 2048, 64]"
# t21095 = prims.mul(t66, t21091) # t21095: "cuda:0 f32[1, 71, 2048, 64]"
# t21098 = prims.convert_element_type(t21095, dtypes.bfloat16) # t21098: "cuda:0 bf16[1, 71, 2048, 64]"
# t21107 = prims.mul(t61, t21091) # t21107: "cuda:0 f32[1, 71, 2048, 64]"
# t21119 = prims.slice_prim(t21098, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t21119: "cuda:0 bf16[1, 71, 2048, 32]"
# t21120 = prims.slice_prim(t21098, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t21120: "cuda:0 bf16[1, 71, 2048, 32]"
# t21121 = prims.convert_element_type(t21119, dtypes.float32) # t21121: "cuda:0 f32[1, 71, 2048, 32]"
# t21122 = prims.neg(t21121) # t21122: "cuda:0 f32[1, 71, 2048, 32]"
# t21123 = prims.convert_element_type(t21122, dtypes.bfloat16) # t21123: "cuda:0 bf16[1, 71, 2048, 32]"
# t21124 = prims.pad(t21123, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t21124: "cuda:0 bf16[1, 71, 2048, 64]"
# t21126 = prims.convert_element_type(t21124, dtypes.float32) # t21126: "cuda:0 f32[1, 71, 2048, 64]"
# t21127 = prims.add(t21107, t21126) # t21127: "cuda:0 f32[1, 71, 2048, 64]"
# t21129 = prims.pad(t21120, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t21129: "cuda:0 bf16[1, 71, 2048, 64]"
# t21131 = prims.convert_element_type(t21129, dtypes.float32) # t21131: "cuda:0 f32[1, 71, 2048, 64]"
# t21132 = prims.add(t21127, t21131) # t21132: "cuda:0 f32[1, 71, 2048, 64]"
# t21133 = prims.convert_element_type(t21132, dtypes.bfloat16) # t21133: "cuda:0 bf16[1, 71, 2048, 64]"
# t21134 = prims.pad(t21133, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t21134: "cuda:0 bf16[1, 71, 2048, 64]"
# t21135 = prims.convert_element_type(t21042, dtypes.float32) # t21135: "cuda:0 f32[1, 71, 2048, 64]"
# t21136 = prims.convert_element_type(t21134, dtypes.float32) # t21136: "cuda:0 f32[1, 71, 2048, 64]"
# t21137 = prims.add(t21135, t21136) # t21137: "cuda:0 f32[1, 71, 2048, 64]"
# t21138 = prims.convert_element_type(t21137, dtypes.bfloat16) # t21138: "cuda:0 bf16[1, 71, 2048, 64]"
# t21148 = prims.reshape(t21090, (1, 1, 71, 2048, 64)) # t21148: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t21153 = prims.reshape(t21138, (1, 1, 71, 2048, 64)) # t21153: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t21159 = prims.convert_element_type(t21143, dtypes.float32) # t21159: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t21160 = prims.sum(t21159, (0, 1, 2)) # t21160: "cuda:0 f32[2048, 64]"
# t21161 = prims.convert_element_type(t21160, dtypes.bfloat16) # t21161: "cuda:0 bf16[2048, 64]"
# t21162 = prims.broadcast_in_dim(t21161, [1, 1, 1, 2048, 64], [3, 4]) # t21162: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t21168 = prims.convert_element_type(t21148, dtypes.float32) # t21168: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t21169 = prims.sum(t21168, (0, 1, 2)) # t21169: "cuda:0 f32[2048, 64]"
# t21170 = prims.convert_element_type(t21169, dtypes.bfloat16) # t21170: "cuda:0 bf16[2048, 64]"
# t21171 = prims.broadcast_in_dim(t21170, [1, 1, 1, 2048, 64], [3, 4]) # t21171: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t21177 = prims.cat((t21153, t21171, t21162), i91) # t21177: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i91, t21036, t21040, t21143
t21183 = torch.permute(t21177, (0, 3, 1, 2, 4)) # t21183: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t21183 = ltorch.permute(t21177, (0, 3, 1, 2, 4)) # t21183: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t21183 = prims.transpose(t21177, (0, 3, 1, 2, 4)) # t21183: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t21177
t21189 = torch.reshape(t21183, (1, 2048, 4672)) # t21189: "cuda:0 bf16[1, 2048, 4672]"
# t21189 = ltorch.reshape(t21183, (1, 2048, 4672)) # t21189: "cuda:0 bf16[1, 2048, 4672]"
# t21189 = prims.reshape(t21183, (1, 2048, 4672)) # t21189: "cuda:0 bf16[1, 2048, 4672]"
del t21183
t21190 = torch.reshape(t21189, (-1, 4672)) # t21190: "cuda:0 bf16[2048, 4672]"
# t21190 = ltorch.reshape(t21189, (-1, 4672)) # t21190: "cuda:0 bf16[2048, 4672]"
# t21190 = prims.reshape(t21189, (2048, 4672)) # t21190: "cuda:0 bf16[2048, 4672]"
del t21189
t21194 = torch.permute(t21190, (1, 0)) # t21194: "cuda:0 bf16[4672, 2048]"
# t21194 = ltorch.permute(t21190, (1, 0)) # t21194: "cuda:0 bf16[4672, 2048]"
# t21194 = prims.transpose(t21190, (1, 0)) # t21194: "cuda:0 bf16[4672, 2048]"
t21196 = torch.matmul(t21194, t21016) # t21196: "cuda:0 bf16[4672, 4544]"
# t21196 = ltorch.matmul(t21194, t21195) # t21196: "cuda:0 bf16[4672, 4544]"
# t21196 = prims.matmul(t21194, t21195) # t21196: "cuda:0 bf16[4672, 4544]"
del t21194, t21016
t21191 = torch.matmul(t21190, t_transformer_h_1_attn_attn_weight) # t21191: "cuda:0 bf16[2048, 4544]"
# t21191 = ltorch.matmul(t21190, t_transformer_h_1_attn_attn_weight) # t21191: "cuda:0 bf16[2048, 4544]"
# t21191 = prims.matmul(t21190, t_transformer_h_1_attn_attn_weight) # t21191: "cuda:0 bf16[2048, 4544]"
del t21190, t_transformer_h_1_attn_attn_weight
t21013 = torch.reshape(t21012, (1, 2048, 4544)) # t21013: "cuda:0 bf16[1, 2048, 4544]"
# t21013 = ltorch.reshape(t21012, (1, 2048, 4544)) # t21013: "cuda:0 bf16[1, 2048, 4544]"
# t21013 = prims.reshape(t21012, (1, 2048, 4544)) # t21013: "cuda:0 bf16[1, 2048, 4544]"
del t21012
t21192 = torch.reshape(t21191, (1, 2048, 4544)) # t21192: "cuda:0 bf16[1, 2048, 4544]"
# t21192 = ltorch.reshape(t21191, (1, 2048, 4544)) # t21192: "cuda:0 bf16[1, 2048, 4544]"
# t21192 = prims.reshape(t21191, (1, 2048, 4544)) # t21192: "cuda:0 bf16[1, 2048, 4544]"
del t21191
[t21205, t21211, t21253] = nvFusion93(i21233, t106, t124, t139, t144, t150, t20970, t21013, t21192, t4)
# t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 2048, 4544]"
# t126 = prims.convert_element_type(t106, dtypes.float32) # t126: "cuda:0 f32[1, 2048, 4544]"
# t127 = prims.add(t125, t126) # t127: "cuda:0 f32[1, 2048, 4544]"
# t130 = prims.convert_element_type(t4, dtypes.float32) # t130: "cuda:0 f32[1, 2048, 4544]"
# t131 = prims.add(t127, t130) # t131: "cuda:0 f32[1, 2048, 4544]"
# t141 = prims.broadcast_in_dim(t139, [1, 2048, 1], [0, 1]) # t141: "cuda:0 f32[1, 2048, 1]"
# t145 = prims.broadcast_in_dim(t141, (1, 2048, 4544), (0, 1, 2)) # t145: "cuda:0 f32[1, 2048, 4544]"
# t147 = prims.sub(t131, t145) # t147: "cuda:0 f32[1, 2048, 4544]"
# t148 = prims.broadcast_in_dim(t144, (1, 2048, 4544), (0, 1, 2)) # t148: "cuda:0 f32[1, 2048, 4544]"
# t149 = prims.mul(t147, t148) # t149: "cuda:0 f32[1, 2048, 4544]"
# t151 = prims.convert_element_type(t150, dtypes.float32) # t151: "cuda:0 f32[1, 2048, 4544]"
# t21250 = prims.convert_element_type(t20970, dtypes.float32) # t21250: "cuda:0 f32[1, 2048, 4544]"
# t21197 = prims.convert_element_type(t21013, dtypes.float32) # t21197: "cuda:0 f32[1, 2048, 4544]"
# t21198 = prims.convert_element_type(t21192, dtypes.float32) # t21198: "cuda:0 f32[1, 2048, 4544]"
# t21199 = prims.add(t21197, t21198) # t21199: "cuda:0 f32[1, 2048, 4544]"
# t21204 = prims.sum(t21199, (0, 1)) # t21204: "cuda:0 f32[4544]"
# t21205 = prims.convert_element_type(t21204, dtypes.bfloat16) # t21205: "cuda:0 bf16[4544]"
# t21206 = prims.mul(t151, t21199) # t21206: "cuda:0 f32[1, 2048, 4544]"
# t21207 = prims.mul(t149, t21199) # t21207: "cuda:0 f32[1, 2048, 4544]"
# t21210 = prims.sum(t21207, (0, 1)) # t21210: "cuda:0 f32[4544]"
# t21211 = prims.convert_element_type(t21210, dtypes.bfloat16) # t21211: "cuda:0 bf16[4544]"
# t21212 = prims.mul(t148, t21206) # t21212: "cuda:0 f32[1, 2048, 4544]"
# t21213 = prims.mul(t147, t21206) # t21213: "cuda:0 f32[1, 2048, 4544]"
# t21214 = prims.sum(t21213, (0, 2)) # t21214: "cuda:0 f32[2048]"
# t21215 = prims.broadcast_in_dim(t21214, [1, 2048, 1], [1]) # t21215: "cuda:0 f32[1, 2048, 1]"
# t21216 = prims.neg(t21212) # t21216: "cuda:0 f32[1, 2048, 4544]"
# t21218 = prims.sum(t21216, (0, 2)) # t21218: "cuda:0 f32[2048]"
# t21219 = prims.broadcast_in_dim(t21218, [1, 2048, 1], [1]) # t21219: "cuda:0 f32[1, 2048, 1]"
# t21220 = prims.mul(-0.5, t21215) # t21220: "cuda:0 f32[1, 2048, 1]"
# t21221 = prims.pow(t144, 3.0) # t21221: "cuda:0 f32[1, 2048, 1]"
# t21222 = prims.mul(t21220, t21221) # t21222: "cuda:0 f32[1, 2048, 1]"
# t21224 = prims.sum(t21219, (0, 2)) # t21224: "cuda:0 f32[2048]"
# t21225 = prims.broadcast_in_dim(t21224, [1, 2048], [1]) # t21225: "cuda:0 f32[1, 2048]"
# t21226 = prims.sum(t21222, (0, 2)) # t21226: "cuda:0 f32[2048]"
# t21227 = prims.broadcast_in_dim(t21226, [1, 2048], [1]) # t21227: "cuda:0 f32[1, 2048]"
# t21230 = prims.broadcast_in_dim(t21225, [1, 2048, 1], [0, 1]) # t21230: "cuda:0 f32[1, 2048, 1]"
# t21231 = prims.broadcast_in_dim(t21230, (1, 2048, 4544), (0, 1, 2)) # t21231: "cuda:0 f32[1, 2048, 4544]"
# t21232 = prims.mul(0.00022007042253521127, t21231) # t21232: "cuda:0 f32[1, 2048, 4544]"
# t21234 = prims.broadcast_in_dim(t21227, [1, 2048, 1], [0, 1]) # t21234: "cuda:0 f32[1, 2048, 1]"
# t21235 = prims.broadcast_in_dim(t21234, (1, 2048, 4544), (0, 1, 2)) # t21235: "cuda:0 f32[1, 2048, 4544]"
# t21237 = prims.broadcast_in_dim(t139, [1, 2048, 1], [0, 1]) # t21237: "cuda:0 f32[1, 2048, 1]"
# t21238 = prims.broadcast_in_dim(t21237, (1, 2048, 4544), (0, 1, 2)) # t21238: "cuda:0 f32[1, 2048, 4544]"
# t21239 = prims.mul(2.0, t21235) # t21239: "cuda:0 f32[1, 2048, 4544]"
# t21240 = prims.sub(t131, t21238) # t21240: "cuda:0 f32[1, 2048, 4544]"
# t21241 = prims.mul(t21239, t21240) # t21241: "cuda:0 f32[1, 2048, 4544]"
# f21242 = prims.convert_element_type(i21233, float) # f21242: "float 4544.0"
# t21243 = prims.div(t21241, f21242) # t21243: "cuda:0 f32[1, 2048, 4544]"
# t21244 = prims.add(t21232, t21243) # t21244: "cuda:0 f32[1, 2048, 4544]"
# t21248 = prims.add(t21212, t21244) # t21248: "cuda:0 f32[1, 2048, 4544]"
# t21252 = prims.add(t21250, t21248) # t21252: "cuda:0 f32[1, 2048, 4544]"
# t21253 = prims.convert_element_type(t21252, dtypes.bfloat16) # t21253: "cuda:0 bf16[1, 2048, 4544]"
del i21233, t106, t124, t139, t144, t150, t20970, t21013, t21192
t21260 = torch.reshape(t21253, (-1, 4544)) # t21260: "cuda:0 bf16[2048, 4544]"
# t21260 = ltorch.reshape(t21253, (-1, 4544)) # t21260: "cuda:0 bf16[2048, 4544]"
# t21260 = prims.reshape(t21253, (2048, 4544)) # t21260: "cuda:0 bf16[2048, 4544]"
t21264 = torch.permute(t21260, (1, 0)) # t21264: "cuda:0 bf16[4544, 2048]"
# t21264 = ltorch.permute(t21260, (1, 0)) # t21264: "cuda:0 bf16[4544, 2048]"
# t21264 = prims.transpose(t21260, (1, 0)) # t21264: "cuda:0 bf16[4544, 2048]"
t21261 = torch.matmul(t21260, t_transformer_h_0_mlp_proj_weight) # t21261: "cuda:0 bf16[2048, 18176]"
# t21261 = ltorch.matmul(t21260, t_transformer_h_0_mlp_proj_weight) # t21261: "cuda:0 bf16[2048, 18176]"
# t21261 = prims.matmul(t21260, t_transformer_h_0_mlp_proj_weight) # t21261: "cuda:0 bf16[2048, 18176]"
del t_transformer_h_0_mlp_proj_weight
t21266 = torch.matmul(t21264, t21265) # t21266: "cuda:0 bf16[4544, 18176]"
# t21266 = ltorch.matmul(t21264, t21265) # t21266: "cuda:0 bf16[4544, 18176]"
# t21266 = prims.matmul(t21264, t21265) # t21266: "cuda:0 bf16[4544, 18176]"
del t21265
t21302 = torch.matmul(t21260, t_transformer_h_0_attn_proj_weight) # t21302: "cuda:0 bf16[2048, 4544]"
# t21302 = ltorch.matmul(t21301, t_transformer_h_0_attn_proj_weight) # t21302: "cuda:0 bf16[2048, 4544]"
# t21302 = prims.matmul(t21301, t_transformer_h_0_attn_proj_weight) # t21302: "cuda:0 bf16[2048, 4544]"
del t21260, t_transformer_h_0_attn_proj_weight
t21307 = torch.matmul(t21264, t21306) # t21307: "cuda:0 bf16[4544, 4544]"
# t21307 = ltorch.matmul(t21305, t21306) # t21307: "cuda:0 bf16[4544, 4544]"
# t21307 = prims.matmul(t21305, t21306) # t21307: "cuda:0 bf16[4544, 4544]"
del t21264, t21306
t21262 = torch.reshape(t21261, (1, 2048, 18176)) # t21262: "cuda:0 bf16[1, 2048, 18176]"
# t21262 = ltorch.reshape(t21261, (1, 2048, 18176)) # t21262: "cuda:0 bf16[1, 2048, 18176]"
# t21262 = prims.reshape(t21261, (1, 2048, 18176)) # t21262: "cuda:0 bf16[1, 2048, 18176]"
del t21261
t21303 = torch.reshape(t21302, (1, 2048, 4544)) # t21303: "cuda:0 bf16[1, 2048, 4544]"
# t21303 = ltorch.reshape(t21302, (1, 2048, 4544)) # t21303: "cuda:0 bf16[1, 2048, 4544]"
# t21303 = prims.reshape(t21302, (1, 2048, 4544)) # t21303: "cuda:0 bf16[1, 2048, 4544]"
del t21302
t21311 = torch.reshape(t21303, (1, 2048, 71, 64)) # t21311: "cuda:0 bf16[1, 2048, 71, 64]"
# t21311 = ltorch.reshape(t21303, (1, 2048, 71, 64)) # t21311: "cuda:0 bf16[1, 2048, 71, 64]"
# t21311 = prims.reshape(t21303, (1, 2048, 71, 64)) # t21311: "cuda:0 bf16[1, 2048, 71, 64]"
del t21303
t21314 = torch.permute(t21311, (0, 2, 1, 3)) # t21314: "cuda:0 bf16[1, 71, 2048, 64]"
# t21314 = ltorch.permute(t21311, (0, 2, 1, 3)) # t21314: "cuda:0 bf16[1, 71, 2048, 64]"
# t21314 = prims.transpose(t21311, (0, 2, 1, 3)) # t21314: "cuda:0 bf16[1, 71, 2048, 64]"
del t21311
[t21293] = nvFusion94(f63, f65, t107, t21262)
# t108 = prims.convert_element_type(t107, dtypes.float32) # t108: "cuda:0 f32[1, 2048, 18176]"
# t109 = prims.div(t108, 1.4142135623730951) # t109: "cuda:0 f32[1, 2048, 18176]"
# t112 = prims.erf(t109) # t112: "cuda:0 f32[1, 2048, 18176]"
# t115 = prims.mul(0.5, t112) # t115: "cuda:0 f32[1, 2048, 18176]"
# t118 = prims.add(0.5, t115) # t118: "cuda:0 f32[1, 2048, 18176]"
# t21267 = prims.convert_element_type(t21262, dtypes.float32) # t21267: "cuda:0 f32[1, 2048, 18176]"
# t21268 = prims.mul(t118, t21267) # t21268: "cuda:0 f32[1, 2048, 18176]"
# t21269 = prims.mul(t108, t21267) # t21269: "cuda:0 f32[1, 2048, 18176]"
# t21277 = prims.mul(f65, t21269) # t21277: "cuda:0 f32[1, 2048, 18176]"
# t21280 = prims.pow(t109, 2.0) # t21280: "cuda:0 f32[1, 2048, 18176]"
# t21281 = prims.neg(t21280) # t21281: "cuda:0 f32[1, 2048, 18176]"
# t21282 = prims.exp(t21281) # t21282: "cuda:0 f32[1, 2048, 18176]"
# t21283 = prims.mul(1.1283791670955126, t21282) # t21283: "cuda:0 f32[1, 2048, 18176]"
# t21284 = prims.mul(t21283, t21277) # t21284: "cuda:0 f32[1, 2048, 18176]"
# t21288 = prims.div(t21284, f63) # t21288: "cuda:0 f32[1, 2048, 18176]"
# t21292 = prims.add(t21268, t21288) # t21292: "cuda:0 f32[1, 2048, 18176]"
# t21293 = prims.convert_element_type(t21292, dtypes.bfloat16) # t21293: "cuda:0 bf16[1, 2048, 18176]"
del f63, f65, t107, t21262
t21294 = torch.reshape(t21293, (-1, 18176)) # t21294: "cuda:0 bf16[2048, 18176]"
# t21294 = ltorch.reshape(t21293, (-1, 18176)) # t21294: "cuda:0 bf16[2048, 18176]"
# t21294 = prims.reshape(t21293, (2048, 18176)) # t21294: "cuda:0 bf16[2048, 18176]"
del t21293
t21298 = torch.permute(t21294, (1, 0)) # t21298: "cuda:0 bf16[18176, 2048]"
# t21298 = ltorch.permute(t21294, (1, 0)) # t21298: "cuda:0 bf16[18176, 2048]"
# t21298 = prims.transpose(t21294, (1, 0)) # t21298: "cuda:0 bf16[18176, 2048]"
t21300 = torch.matmul(t21298, t21299) # t21300: "cuda:0 bf16[18176, 4544]"
# t21300 = ltorch.matmul(t21298, t21299) # t21300: "cuda:0 bf16[18176, 4544]"
# t21300 = prims.matmul(t21298, t21299) # t21300: "cuda:0 bf16[18176, 4544]"
del t21298
t21295 = torch.matmul(t21294, t_transformer_h_0_mlp_fc_weight) # t21295: "cuda:0 bf16[2048, 4544]"
# t21295 = ltorch.matmul(t21294, t_transformer_h_0_mlp_fc_weight) # t21295: "cuda:0 bf16[2048, 4544]"
# t21295 = prims.matmul(t21294, t_transformer_h_0_mlp_fc_weight) # t21295: "cuda:0 bf16[2048, 4544]"
del t21294, t_transformer_h_0_mlp_fc_weight
(t21315, t21316, t21317) = cudnn_sdpa_bwd(t21314, t96, t99, t51, None, f54, b55, t100, t101, t102, t103, scale=f56, cat_grad_qkv=False)
del t21314, t96, t99, t51, f54, b55, t100, t101, t102, t103, f56
t21319 = torch_slice_prim_impl(t21316, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t21319: "cuda:0 bf16[1, 71, 2048, 64]"
del t21316
t21323 = torch_slice_prim_impl(t21315, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t21323: "cuda:0 bf16[1, 71, 2048, 64]"
del t21315
t21426 = torch.reshape(t21317, (1, 1, 71, 2048, 64)) # t21426: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t21426 = ltorch.reshape(t21317, (1, 1, 71, 2048, 64)) # t21426: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t21426 = prims.reshape(t21317, (1, 1, 71, 2048, 64)) # t21426: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t21317
[t21460] = nvFusion95(i27, t21319, t21323, t21426, t61, t66)
# t21320 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t21320: "cuda:0 bf16[1, 71, 2048, 0]"
# t21321 = prims.pad(t21320, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t21321: "cuda:0 bf16[1, 71, 2048, 64]"
# t21324 = prims.full([1, 71, 2048, 0], 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t21324: "cuda:0 bf16[1, 71, 2048, 0]"
# t21325 = prims.pad(t21324, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t21325: "cuda:0 bf16[1, 71, 2048, 64]"
# t21326 = prims.convert_element_type(t21319, dtypes.float32) # t21326: "cuda:0 f32[1, 71, 2048, 64]"
# t21330 = prims.mul(t66, t21326) # t21330: "cuda:0 f32[1, 71, 2048, 64]"
# t21333 = prims.convert_element_type(t21330, dtypes.bfloat16) # t21333: "cuda:0 bf16[1, 71, 2048, 64]"
# t21342 = prims.mul(t61, t21326) # t21342: "cuda:0 f32[1, 71, 2048, 64]"
# t21354 = prims.slice_prim(t21333, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t21354: "cuda:0 bf16[1, 71, 2048, 32]"
# t21355 = prims.slice_prim(t21333, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t21355: "cuda:0 bf16[1, 71, 2048, 32]"
# t21356 = prims.convert_element_type(t21354, dtypes.float32) # t21356: "cuda:0 f32[1, 71, 2048, 32]"
# t21357 = prims.neg(t21356) # t21357: "cuda:0 f32[1, 71, 2048, 32]"
# t21358 = prims.convert_element_type(t21357, dtypes.bfloat16) # t21358: "cuda:0 bf16[1, 71, 2048, 32]"
# t21359 = prims.pad(t21358, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t21359: "cuda:0 bf16[1, 71, 2048, 64]"
# t21361 = prims.convert_element_type(t21359, dtypes.float32) # t21361: "cuda:0 f32[1, 71, 2048, 64]"
# t21362 = prims.add(t21342, t21361) # t21362: "cuda:0 f32[1, 71, 2048, 64]"
# t21364 = prims.pad(t21355, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t21364: "cuda:0 bf16[1, 71, 2048, 64]"
# t21366 = prims.convert_element_type(t21364, dtypes.float32) # t21366: "cuda:0 f32[1, 71, 2048, 64]"
# t21367 = prims.add(t21362, t21366) # t21367: "cuda:0 f32[1, 71, 2048, 64]"
# t21368 = prims.convert_element_type(t21367, dtypes.bfloat16) # t21368: "cuda:0 bf16[1, 71, 2048, 64]"
# t21369 = prims.pad(t21368, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t21369: "cuda:0 bf16[1, 71, 2048, 64]"
# t21370 = prims.convert_element_type(t21321, dtypes.float32) # t21370: "cuda:0 f32[1, 71, 2048, 64]"
# t21371 = prims.convert_element_type(t21369, dtypes.float32) # t21371: "cuda:0 f32[1, 71, 2048, 64]"
# t21372 = prims.add(t21370, t21371) # t21372: "cuda:0 f32[1, 71, 2048, 64]"
# t21373 = prims.convert_element_type(t21372, dtypes.bfloat16) # t21373: "cuda:0 bf16[1, 71, 2048, 64]"
# t21374 = prims.convert_element_type(t21323, dtypes.float32) # t21374: "cuda:0 f32[1, 71, 2048, 64]"
# t21378 = prims.mul(t66, t21374) # t21378: "cuda:0 f32[1, 71, 2048, 64]"
# t21381 = prims.convert_element_type(t21378, dtypes.bfloat16) # t21381: "cuda:0 bf16[1, 71, 2048, 64]"
# t21390 = prims.mul(t61, t21374) # t21390: "cuda:0 f32[1, 71, 2048, 64]"
# t21402 = prims.slice_prim(t21381, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t21402: "cuda:0 bf16[1, 71, 2048, 32]"
# t21403 = prims.slice_prim(t21381, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t21403: "cuda:0 bf16[1, 71, 2048, 32]"
# t21404 = prims.convert_element_type(t21402, dtypes.float32) # t21404: "cuda:0 f32[1, 71, 2048, 32]"
# t21405 = prims.neg(t21404) # t21405: "cuda:0 f32[1, 71, 2048, 32]"
# t21406 = prims.convert_element_type(t21405, dtypes.bfloat16) # t21406: "cuda:0 bf16[1, 71, 2048, 32]"
# t21407 = prims.pad(t21406, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (32, 0, 0))) # t21407: "cuda:0 bf16[1, 71, 2048, 64]"
# t21409 = prims.convert_element_type(t21407, dtypes.float32) # t21409: "cuda:0 f32[1, 71, 2048, 64]"
# t21410 = prims.add(t21390, t21409) # t21410: "cuda:0 f32[1, 71, 2048, 64]"
# t21412 = prims.pad(t21403, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 32, 0))) # t21412: "cuda:0 bf16[1, 71, 2048, 64]"
# t21414 = prims.convert_element_type(t21412, dtypes.float32) # t21414: "cuda:0 f32[1, 71, 2048, 64]"
# t21415 = prims.add(t21410, t21414) # t21415: "cuda:0 f32[1, 71, 2048, 64]"
# t21416 = prims.convert_element_type(t21415, dtypes.bfloat16) # t21416: "cuda:0 bf16[1, 71, 2048, 64]"
# t21417 = prims.pad(t21416, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))) # t21417: "cuda:0 bf16[1, 71, 2048, 64]"
# t21418 = prims.convert_element_type(t21325, dtypes.float32) # t21418: "cuda:0 f32[1, 71, 2048, 64]"
# t21419 = prims.convert_element_type(t21417, dtypes.float32) # t21419: "cuda:0 f32[1, 71, 2048, 64]"
# t21420 = prims.add(t21418, t21419) # t21420: "cuda:0 f32[1, 71, 2048, 64]"
# t21421 = prims.convert_element_type(t21420, dtypes.bfloat16) # t21421: "cuda:0 bf16[1, 71, 2048, 64]"
# t21431 = prims.reshape(t21373, (1, 1, 71, 2048, 64)) # t21431: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t21436 = prims.reshape(t21421, (1, 1, 71, 2048, 64)) # t21436: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t21442 = prims.convert_element_type(t21426, dtypes.float32) # t21442: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t21443 = prims.sum(t21442, (0, 1, 2)) # t21443: "cuda:0 f32[2048, 64]"
# t21444 = prims.convert_element_type(t21443, dtypes.bfloat16) # t21444: "cuda:0 bf16[2048, 64]"
# t21445 = prims.broadcast_in_dim(t21444, [1, 1, 1, 2048, 64], [3, 4]) # t21445: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t21451 = prims.convert_element_type(t21431, dtypes.float32) # t21451: "cuda:0 f32[1, 1, 71, 2048, 64]"
# t21452 = prims.sum(t21451, (0, 1, 2)) # t21452: "cuda:0 f32[2048, 64]"
# t21453 = prims.convert_element_type(t21452, dtypes.bfloat16) # t21453: "cuda:0 bf16[2048, 64]"
# t21454 = prims.broadcast_in_dim(t21453, [1, 1, 1, 2048, 64], [3, 4]) # t21454: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t21460 = prims.cat((t21436, t21454, t21445), i27) # t21460: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del i27, t21319, t21323, t21426, t61, t66
t21466 = torch.permute(t21460, (0, 3, 1, 2, 4)) # t21466: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t21466 = ltorch.permute(t21460, (0, 3, 1, 2, 4)) # t21466: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t21466 = prims.transpose(t21460, (0, 3, 1, 2, 4)) # t21466: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t21460
t21472 = torch.reshape(t21466, (1, 2048, 4672)) # t21472: "cuda:0 bf16[1, 2048, 4672]"
# t21472 = ltorch.reshape(t21466, (1, 2048, 4672)) # t21472: "cuda:0 bf16[1, 2048, 4672]"
# t21472 = prims.reshape(t21466, (1, 2048, 4672)) # t21472: "cuda:0 bf16[1, 2048, 4672]"
del t21466
t21473 = torch.reshape(t21472, (-1, 4672)) # t21473: "cuda:0 bf16[2048, 4672]"
# t21473 = ltorch.reshape(t21472, (-1, 4672)) # t21473: "cuda:0 bf16[2048, 4672]"
# t21473 = prims.reshape(t21472, (2048, 4672)) # t21473: "cuda:0 bf16[2048, 4672]"
del t21472
t21477 = torch.permute(t21473, (1, 0)) # t21477: "cuda:0 bf16[4672, 2048]"
# t21477 = ltorch.permute(t21473, (1, 0)) # t21477: "cuda:0 bf16[4672, 2048]"
# t21477 = prims.transpose(t21473, (1, 0)) # t21477: "cuda:0 bf16[4672, 2048]"
t21479 = torch.matmul(t21477, t21299) # t21479: "cuda:0 bf16[4672, 4544]"
# t21479 = ltorch.matmul(t21477, t21478) # t21479: "cuda:0 bf16[4672, 4544]"
# t21479 = prims.matmul(t21477, t21478) # t21479: "cuda:0 bf16[4672, 4544]"
del t21477, t21299
t21474 = torch.matmul(t21473, t_transformer_h_0_attn_attn_weight) # t21474: "cuda:0 bf16[2048, 4544]"
# t21474 = ltorch.matmul(t21473, t_transformer_h_0_attn_attn_weight) # t21474: "cuda:0 bf16[2048, 4544]"
# t21474 = prims.matmul(t21473, t_transformer_h_0_attn_attn_weight) # t21474: "cuda:0 bf16[2048, 4544]"
del t21473, t_transformer_h_0_attn_attn_weight
t21296 = torch.reshape(t21295, (1, 2048, 4544)) # t21296: "cuda:0 bf16[1, 2048, 4544]"
# t21296 = ltorch.reshape(t21295, (1, 2048, 4544)) # t21296: "cuda:0 bf16[1, 2048, 4544]"
# t21296 = prims.reshape(t21295, (1, 2048, 4544)) # t21296: "cuda:0 bf16[1, 2048, 4544]"
del t21295
t21475 = torch.reshape(t21474, (1, 2048, 4544)) # t21475: "cuda:0 bf16[1, 2048, 4544]"
# t21475 = ltorch.reshape(t21474, (1, 2048, 4544)) # t21475: "cuda:0 bf16[1, 2048, 4544]"
# t21475 = prims.reshape(t21474, (1, 2048, 4544)) # t21475: "cuda:0 bf16[1, 2048, 4544]"
del t21474
[t21488, t21494, t21536] = nvFusion96(i21516, t13, t19, t21253, t21296, t21475, t4, t9)
# t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 2048, 4544]"
# t11 = prims.broadcast_in_dim(t9, [1, 2048, 1], [0, 1]) # t11: "cuda:0 f32[1, 2048, 1]"
# t14 = prims.broadcast_in_dim(t11, (1, 2048, 4544), (0, 1, 2)) # t14: "cuda:0 f32[1, 2048, 4544]"
# t16 = prims.sub(t5, t14) # t16: "cuda:0 f32[1, 2048, 4544]"
# t17 = prims.broadcast_in_dim(t13, (1, 2048, 4544), (0, 1, 2)) # t17: "cuda:0 f32[1, 2048, 4544]"
# t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 2048, 4544]"
# t20 = prims.convert_element_type(t19, dtypes.float32) # t20: "cuda:0 f32[1, 2048, 4544]"
# t21533 = prims.convert_element_type(t21253, dtypes.float32) # t21533: "cuda:0 f32[1, 2048, 4544]"
# t21480 = prims.convert_element_type(t21296, dtypes.float32) # t21480: "cuda:0 f32[1, 2048, 4544]"
# t21481 = prims.convert_element_type(t21475, dtypes.float32) # t21481: "cuda:0 f32[1, 2048, 4544]"
# t21482 = prims.add(t21480, t21481) # t21482: "cuda:0 f32[1, 2048, 4544]"
# t21487 = prims.sum(t21482, (0, 1)) # t21487: "cuda:0 f32[4544]"
# t21488 = prims.convert_element_type(t21487, dtypes.bfloat16) # t21488: "cuda:0 bf16[4544]"
# t21489 = prims.mul(t20, t21482) # t21489: "cuda:0 f32[1, 2048, 4544]"
# t21490 = prims.mul(t18, t21482) # t21490: "cuda:0 f32[1, 2048, 4544]"
# t21493 = prims.sum(t21490, (0, 1)) # t21493: "cuda:0 f32[4544]"
# t21494 = prims.convert_element_type(t21493, dtypes.bfloat16) # t21494: "cuda:0 bf16[4544]"
# t21495 = prims.mul(t17, t21489) # t21495: "cuda:0 f32[1, 2048, 4544]"
# t21496 = prims.mul(t16, t21489) # t21496: "cuda:0 f32[1, 2048, 4544]"
# t21497 = prims.sum(t21496, (0, 2)) # t21497: "cuda:0 f32[2048]"
# t21498 = prims.broadcast_in_dim(t21497, [1, 2048, 1], [1]) # t21498: "cuda:0 f32[1, 2048, 1]"
# t21499 = prims.neg(t21495) # t21499: "cuda:0 f32[1, 2048, 4544]"
# t21501 = prims.sum(t21499, (0, 2)) # t21501: "cuda:0 f32[2048]"
# t21502 = prims.broadcast_in_dim(t21501, [1, 2048, 1], [1]) # t21502: "cuda:0 f32[1, 2048, 1]"
# t21503 = prims.mul(-0.5, t21498) # t21503: "cuda:0 f32[1, 2048, 1]"
# t21504 = prims.pow(t13, 3.0) # t21504: "cuda:0 f32[1, 2048, 1]"
# t21505 = prims.mul(t21503, t21504) # t21505: "cuda:0 f32[1, 2048, 1]"
# t21507 = prims.sum(t21502, (0, 2)) # t21507: "cuda:0 f32[2048]"
# t21508 = prims.broadcast_in_dim(t21507, [1, 2048], [1]) # t21508: "cuda:0 f32[1, 2048]"
# t21509 = prims.sum(t21505, (0, 2)) # t21509: "cuda:0 f32[2048]"
# t21510 = prims.broadcast_in_dim(t21509, [1, 2048], [1]) # t21510: "cuda:0 f32[1, 2048]"
# t21513 = prims.broadcast_in_dim(t21508, [1, 2048, 1], [0, 1]) # t21513: "cuda:0 f32[1, 2048, 1]"
# t21514 = prims.broadcast_in_dim(t21513, (1, 2048, 4544), (0, 1, 2)) # t21514: "cuda:0 f32[1, 2048, 4544]"
# t21515 = prims.mul(0.00022007042253521127, t21514) # t21515: "cuda:0 f32[1, 2048, 4544]"
# t21517 = prims.broadcast_in_dim(t21510, [1, 2048, 1], [0, 1]) # t21517: "cuda:0 f32[1, 2048, 1]"
# t21518 = prims.broadcast_in_dim(t21517, (1, 2048, 4544), (0, 1, 2)) # t21518: "cuda:0 f32[1, 2048, 4544]"
# t21520 = prims.broadcast_in_dim(t9, [1, 2048, 1], [0, 1]) # t21520: "cuda:0 f32[1, 2048, 1]"
# t21521 = prims.broadcast_in_dim(t21520, (1, 2048, 4544), (0, 1, 2)) # t21521: "cuda:0 f32[1, 2048, 4544]"
# t21522 = prims.mul(2.0, t21518) # t21522: "cuda:0 f32[1, 2048, 4544]"
# t21523 = prims.sub(t5, t21521) # t21523: "cuda:0 f32[1, 2048, 4544]"
# t21524 = prims.mul(t21522, t21523) # t21524: "cuda:0 f32[1, 2048, 4544]"
# f21525 = prims.convert_element_type(i21516, float) # f21525: "float 4544.0"
# t21526 = prims.div(t21524, f21525) # t21526: "cuda:0 f32[1, 2048, 4544]"
# t21527 = prims.add(t21515, t21526) # t21527: "cuda:0 f32[1, 2048, 4544]"
# t21531 = prims.add(t21495, t21527) # t21531: "cuda:0 f32[1, 2048, 4544]"
# t21535 = prims.add(t21533, t21531) # t21535: "cuda:0 f32[1, 2048, 4544]"
# t21536 = prims.convert_element_type(t21535, dtypes.bfloat16) # t21536: "cuda:0 bf16[1, 2048, 4544]"
del i21516, t13, t19, t21253, t21296, t21475, t4, t9
t21537 = torch.torch.ops.aten.embedding_backward(t21536, idx, i0, -1, b1, b2) # t21537: "cuda:0 bf16[65024, 4544]"
# t21537 = ltorch.embedding_backward(t21536, idx, i0, -1, b1, b2) # t21537: "cuda:0 bf16[65024, 4544]"
# t21537 = prims.embedding_backward(t21536, idx, i0, -1, b1, b2) # t21537: "cuda:0 bf16[65024, 4544]"
del t21536, idx, i0, b1, b2
return (None, None, t12439, None, t21479, t21307, t21300, t21266, t21488, t21494, t21196, t21024, t21017, t20983, t21205, t21211, t20913, t20741, t20734, t20700, t20922, t20928, t20630, t20458, t20451, t20417, t20639, t20645, t20347, t20175, t20168, t20134, t20356, t20362, t20064, t19892, t19885, t19851, t20073, t20079, t19781, t19609, t19602, t19568, t19790, t19796, t19498, t19326, t19319, t19285, t19507, t19513, t19215, t19043, t19036, t19002, t19224, t19230, t18932, t18760, t18753, t18719, t18941, t18947, t18649, t18477, t18470, t18436, t18658, t18664, t18366, t18194, t18187, t18153, t18375, t18381, t18083, t17911, t17904, t17870, t18092, t18098, t17800, t17628, t17621, t17587, t17809, t17815, t17517, t17345, t17338, t17304, t17526, t17532, t17234, t17062, t17055, t17021, t17243, t17249, t16951, t16779, t16772, t16738, t16960, t16966, t16668, t16496, t16489, t16455, t16677, t16683, t16385, t16213, t16206, t16172, t16394, t16400, t16102, t15930, t15923, t15889, t16111, t16117, t15819, t15647, t15640, t15606, t15828, t15834, t15536, t15364, t15357, t15323, t15545, t15551, t15253, t15081, t15074, t15040, t15262, t15268, t14970, t14798, t14791, t14757, t14979, t14985, t14687, t14515, t14508, t14474, t14696, t14702, t14404, t14232, t14225, t14191, t14413, t14419, t14121, t13949, t13942, t13908, t14130, t14136, t13838, t13666, t13659, t13625, t13847, t13853, t13555, t13383, t13376, t13342, t13564, t13570, t13272, t13100, t13093, t13059, t13281, t13287, t12989, t12817, t12810, t12776, t12998, t13004, t12706, t12542, t12535, t12501, t12715, t12721, t12444, t12450, t21537)
This file has been truncated, but you can view the full file.
# Constructed by Delete Last Used (took 10 milliseconds)
import torch
from torch import Tensor
import torch.nn.functional
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def augmented_forward_fn(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_bias, t_transformer_h_0_norm_1_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_bias, t_transformer_h_1_norm_1_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_bias, t_transformer_h_2_norm_1_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_bias, t_transformer_h_3_norm_1_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_bias, t_transformer_h_4_norm_1_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_bias, t_transformer_h_5_norm_1_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_bias, t_transformer_h_6_norm_1_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_bias, t_transformer_h_7_norm_1_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_bias, t_transformer_h_8_norm_1_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_bias, t_transformer_h_9_norm_1_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_bias, t_transformer_h_10_norm_1_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_bias, t_transformer_h_11_norm_1_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_bias, t_transformer_h_12_norm_1_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_bias, t_transformer_h_13_norm_1_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_bias, t_transformer_h_14_norm_1_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_bias, t_transformer_h_15_norm_1_weight, t_transformer_h_16_attn_attn_weight, t_transformer_h_16_attn_proj_weight, t_transformer_h_16_mlp_fc_weight, t_transformer_h_16_mlp_proj_weight, t_transformer_h_16_norm_1_bias, t_transformer_h_16_norm_1_weight, t_transformer_h_17_attn_attn_weight, t_transformer_h_17_attn_proj_weight, t_transformer_h_17_mlp_fc_weight, t_transformer_h_17_mlp_proj_weight, t_transformer_h_17_norm_1_bias, t_transformer_h_17_norm_1_weight, t_transformer_h_18_attn_attn_weight, t_transformer_h_18_attn_proj_weight, t_transformer_h_18_mlp_fc_weight, t_transformer_h_18_mlp_proj_weight, t_transformer_h_18_norm_1_bias, t_transformer_h_18_norm_1_weight, t_transformer_h_19_attn_attn_weight, t_transformer_h_19_attn_proj_weight, t_transformer_h_19_mlp_fc_weight, t_transformer_h_19_mlp_proj_weight, t_transformer_h_19_norm_1_bias, t_transformer_h_19_norm_1_weight, t_transformer_h_20_attn_attn_weight, t_transformer_h_20_attn_proj_weight, t_transformer_h_20_mlp_fc_weight, t_transformer_h_20_mlp_proj_weight, t_transformer_h_20_norm_1_bias, t_transformer_h_20_norm_1_weight, t_transformer_h_21_attn_attn_weight, t_transformer_h_21_attn_proj_weight, t_transformer_h_21_mlp_fc_weight, t_transformer_h_21_mlp_proj_weight, t_transformer_h_21_norm_1_bias, t_transformer_h_21_norm_1_weight, t_transformer_h_22_attn_attn_weight, t_transformer_h_22_attn_proj_weight, t_transformer_h_22_mlp_fc_weight, t_transformer_h_22_mlp_proj_weight, t_transformer_h_22_norm_1_bias, t_transformer_h_22_norm_1_weight, t_transformer_h_23_attn_attn_weight, t_transformer_h_23_attn_proj_weight, t_transformer_h_23_mlp_fc_weight, t_transformer_h_23_mlp_proj_weight, t_transformer_h_23_norm_1_bias, t_transformer_h_23_norm_1_weight, t_transformer_h_24_attn_attn_weight, t_transformer_h_24_attn_proj_weight, t_transformer_h_24_mlp_fc_weight, t_transformer_h_24_mlp_proj_weight, t_transformer_h_24_norm_1_bias, t_transformer_h_24_norm_1_weight, t_transformer_h_25_attn_attn_weight, t_transformer_h_25_attn_proj_weight, t_transformer_h_25_mlp_fc_weight, t_transformer_h_25_mlp_proj_weight, t_transformer_h_25_norm_1_bias, t_transformer_h_25_norm_1_weight, t_transformer_h_26_attn_attn_weight, t_transformer_h_26_attn_proj_weight, t_transformer_h_26_mlp_fc_weight, t_transformer_h_26_mlp_proj_weight, t_transformer_h_26_norm_1_bias, t_transformer_h_26_norm_1_weight, t_transformer_h_27_attn_attn_weight, t_transformer_h_27_attn_proj_weight, t_transformer_h_27_mlp_fc_weight, t_transformer_h_27_mlp_proj_weight, t_transformer_h_27_norm_1_bias, t_transformer_h_27_norm_1_weight, t_transformer_h_28_attn_attn_weight, t_transformer_h_28_attn_proj_weight, t_transformer_h_28_mlp_fc_weight, t_transformer_h_28_mlp_proj_weight, t_transformer_h_28_norm_1_bias, t_transformer_h_28_norm_1_weight, t_transformer_h_29_attn_attn_weight, t_transformer_h_29_attn_proj_weight, t_transformer_h_29_mlp_fc_weight, t_transformer_h_29_mlp_proj_weight, t_transformer_h_29_norm_1_bias, t_transformer_h_29_norm_1_weight, t_transformer_h_30_attn_attn_weight, t_transformer_h_30_attn_proj_weight, t_transformer_h_30_mlp_fc_weight, t_transformer_h_30_mlp_proj_weight, t_transformer_h_30_norm_1_bias, t_transformer_h_30_norm_1_weight, t_transformer_h_31_attn_attn_weight, t_transformer_h_31_attn_proj_weight, t_transformer_h_31_mlp_fc_weight, t_transformer_h_31_mlp_proj_weight, t_transformer_h_31_norm_1_bias, t_transformer_h_31_norm_1_weight, t_transformer_ln_f_bias, t_transformer_ln_f_weight, t_transformer_wte_weight):
# idx: "cuda:0 i64[1, 2048]"
# tos1: "cuda:0 bf16[2048, 64]"
# t_lm_head_weight: "cuda:0 bf16[65024, 4544]"
# t_sin: "cuda:0 bf16[2048, 64]"
# t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_0_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_0_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_1_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_1_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_1_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_1_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_1_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_1_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_2_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_2_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_2_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_2_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_2_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_2_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_3_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_3_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_3_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_3_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_3_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_3_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_4_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_4_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_4_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_4_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_4_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_4_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_5_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_5_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_5_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_5_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_5_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_5_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_6_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_6_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_6_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_6_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_6_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_6_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_7_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_7_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_7_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_7_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_7_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_7_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_8_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_8_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_8_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_8_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_8_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_8_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_9_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_9_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_9_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_9_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_9_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_9_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_10_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_10_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_10_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_10_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_10_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_10_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_11_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_11_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_11_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_11_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_11_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_11_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_12_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_12_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_12_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_12_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_12_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_12_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_13_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_13_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_13_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_13_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_13_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_13_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_14_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_14_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_14_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_14_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_14_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_14_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_15_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_15_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_15_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_15_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_15_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_15_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_16_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_16_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_16_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_16_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_16_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_16_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_17_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_17_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_17_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_17_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_17_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_17_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_18_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_18_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_18_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_18_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_18_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_18_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_19_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_19_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_19_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_19_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_19_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_19_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_20_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_20_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_20_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_20_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_20_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_20_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_21_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_21_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_21_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_21_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_21_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_21_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_22_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_22_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_22_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_22_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_22_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_22_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_23_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_23_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_23_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_23_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_23_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_23_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_24_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_24_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_24_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_24_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_24_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_24_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_25_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_25_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_25_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_25_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_25_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_25_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_26_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_26_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_26_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_26_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_26_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_26_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_27_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_27_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_27_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_27_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_27_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_27_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_28_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_28_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_28_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_28_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_28_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_28_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_29_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_29_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_29_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_29_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_29_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_29_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_30_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_30_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_30_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_30_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_30_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_30_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_h_31_attn_attn_weight: "cuda:0 bf16[4672, 4544]"
# t_transformer_h_31_attn_proj_weight: "cuda:0 bf16[4544, 4544]"
# t_transformer_h_31_mlp_fc_weight: "cuda:0 bf16[18176, 4544]"
# t_transformer_h_31_mlp_proj_weight: "cuda:0 bf16[4544, 18176]"
# t_transformer_h_31_norm_1_bias: "cuda:0 bf16[4544]"
# t_transformer_h_31_norm_1_weight: "cuda:0 bf16[4544]"
# t_transformer_ln_f_bias: "cuda:0 bf16[4544]"
# t_transformer_ln_f_weight: "cuda:0 bf16[4544]"
# t_transformer_wte_weight: "cuda:0 bf16[65024, 4544]"
t4 = torch.nn.functional.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4544]"
# t4 = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 2048, 4544]"
# t5149 = ltorch.reshape(idx, [2048]) # t5149: "cuda:0 i64[2048]"
# t5149 = prims.reshape(idx, (2048,)) # t5149: "cuda:0 i64[2048]"
# t5150 = prims.take(t_transformer_wte_weight, t5149, 0) # t5150: "cuda:0 bf16[2048, 4544]"
# t4 = ltorch.reshape(t5150, [1, 2048, 4544]) # t4: "cuda:0 bf16[1, 2048, 4544]"
# t4 = prims.reshape(t5150, (1, 2048, 4544)) # t4: "cuda:0 bf16[1, 2048, 4544]"
t5281 = torch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t5281: "cuda:0 bf16[1, 4544]"
# t5281 = ltorch.unsqueeze(t_transformer_h_0_norm_1_weight, 0) # t5281: "cuda:0 bf16[1, 4544]"
# t5281 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, [1, 4544], [1]) # t5281: "cuda:0 bf16[1, 4544]"
t5282 = torch.unsqueeze(t5281, 1) # t5282: "cuda:0 bf16[1, 1, 4544]"
# t5282 = ltorch.unsqueeze(t5281, 1) # t5282: "cuda:0 bf16[1, 1, 4544]"
# t5282 = prims.broadcast_in_dim(t5281, [1, 1, 4544], [0, 2]) # t5282: "cuda:0 bf16[1, 1, 4544]"
del t5281
t19 = Tensor.expand(t5282, (1, 2048, 4544)) # t19: "cuda:0 bf16[1, 2048, 4544]"
# t19 = ltorch.expand(t5282, (1, 2048, 4544)) # t19: "cuda:0 bf16[1, 2048, 4544]"
# t19 = prims.broadcast_in_dim(t5282, (1, 2048, 4544), (0, 1, 2)) # t19: "cuda:0 bf16[1, 2048, 4544]"
del t5282
t5284 = torch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t5284: "cuda:0 bf16[1, 4544]"
# t5284 = ltorch.unsqueeze(t_transformer_h_0_norm_1_bias, 0) # t5284: "cuda:0 bf16[1, 4544]"
# t5284 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_bias, [1, 4544], [1]) # t5284: "cuda:0 bf16[1, 4544]"
t5285 = torch.unsqueeze(t5284, 1) # t5285: "cuda:0 bf16[1, 1, 4544]"
# t5285 = ltorch.unsqueeze(t5284, 1) # t5285: "cuda:0 bf16[1, 1, 4544]"
# t5285 = prims.broadcast_in_dim(t5284, [1, 1, 4544], [0, 2]) # t5285: "cuda:0 bf16[1, 1, 4544]"
del t5284
t22 = Tensor.expand(t5285, (1, 2048, 4544)) # t22: "cuda:0 bf16[1, 2048, 4544]"
# t22 = ltorch.expand(t5285, (1, 2048, 4544)) # t22: "cuda:0 bf16[1, 2048, 4544]"
# t22 = prims.broadcast_in_dim(t5285, (1, 2048, 4544), (0, 1, 2)) # t22: "cuda:0 bf16[1, 2048, 4544]"
del t5285
[t13, t25, t9] = nvFusion0(t19, t22, t4)
# t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 2048, 4544]"
# (t8, t9) = prims.var_mean(t5, (2,), correction=0)
# t10 = prims.broadcast_in_dim(t8, [1, 2048, 1], [0, 1]) # t10: "cuda:0 f32[1, 2048, 1]"
# t11 = prims.broadcast_in_dim(t9, [1, 2048, 1], [0, 1]) # t11: "cuda:0 f32[1, 2048, 1]"
# t12 = prims.add(t10, 1e-05) # t12: "cuda:0 f32[1, 2048, 1]"
# t13 = prims.rsqrt(t12) # t13: "cuda:0 f32[1, 2048, 1]"
# t14 = prims.broadcast_in_dim(t11, (1, 2048, 4544), (0, 1, 2)) # t14: "cuda:0 f32[1, 2048, 4544]"
# t16 = prims.sub(t5, t14) # t16: "cuda:0 f32[1, 2048, 4544]"
# t17 = prims.broadcast_in_dim(t13, (1, 2048, 4544), (0, 1, 2)) # t17: "cuda:0 f32[1, 2048, 4544]"
# t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 2048, 4544]"
# t20 = prims.convert_element_type(t19, dtypes.float32) # t20: "cuda:0 f32[1, 2048, 4544]"
# t21 = prims.mul(t18, t20) # t21: "cuda:0 f32[1, 2048, 4544]"
# t23 = prims.convert_element_type(t22, dtypes.float32) # t23: "cuda:0 f32[1, 2048, 4544]"
# t24 = prims.add(t21, t23) # t24: "cuda:0 f32[1, 2048, 4544]"
# t25 = prims.convert_element_type(t24, dtypes.bfloat16) # t25: "cuda:0 bf16[1, 2048, 4544]"
del t22
t26 = torch.nn.functional.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 4672]"
# t26 = ltorch.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 4672]"
# t26 = prims.linear(t25, t_transformer_h_0_attn_attn_weight, None) # t26: "cuda:0 bf16[1, 2048, 4672]"
t107 = torch.nn.functional.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]"
# t107 = ltorch.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]"
# t107 = prims.linear(t25, t_transformer_h_0_mlp_fc_weight, None) # t107: "cuda:0 bf16[1, 2048, 18176]"
t0 = torch_slice_prim_impl(tos1, [0, 0], [2048, 64], [1, 1]) # t0: "cuda:0 bf16[2048, 64]"
t1 = torch_slice_prim_impl(t_sin, [0, 0], [2048, 64], [1, 1]) # t1: "cuda:0 bf16[2048, 64]"
t27 = torch.reshape(t26, (1, 2048, 1, 73, 64)) # t27: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t27 = ltorch.reshape(t26, (1, 2048, 1, 73, 64)) # t27: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t27 = prims.reshape(t26, (1, 2048, 1, 73, 64)) # t27: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t26
t28 = torch.permute(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t28 = ltorch.permute(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t28 = prims.transpose(t27, (0, 2, 3, 1, 4)) # t28: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t27
(t29, t30, t31) = torch.split(t28, (71, 1, 1), 2)
# (t29, t30, t31) = ltorch.split(t28, (71, 1, 1), 2)
# t29 = prims.slice_prim(t28, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t29: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t30 = prims.slice_prim(t28, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t30: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t31 = prims.slice_prim(t28, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t31: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t28
t32 = Tensor.expand(t30, (1, 1, 71, 2048, 64)) # t32: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t32 = ltorch.expand(t30, (1, 1, 71, 2048, 64)) # t32: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t32 = prims.broadcast_in_dim(t30, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t32: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t30
t38 = Tensor.expand(t31, (1, 1, 71, 2048, 64)) # t38: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t38 = ltorch.expand(t31, (1, 1, 71, 2048, 64)) # t38: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t38 = prims.broadcast_in_dim(t31, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t38: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t31
t39 = torch.reshape(t29, (1, 71, 2048, 64)) # t39: "cuda:0 bf16[1, 71, 2048, 64]"
# t39 = ltorch.reshape(t29, (1, 71, 2048, 64)) # t39: "cuda:0 bf16[1, 71, 2048, 64]"
# t39 = prims.reshape(t29, (1, 71, 2048, 64)) # t39: "cuda:0 bf16[1, 71, 2048, 64]"
del t29
t45 = torch.reshape(t32, (1, 71, 2048, 64)) # t45: "cuda:0 bf16[1, 71, 2048, 64]"
# t45 = ltorch.reshape(t32, (1, 71, 2048, 64)) # t45: "cuda:0 bf16[1, 71, 2048, 64]"
# t45 = prims.reshape(t32, (1, 71, 2048, 64)) # t45: "cuda:0 bf16[1, 71, 2048, 64]"
del t32
t51 = torch.reshape(t38, (1, 71, 2048, 64)) # t51: "cuda:0 bf16[1, 71, 2048, 64]"
# t51 = ltorch.reshape(t38, (1, 71, 2048, 64)) # t51: "cuda:0 bf16[1, 71, 2048, 64]"
# t51 = prims.reshape(t38, (1, 71, 2048, 64)) # t51: "cuda:0 bf16[1, 71, 2048, 64]"
del t38
t52 = torch_slice_prim_impl(t39, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t52: "cuda:0 bf16[1, 71, 2048, 64]"
t53 = torch_slice_prim_impl(t52, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t53: "cuda:0 bf16[1, 71, 2048, 32]"
t54 = torch_slice_prim_impl(t52, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t54: "cuda:0 bf16[1, 71, 2048, 32]"
t5302 = torch.unsqueeze(t0, 0) # t5302: "cuda:0 bf16[1, 2048, 64]"
# t5302 = ltorch.unsqueeze(t0, 0) # t5302: "cuda:0 bf16[1, 2048, 64]"
# t5302 = prims.broadcast_in_dim(t0, [1, 2048, 64], [1, 2]) # t5302: "cuda:0 bf16[1, 2048, 64]"
del t0
t5303 = torch.unsqueeze(t5302, 1) # t5303: "cuda:0 bf16[1, 1, 2048, 64]"
# t5303 = ltorch.unsqueeze(t5302, 1) # t5303: "cuda:0 bf16[1, 1, 2048, 64]"
# t5303 = prims.broadcast_in_dim(t5302, [1, 1, 2048, 64], [0, 2, 3]) # t5303: "cuda:0 bf16[1, 1, 2048, 64]"
del t5302
t59 = Tensor.expand(t5303, (1, 71, 2048, 64)) # t59: "cuda:0 bf16[1, 71, 2048, 64]"
# t59 = ltorch.expand(t5303, (1, 71, 2048, 64)) # t59: "cuda:0 bf16[1, 71, 2048, 64]"
# t59 = prims.broadcast_in_dim(t5303, (1, 71, 2048, 64), (0, 1, 2, 3)) # t59: "cuda:0 bf16[1, 71, 2048, 64]"
del t5303
t5305 = torch.unsqueeze(t1, 0) # t5305: "cuda:0 bf16[1, 2048, 64]"
# t5305 = ltorch.unsqueeze(t1, 0) # t5305: "cuda:0 bf16[1, 2048, 64]"
# t5305 = prims.broadcast_in_dim(t1, [1, 2048, 64], [1, 2]) # t5305: "cuda:0 bf16[1, 2048, 64]"
del t1
t5306 = torch.unsqueeze(t5305, 1) # t5306: "cuda:0 bf16[1, 1, 2048, 64]"
# t5306 = ltorch.unsqueeze(t5305, 1) # t5306: "cuda:0 bf16[1, 1, 2048, 64]"
# t5306 = prims.broadcast_in_dim(t5305, [1, 1, 2048, 64], [0, 2, 3]) # t5306: "cuda:0 bf16[1, 1, 2048, 64]"
del t5305
t64 = Tensor.expand(t5306, (1, 71, 2048, 64)) # t64: "cuda:0 bf16[1, 71, 2048, 64]"
# t64 = ltorch.expand(t5306, (1, 71, 2048, 64)) # t64: "cuda:0 bf16[1, 71, 2048, 64]"
# t64 = prims.broadcast_in_dim(t5306, (1, 71, 2048, 64), (0, 1, 2, 3)) # t64: "cuda:0 bf16[1, 71, 2048, 64]"
del t5306
t73 = torch_slice_prim_impl(t45, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t73: "cuda:0 bf16[1, 71, 2048, 64]"
t74 = torch_slice_prim_impl(t73, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t74: "cuda:0 bf16[1, 71, 2048, 32]"
t75 = torch_slice_prim_impl(t73, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t75: "cuda:0 bf16[1, 71, 2048, 32]"
t95 = torch_slice_prim_impl(t39, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t95: "cuda:0 bf16[1, 71, 2048, 0]"
del t39
t97 = torch_slice_prim_impl(t45, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t97: "cuda:0 bf16[1, 71, 2048, 0]"
del t45
[t123, t61, t66, t96, t99] = nvFusion1(t107, t52, t53, t54, t59, t64, t73, t74, t75, t95, t97)
# t55 = prims.convert_element_type(t54, dtypes.float32) # t55: "cuda:0 f32[1, 71, 2048, 32]"
# t56 = prims.neg(t55) # t56: "cuda:0 f32[1, 71, 2048, 32]"
# t57 = prims.convert_element_type(t56, dtypes.bfloat16) # t57: "cuda:0 bf16[1, 71, 2048, 32]"
# t58 = prims.cat((t57, t53), -1) # t58: "cuda:0 bf16[1, 71, 2048, 64]"
# t60 = prims.convert_element_type(t52, dtypes.float32) # t60: "cuda:0 f32[1, 71, 2048, 64]"
# t61 = prims.convert_element_type(t59, dtypes.float32) # t61: "cuda:0 f32[1, 71, 2048, 64]"
# t62 = prims.mul(t60, t61) # t62: "cuda:0 f32[1, 71, 2048, 64]"
# t65 = prims.convert_element_type(t58, dtypes.float32) # t65: "cuda:0 f32[1, 71, 2048, 64]"
# t66 = prims.convert_element_type(t64, dtypes.float32) # t66: "cuda:0 f32[1, 71, 2048, 64]"
# t67 = prims.mul(t65, t66) # t67: "cuda:0 f32[1, 71, 2048, 64]"
# t71 = prims.add(t62, t67) # t71: "cuda:0 f32[1, 71, 2048, 64]"
# t72 = prims.convert_element_type(t71, dtypes.bfloat16) # t72: "cuda:0 bf16[1, 71, 2048, 64]"
# t76 = prims.convert_element_type(t75, dtypes.float32) # t76: "cuda:0 f32[1, 71, 2048, 32]"
# t77 = prims.neg(t76) # t77: "cuda:0 f32[1, 71, 2048, 32]"
# t78 = prims.convert_element_type(t77, dtypes.bfloat16) # t78: "cuda:0 bf16[1, 71, 2048, 32]"
# t80 = prims.cat((t78, t74), -1) # t80: "cuda:0 bf16[1, 71, 2048, 64]"
# t82 = prims.convert_element_type(t73, dtypes.float32) # t82: "cuda:0 f32[1, 71, 2048, 64]"
# t84 = prims.mul(t82, t61) # t84: "cuda:0 f32[1, 71, 2048, 64]"
# t87 = prims.convert_element_type(t80, dtypes.float32) # t87: "cuda:0 f32[1, 71, 2048, 64]"
# t89 = prims.mul(t87, t66) # t89: "cuda:0 f32[1, 71, 2048, 64]"
# t93 = prims.add(t84, t89) # t93: "cuda:0 f32[1, 71, 2048, 64]"
# t94 = prims.convert_element_type(t93, dtypes.bfloat16) # t94: "cuda:0 bf16[1, 71, 2048, 64]"
# t96 = prims.cat((t72, t95), -1) # t96: "cuda:0 bf16[1, 71, 2048, 64]"
# t99 = prims.cat((t94, t97), -1) # t99: "cuda:0 bf16[1, 71, 2048, 64]"
# t108 = prims.convert_element_type(t107, dtypes.float32) # t108: "cuda:0 f32[1, 2048, 18176]"
# t109 = prims.div(t108, 1.4142135623730951) # t109: "cuda:0 f32[1, 2048, 18176]"
# t112 = prims.erf(t109) # t112: "cuda:0 f32[1, 2048, 18176]"
# t115 = prims.mul(0.5, t112) # t115: "cuda:0 f32[1, 2048, 18176]"
# t118 = prims.add(0.5, t115) # t118: "cuda:0 f32[1, 2048, 18176]"
# t122 = prims.mul(t108, t118) # t122: "cuda:0 f32[1, 2048, 18176]"
# t123 = prims.convert_element_type(t122, dtypes.bfloat16) # t123: "cuda:0 bf16[1, 2048, 18176]"
del t52, t53, t54, t59, t64, t73, t74, t75, t95, t97
(t100, t101, t102, t103) = cudnn_sdpa_fwd(t96, t99, t51, None, 0.0, True, scale=0.125)
t124 = torch.nn.functional.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4544]"
# t124 = ltorch.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4544]"
# t124 = prims.linear(t123, t_transformer_h_0_mlp_proj_weight, None) # t124: "cuda:0 bf16[1, 2048, 4544]"
t104 = torch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 71, 64]"
# t104 = ltorch.permute(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 71, 64]"
# t104 = prims.transpose(t100, (0, 2, 1, 3)) # t104: "cuda:0 bf16[1, 2048, 71, 64]"
t105 = torch.reshape(t104, (1, 2048, 4544)) # t105: "cuda:0 bf16[1, 2048, 4544]"
# t105 = ltorch.reshape(t104, (1, 2048, 4544)) # t105: "cuda:0 bf16[1, 2048, 4544]"
# t105 = prims.reshape(t104, (1, 2048, 4544)) # t105: "cuda:0 bf16[1, 2048, 4544]"
del t104
t106 = torch.nn.functional.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4544]"
# t106 = ltorch.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4544]"
# t106 = prims.linear(t105, t_transformer_h_0_attn_proj_weight, None) # t106: "cuda:0 bf16[1, 2048, 4544]"
t5315 = torch.unsqueeze(t_transformer_h_1_norm_1_weight, 0) # t5315: "cuda:0 bf16[1, 4544]"
# t5315 = ltorch.unsqueeze(t_transformer_h_1_norm_1_weight, 0) # t5315: "cuda:0 bf16[1, 4544]"
# t5315 = prims.broadcast_in_dim(t_transformer_h_1_norm_1_weight, [1, 4544], [1]) # t5315: "cuda:0 bf16[1, 4544]"
t5316 = torch.unsqueeze(t5315, 1) # t5316: "cuda:0 bf16[1, 1, 4544]"
# t5316 = ltorch.unsqueeze(t5315, 1) # t5316: "cuda:0 bf16[1, 1, 4544]"
# t5316 = prims.broadcast_in_dim(t5315, [1, 1, 4544], [0, 2]) # t5316: "cuda:0 bf16[1, 1, 4544]"
del t5315
t150 = Tensor.expand(t5316, (1, 2048, 4544)) # t150: "cuda:0 bf16[1, 2048, 4544]"
# t150 = ltorch.expand(t5316, (1, 2048, 4544)) # t150: "cuda:0 bf16[1, 2048, 4544]"
# t150 = prims.broadcast_in_dim(t5316, (1, 2048, 4544), (0, 1, 2)) # t150: "cuda:0 bf16[1, 2048, 4544]"
del t5316
t5318 = torch.unsqueeze(t_transformer_h_1_norm_1_bias, 0) # t5318: "cuda:0 bf16[1, 4544]"
# t5318 = ltorch.unsqueeze(t_transformer_h_1_norm_1_bias, 0) # t5318: "cuda:0 bf16[1, 4544]"
# t5318 = prims.broadcast_in_dim(t_transformer_h_1_norm_1_bias, [1, 4544], [1]) # t5318: "cuda:0 bf16[1, 4544]"
t5319 = torch.unsqueeze(t5318, 1) # t5319: "cuda:0 bf16[1, 1, 4544]"
# t5319 = ltorch.unsqueeze(t5318, 1) # t5319: "cuda:0 bf16[1, 1, 4544]"
# t5319 = prims.broadcast_in_dim(t5318, [1, 1, 4544], [0, 2]) # t5319: "cuda:0 bf16[1, 1, 4544]"
del t5318
t153 = Tensor.expand(t5319, (1, 2048, 4544)) # t153: "cuda:0 bf16[1, 2048, 4544]"
# t153 = ltorch.expand(t5319, (1, 2048, 4544)) # t153: "cuda:0 bf16[1, 2048, 4544]"
# t153 = prims.broadcast_in_dim(t5319, (1, 2048, 4544), (0, 1, 2)) # t153: "cuda:0 bf16[1, 2048, 4544]"
del t5319
[t132, t139, t144, t156] = nvFusion2(t106, t124, t150, t153, t4)
# t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 2048, 4544]"
# t126 = prims.convert_element_type(t106, dtypes.float32) # t126: "cuda:0 f32[1, 2048, 4544]"
# t127 = prims.add(t125, t126) # t127: "cuda:0 f32[1, 2048, 4544]"
# t130 = prims.convert_element_type(t4, dtypes.float32) # t130: "cuda:0 f32[1, 2048, 4544]"
# t131 = prims.add(t127, t130) # t131: "cuda:0 f32[1, 2048, 4544]"
# t132 = prims.convert_element_type(t131, dtypes.bfloat16) # t132: "cuda:0 bf16[1, 2048, 4544]"
# (t138, t139) = prims.var_mean(t131, (2,), correction=0)
# t140 = prims.broadcast_in_dim(t138, [1, 2048, 1], [0, 1]) # t140: "cuda:0 f32[1, 2048, 1]"
# t141 = prims.broadcast_in_dim(t139, [1, 2048, 1], [0, 1]) # t141: "cuda:0 f32[1, 2048, 1]"
# t143 = prims.add(t140, 1e-05) # t143: "cuda:0 f32[1, 2048, 1]"
# t144 = prims.rsqrt(t143) # t144: "cuda:0 f32[1, 2048, 1]"
# t145 = prims.broadcast_in_dim(t141, (1, 2048, 4544), (0, 1, 2)) # t145: "cuda:0 f32[1, 2048, 4544]"
# t147 = prims.sub(t131, t145) # t147: "cuda:0 f32[1, 2048, 4544]"
# t148 = prims.broadcast_in_dim(t144, (1, 2048, 4544), (0, 1, 2)) # t148: "cuda:0 f32[1, 2048, 4544]"
# t149 = prims.mul(t147, t148) # t149: "cuda:0 f32[1, 2048, 4544]"
# t151 = prims.convert_element_type(t150, dtypes.float32) # t151: "cuda:0 f32[1, 2048, 4544]"
# t152 = prims.mul(t149, t151) # t152: "cuda:0 f32[1, 2048, 4544]"
# t154 = prims.convert_element_type(t153, dtypes.float32) # t154: "cuda:0 f32[1, 2048, 4544]"
# t155 = prims.add(t152, t154) # t155: "cuda:0 f32[1, 2048, 4544]"
# t156 = prims.convert_element_type(t155, dtypes.bfloat16) # t156: "cuda:0 bf16[1, 2048, 4544]"
del t153
t157 = torch.nn.functional.linear(t156, t_transformer_h_1_attn_attn_weight, None) # t157: "cuda:0 bf16[1, 2048, 4672]"
# t157 = ltorch.linear(t156, t_transformer_h_1_attn_attn_weight, None) # t157: "cuda:0 bf16[1, 2048, 4672]"
# t157 = prims.linear(t156, t_transformer_h_1_attn_attn_weight, None) # t157: "cuda:0 bf16[1, 2048, 4672]"
t265 = torch.nn.functional.linear(t156, t_transformer_h_1_mlp_fc_weight, None) # t265: "cuda:0 bf16[1, 2048, 18176]"
# t265 = ltorch.linear(t156, t_transformer_h_1_mlp_fc_weight, None) # t265: "cuda:0 bf16[1, 2048, 18176]"
# t265 = prims.linear(t156, t_transformer_h_1_mlp_fc_weight, None) # t265: "cuda:0 bf16[1, 2048, 18176]"
t163 = torch.reshape(t157, (1, 2048, 1, 73, 64)) # t163: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t163 = ltorch.reshape(t157, (1, 2048, 1, 73, 64)) # t163: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t163 = prims.reshape(t157, (1, 2048, 1, 73, 64)) # t163: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t157
t169 = torch.permute(t163, (0, 2, 3, 1, 4)) # t169: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t169 = ltorch.permute(t163, (0, 2, 3, 1, 4)) # t169: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t169 = prims.transpose(t163, (0, 2, 3, 1, 4)) # t169: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t163
(t170, t171, t172) = torch.split(t169, (71, 1, 1), 2)
# (t170, t171, t172) = ltorch.split(t169, (71, 1, 1), 2)
# t170 = prims.slice_prim(t169, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t170: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t171 = prims.slice_prim(t169, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t171: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t172 = prims.slice_prim(t169, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t172: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t169
t178 = Tensor.expand(t171, (1, 1, 71, 2048, 64)) # t178: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t178 = ltorch.expand(t171, (1, 1, 71, 2048, 64)) # t178: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t178 = prims.broadcast_in_dim(t171, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t178: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t171
t184 = Tensor.expand(t172, (1, 1, 71, 2048, 64)) # t184: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t184 = ltorch.expand(t172, (1, 1, 71, 2048, 64)) # t184: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t184 = prims.broadcast_in_dim(t172, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t184: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t172
t190 = torch.reshape(t170, (1, 71, 2048, 64)) # t190: "cuda:0 bf16[1, 71, 2048, 64]"
# t190 = ltorch.reshape(t170, (1, 71, 2048, 64)) # t190: "cuda:0 bf16[1, 71, 2048, 64]"
# t190 = prims.reshape(t170, (1, 71, 2048, 64)) # t190: "cuda:0 bf16[1, 71, 2048, 64]"
del t170
t196 = torch.reshape(t178, (1, 71, 2048, 64)) # t196: "cuda:0 bf16[1, 71, 2048, 64]"
# t196 = ltorch.reshape(t178, (1, 71, 2048, 64)) # t196: "cuda:0 bf16[1, 71, 2048, 64]"
# t196 = prims.reshape(t178, (1, 71, 2048, 64)) # t196: "cuda:0 bf16[1, 71, 2048, 64]"
del t178
t202 = torch.reshape(t184, (1, 71, 2048, 64)) # t202: "cuda:0 bf16[1, 71, 2048, 64]"
# t202 = ltorch.reshape(t184, (1, 71, 2048, 64)) # t202: "cuda:0 bf16[1, 71, 2048, 64]"
# t202 = prims.reshape(t184, (1, 71, 2048, 64)) # t202: "cuda:0 bf16[1, 71, 2048, 64]"
del t184
t203 = torch_slice_prim_impl(t190, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t203: "cuda:0 bf16[1, 71, 2048, 64]"
t204 = torch_slice_prim_impl(t203, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t204: "cuda:0 bf16[1, 71, 2048, 32]"
t205 = torch_slice_prim_impl(t203, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t205: "cuda:0 bf16[1, 71, 2048, 32]"
t225 = torch_slice_prim_impl(t196, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t225: "cuda:0 bf16[1, 71, 2048, 64]"
t226 = torch_slice_prim_impl(t225, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t226: "cuda:0 bf16[1, 71, 2048, 32]"
t227 = torch_slice_prim_impl(t225, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t227: "cuda:0 bf16[1, 71, 2048, 32]"
t247 = torch_slice_prim_impl(t190, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t247: "cuda:0 bf16[1, 71, 2048, 0]"
del t190
t250 = torch_slice_prim_impl(t196, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t250: "cuda:0 bf16[1, 71, 2048, 0]"
del t196
[t249, t252, t284] = nvFusion3(t203, t204, t205, t225, t226, t227, t247, t250, t265, t61, t66)
# t206 = prims.convert_element_type(t205, dtypes.float32) # t206: "cuda:0 f32[1, 71, 2048, 32]"
# t207 = prims.neg(t206) # t207: "cuda:0 f32[1, 71, 2048, 32]"
# t208 = prims.convert_element_type(t207, dtypes.bfloat16) # t208: "cuda:0 bf16[1, 71, 2048, 32]"
# t210 = prims.cat((t208, t204), -1) # t210: "cuda:0 bf16[1, 71, 2048, 64]"
# t212 = prims.convert_element_type(t203, dtypes.float32) # t212: "cuda:0 f32[1, 71, 2048, 64]"
# t214 = prims.mul(t212, t61) # t214: "cuda:0 f32[1, 71, 2048, 64]"
# t217 = prims.convert_element_type(t210, dtypes.float32) # t217: "cuda:0 f32[1, 71, 2048, 64]"
# t219 = prims.mul(t217, t66) # t219: "cuda:0 f32[1, 71, 2048, 64]"
# t223 = prims.add(t214, t219) # t223: "cuda:0 f32[1, 71, 2048, 64]"
# t224 = prims.convert_element_type(t223, dtypes.bfloat16) # t224: "cuda:0 bf16[1, 71, 2048, 64]"
# t228 = prims.convert_element_type(t227, dtypes.float32) # t228: "cuda:0 f32[1, 71, 2048, 32]"
# t229 = prims.neg(t228) # t229: "cuda:0 f32[1, 71, 2048, 32]"
# t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: "cuda:0 bf16[1, 71, 2048, 32]"
# t232 = prims.cat((t230, t226), -1) # t232: "cuda:0 bf16[1, 71, 2048, 64]"
# t234 = prims.convert_element_type(t225, dtypes.float32) # t234: "cuda:0 f32[1, 71, 2048, 64]"
# t236 = prims.mul(t234, t61) # t236: "cuda:0 f32[1, 71, 2048, 64]"
# t239 = prims.convert_element_type(t232, dtypes.float32) # t239: "cuda:0 f32[1, 71, 2048, 64]"
# t241 = prims.mul(t239, t66) # t241: "cuda:0 f32[1, 71, 2048, 64]"
# t245 = prims.add(t236, t241) # t245: "cuda:0 f32[1, 71, 2048, 64]"
# t246 = prims.convert_element_type(t245, dtypes.bfloat16) # t246: "cuda:0 bf16[1, 71, 2048, 64]"
# t249 = prims.cat((t224, t247), -1) # t249: "cuda:0 bf16[1, 71, 2048, 64]"
# t252 = prims.cat((t246, t250), -1) # t252: "cuda:0 bf16[1, 71, 2048, 64]"
# t266 = prims.convert_element_type(t265, dtypes.float32) # t266: "cuda:0 f32[1, 2048, 18176]"
# t268 = prims.div(t266, 1.4142135623730951) # t268: "cuda:0 f32[1, 2048, 18176]"
# t271 = prims.erf(t268) # t271: "cuda:0 f32[1, 2048, 18176]"
# t275 = prims.mul(0.5, t271) # t275: "cuda:0 f32[1, 2048, 18176]"
# t279 = prims.add(0.5, t275) # t279: "cuda:0 f32[1, 2048, 18176]"
# t283 = prims.mul(t266, t279) # t283: "cuda:0 f32[1, 2048, 18176]"
# t284 = prims.convert_element_type(t283, dtypes.bfloat16) # t284: "cuda:0 bf16[1, 2048, 18176]"
del t203, t204, t205, t225, t226, t227, t247, t250
(t253, t254, t255, t256) = cudnn_sdpa_fwd(t249, t252, t202, None, 0.0, True, scale=0.125)
t285 = torch.nn.functional.linear(t284, t_transformer_h_1_mlp_proj_weight, None) # t285: "cuda:0 bf16[1, 2048, 4544]"
# t285 = ltorch.linear(t284, t_transformer_h_1_mlp_proj_weight, None) # t285: "cuda:0 bf16[1, 2048, 4544]"
# t285 = prims.linear(t284, t_transformer_h_1_mlp_proj_weight, None) # t285: "cuda:0 bf16[1, 2048, 4544]"
t259 = torch.permute(t253, (0, 2, 1, 3)) # t259: "cuda:0 bf16[1, 2048, 71, 64]"
# t259 = ltorch.permute(t253, (0, 2, 1, 3)) # t259: "cuda:0 bf16[1, 2048, 71, 64]"
# t259 = prims.transpose(t253, (0, 2, 1, 3)) # t259: "cuda:0 bf16[1, 2048, 71, 64]"
t263 = torch.reshape(t259, (1, 2048, 4544)) # t263: "cuda:0 bf16[1, 2048, 4544]"
# t263 = ltorch.reshape(t259, (1, 2048, 4544)) # t263: "cuda:0 bf16[1, 2048, 4544]"
# t263 = prims.reshape(t259, (1, 2048, 4544)) # t263: "cuda:0 bf16[1, 2048, 4544]"
del t259
t264 = torch.nn.functional.linear(t263, t_transformer_h_1_attn_proj_weight, None) # t264: "cuda:0 bf16[1, 2048, 4544]"
# t264 = ltorch.linear(t263, t_transformer_h_1_attn_proj_weight, None) # t264: "cuda:0 bf16[1, 2048, 4544]"
# t264 = prims.linear(t263, t_transformer_h_1_attn_proj_weight, None) # t264: "cuda:0 bf16[1, 2048, 4544]"
t5341 = torch.unsqueeze(t_transformer_h_2_norm_1_weight, 0) # t5341: "cuda:0 bf16[1, 4544]"
# t5341 = ltorch.unsqueeze(t_transformer_h_2_norm_1_weight, 0) # t5341: "cuda:0 bf16[1, 4544]"
# t5341 = prims.broadcast_in_dim(t_transformer_h_2_norm_1_weight, [1, 4544], [1]) # t5341: "cuda:0 bf16[1, 4544]"
t5342 = torch.unsqueeze(t5341, 1) # t5342: "cuda:0 bf16[1, 1, 4544]"
# t5342 = ltorch.unsqueeze(t5341, 1) # t5342: "cuda:0 bf16[1, 1, 4544]"
# t5342 = prims.broadcast_in_dim(t5341, [1, 1, 4544], [0, 2]) # t5342: "cuda:0 bf16[1, 1, 4544]"
del t5341
t311 = Tensor.expand(t5342, (1, 2048, 4544)) # t311: "cuda:0 bf16[1, 2048, 4544]"
# t311 = ltorch.expand(t5342, (1, 2048, 4544)) # t311: "cuda:0 bf16[1, 2048, 4544]"
# t311 = prims.broadcast_in_dim(t5342, (1, 2048, 4544), (0, 1, 2)) # t311: "cuda:0 bf16[1, 2048, 4544]"
del t5342
t5344 = torch.unsqueeze(t_transformer_h_2_norm_1_bias, 0) # t5344: "cuda:0 bf16[1, 4544]"
# t5344 = ltorch.unsqueeze(t_transformer_h_2_norm_1_bias, 0) # t5344: "cuda:0 bf16[1, 4544]"
# t5344 = prims.broadcast_in_dim(t_transformer_h_2_norm_1_bias, [1, 4544], [1]) # t5344: "cuda:0 bf16[1, 4544]"
t5345 = torch.unsqueeze(t5344, 1) # t5345: "cuda:0 bf16[1, 1, 4544]"
# t5345 = ltorch.unsqueeze(t5344, 1) # t5345: "cuda:0 bf16[1, 1, 4544]"
# t5345 = prims.broadcast_in_dim(t5344, [1, 1, 4544], [0, 2]) # t5345: "cuda:0 bf16[1, 1, 4544]"
del t5344
t314 = Tensor.expand(t5345, (1, 2048, 4544)) # t314: "cuda:0 bf16[1, 2048, 4544]"
# t314 = ltorch.expand(t5345, (1, 2048, 4544)) # t314: "cuda:0 bf16[1, 2048, 4544]"
# t314 = prims.broadcast_in_dim(t5345, (1, 2048, 4544), (0, 1, 2)) # t314: "cuda:0 bf16[1, 2048, 4544]"
del t5345
[t293, t300, t305, t317] = nvFusion4(t132, t264, t285, t311, t314)
# t291 = prims.convert_element_type(t132, dtypes.float32) # t291: "cuda:0 f32[1, 2048, 4544]"
# t286 = prims.convert_element_type(t285, dtypes.float32) # t286: "cuda:0 f32[1, 2048, 4544]"
# t287 = prims.convert_element_type(t264, dtypes.float32) # t287: "cuda:0 f32[1, 2048, 4544]"
# t288 = prims.add(t286, t287) # t288: "cuda:0 f32[1, 2048, 4544]"
# t292 = prims.add(t288, t291) # t292: "cuda:0 f32[1, 2048, 4544]"
# t293 = prims.convert_element_type(t292, dtypes.bfloat16) # t293: "cuda:0 bf16[1, 2048, 4544]"
# (t299, t300) = prims.var_mean(t292, (2,), correction=0)
# t301 = prims.broadcast_in_dim(t299, [1, 2048, 1], [0, 1]) # t301: "cuda:0 f32[1, 2048, 1]"
# t302 = prims.broadcast_in_dim(t300, [1, 2048, 1], [0, 1]) # t302: "cuda:0 f32[1, 2048, 1]"
# t304 = prims.add(t301, 1e-05) # t304: "cuda:0 f32[1, 2048, 1]"
# t305 = prims.rsqrt(t304) # t305: "cuda:0 f32[1, 2048, 1]"
# t306 = prims.broadcast_in_dim(t302, (1, 2048, 4544), (0, 1, 2)) # t306: "cuda:0 f32[1, 2048, 4544]"
# t308 = prims.sub(t292, t306) # t308: "cuda:0 f32[1, 2048, 4544]"
# t309 = prims.broadcast_in_dim(t305, (1, 2048, 4544), (0, 1, 2)) # t309: "cuda:0 f32[1, 2048, 4544]"
# t310 = prims.mul(t308, t309) # t310: "cuda:0 f32[1, 2048, 4544]"
# t312 = prims.convert_element_type(t311, dtypes.float32) # t312: "cuda:0 f32[1, 2048, 4544]"
# t313 = prims.mul(t310, t312) # t313: "cuda:0 f32[1, 2048, 4544]"
# t315 = prims.convert_element_type(t314, dtypes.float32) # t315: "cuda:0 f32[1, 2048, 4544]"
# t316 = prims.add(t313, t315) # t316: "cuda:0 f32[1, 2048, 4544]"
# t317 = prims.convert_element_type(t316, dtypes.bfloat16) # t317: "cuda:0 bf16[1, 2048, 4544]"
del t314
t426 = torch.nn.functional.linear(t317, t_transformer_h_2_mlp_fc_weight, None) # t426: "cuda:0 bf16[1, 2048, 18176]"
# t426 = ltorch.linear(t317, t_transformer_h_2_mlp_fc_weight, None) # t426: "cuda:0 bf16[1, 2048, 18176]"
# t426 = prims.linear(t317, t_transformer_h_2_mlp_fc_weight, None) # t426: "cuda:0 bf16[1, 2048, 18176]"
t318 = torch.nn.functional.linear(t317, t_transformer_h_2_attn_attn_weight, None) # t318: "cuda:0 bf16[1, 2048, 4672]"
# t318 = ltorch.linear(t317, t_transformer_h_2_attn_attn_weight, None) # t318: "cuda:0 bf16[1, 2048, 4672]"
# t318 = prims.linear(t317, t_transformer_h_2_attn_attn_weight, None) # t318: "cuda:0 bf16[1, 2048, 4672]"
t324 = torch.reshape(t318, (1, 2048, 1, 73, 64)) # t324: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t324 = ltorch.reshape(t318, (1, 2048, 1, 73, 64)) # t324: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t324 = prims.reshape(t318, (1, 2048, 1, 73, 64)) # t324: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t318
t330 = torch.permute(t324, (0, 2, 3, 1, 4)) # t330: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t330 = ltorch.permute(t324, (0, 2, 3, 1, 4)) # t330: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t330 = prims.transpose(t324, (0, 2, 3, 1, 4)) # t330: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t324
(t331, t332, t333) = torch.split(t330, (71, 1, 1), 2)
# (t331, t332, t333) = ltorch.split(t330, (71, 1, 1), 2)
# t331 = prims.slice_prim(t330, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t331: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t332 = prims.slice_prim(t330, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t332: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t333 = prims.slice_prim(t330, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t333: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t330
t339 = Tensor.expand(t332, (1, 1, 71, 2048, 64)) # t339: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t339 = ltorch.expand(t332, (1, 1, 71, 2048, 64)) # t339: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t339 = prims.broadcast_in_dim(t332, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t339: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t332
t345 = Tensor.expand(t333, (1, 1, 71, 2048, 64)) # t345: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t345 = ltorch.expand(t333, (1, 1, 71, 2048, 64)) # t345: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t345 = prims.broadcast_in_dim(t333, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t345: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t333
t351 = torch.reshape(t331, (1, 71, 2048, 64)) # t351: "cuda:0 bf16[1, 71, 2048, 64]"
# t351 = ltorch.reshape(t331, (1, 71, 2048, 64)) # t351: "cuda:0 bf16[1, 71, 2048, 64]"
# t351 = prims.reshape(t331, (1, 71, 2048, 64)) # t351: "cuda:0 bf16[1, 71, 2048, 64]"
del t331
t357 = torch.reshape(t339, (1, 71, 2048, 64)) # t357: "cuda:0 bf16[1, 71, 2048, 64]"
# t357 = ltorch.reshape(t339, (1, 71, 2048, 64)) # t357: "cuda:0 bf16[1, 71, 2048, 64]"
# t357 = prims.reshape(t339, (1, 71, 2048, 64)) # t357: "cuda:0 bf16[1, 71, 2048, 64]"
del t339
t363 = torch.reshape(t345, (1, 71, 2048, 64)) # t363: "cuda:0 bf16[1, 71, 2048, 64]"
# t363 = ltorch.reshape(t345, (1, 71, 2048, 64)) # t363: "cuda:0 bf16[1, 71, 2048, 64]"
# t363 = prims.reshape(t345, (1, 71, 2048, 64)) # t363: "cuda:0 bf16[1, 71, 2048, 64]"
del t345
t364 = torch_slice_prim_impl(t351, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t364: "cuda:0 bf16[1, 71, 2048, 64]"
t365 = torch_slice_prim_impl(t364, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t365: "cuda:0 bf16[1, 71, 2048, 32]"
t366 = torch_slice_prim_impl(t364, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t366: "cuda:0 bf16[1, 71, 2048, 32]"
t386 = torch_slice_prim_impl(t357, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t386: "cuda:0 bf16[1, 71, 2048, 64]"
t387 = torch_slice_prim_impl(t386, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t387: "cuda:0 bf16[1, 71, 2048, 32]"
t388 = torch_slice_prim_impl(t386, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t388: "cuda:0 bf16[1, 71, 2048, 32]"
t408 = torch_slice_prim_impl(t351, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t408: "cuda:0 bf16[1, 71, 2048, 0]"
del t351
t411 = torch_slice_prim_impl(t357, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t411: "cuda:0 bf16[1, 71, 2048, 0]"
del t357
[t410, t413, t445] = nvFusion5(t364, t365, t366, t386, t387, t388, t408, t411, t426, t61, t66)
# t427 = prims.convert_element_type(t426, dtypes.float32) # t427: "cuda:0 f32[1, 2048, 18176]"
# t429 = prims.div(t427, 1.4142135623730951) # t429: "cuda:0 f32[1, 2048, 18176]"
# t432 = prims.erf(t429) # t432: "cuda:0 f32[1, 2048, 18176]"
# t436 = prims.mul(0.5, t432) # t436: "cuda:0 f32[1, 2048, 18176]"
# t440 = prims.add(0.5, t436) # t440: "cuda:0 f32[1, 2048, 18176]"
# t444 = prims.mul(t427, t440) # t444: "cuda:0 f32[1, 2048, 18176]"
# t445 = prims.convert_element_type(t444, dtypes.bfloat16) # t445: "cuda:0 bf16[1, 2048, 18176]"
# t367 = prims.convert_element_type(t366, dtypes.float32) # t367: "cuda:0 f32[1, 71, 2048, 32]"
# t368 = prims.neg(t367) # t368: "cuda:0 f32[1, 71, 2048, 32]"
# t369 = prims.convert_element_type(t368, dtypes.bfloat16) # t369: "cuda:0 bf16[1, 71, 2048, 32]"
# t371 = prims.cat((t369, t365), -1) # t371: "cuda:0 bf16[1, 71, 2048, 64]"
# t373 = prims.convert_element_type(t364, dtypes.float32) # t373: "cuda:0 f32[1, 71, 2048, 64]"
# t375 = prims.mul(t373, t61) # t375: "cuda:0 f32[1, 71, 2048, 64]"
# t378 = prims.convert_element_type(t371, dtypes.float32) # t378: "cuda:0 f32[1, 71, 2048, 64]"
# t380 = prims.mul(t378, t66) # t380: "cuda:0 f32[1, 71, 2048, 64]"
# t384 = prims.add(t375, t380) # t384: "cuda:0 f32[1, 71, 2048, 64]"
# t385 = prims.convert_element_type(t384, dtypes.bfloat16) # t385: "cuda:0 bf16[1, 71, 2048, 64]"
# t389 = prims.convert_element_type(t388, dtypes.float32) # t389: "cuda:0 f32[1, 71, 2048, 32]"
# t390 = prims.neg(t389) # t390: "cuda:0 f32[1, 71, 2048, 32]"
# t391 = prims.convert_element_type(t390, dtypes.bfloat16) # t391: "cuda:0 bf16[1, 71, 2048, 32]"
# t393 = prims.cat((t391, t387), -1) # t393: "cuda:0 bf16[1, 71, 2048, 64]"
# t395 = prims.convert_element_type(t386, dtypes.float32) # t395: "cuda:0 f32[1, 71, 2048, 64]"
# t397 = prims.mul(t395, t61) # t397: "cuda:0 f32[1, 71, 2048, 64]"
# t400 = prims.convert_element_type(t393, dtypes.float32) # t400: "cuda:0 f32[1, 71, 2048, 64]"
# t402 = prims.mul(t400, t66) # t402: "cuda:0 f32[1, 71, 2048, 64]"
# t406 = prims.add(t397, t402) # t406: "cuda:0 f32[1, 71, 2048, 64]"
# t407 = prims.convert_element_type(t406, dtypes.bfloat16) # t407: "cuda:0 bf16[1, 71, 2048, 64]"
# t410 = prims.cat((t385, t408), -1) # t410: "cuda:0 bf16[1, 71, 2048, 64]"
# t413 = prims.cat((t407, t411), -1) # t413: "cuda:0 bf16[1, 71, 2048, 64]"
del t364, t365, t366, t386, t387, t388, t408, t411
t446 = torch.nn.functional.linear(t445, t_transformer_h_2_mlp_proj_weight, None) # t446: "cuda:0 bf16[1, 2048, 4544]"
# t446 = ltorch.linear(t445, t_transformer_h_2_mlp_proj_weight, None) # t446: "cuda:0 bf16[1, 2048, 4544]"
# t446 = prims.linear(t445, t_transformer_h_2_mlp_proj_weight, None) # t446: "cuda:0 bf16[1, 2048, 4544]"
(t414, t415, t416, t417) = cudnn_sdpa_fwd(t410, t413, t363, None, 0.0, True, scale=0.125)
t420 = torch.permute(t414, (0, 2, 1, 3)) # t420: "cuda:0 bf16[1, 2048, 71, 64]"
# t420 = ltorch.permute(t414, (0, 2, 1, 3)) # t420: "cuda:0 bf16[1, 2048, 71, 64]"
# t420 = prims.transpose(t414, (0, 2, 1, 3)) # t420: "cuda:0 bf16[1, 2048, 71, 64]"
t424 = torch.reshape(t420, (1, 2048, 4544)) # t424: "cuda:0 bf16[1, 2048, 4544]"
# t424 = ltorch.reshape(t420, (1, 2048, 4544)) # t424: "cuda:0 bf16[1, 2048, 4544]"
# t424 = prims.reshape(t420, (1, 2048, 4544)) # t424: "cuda:0 bf16[1, 2048, 4544]"
del t420
t425 = torch.nn.functional.linear(t424, t_transformer_h_2_attn_proj_weight, None) # t425: "cuda:0 bf16[1, 2048, 4544]"
# t425 = ltorch.linear(t424, t_transformer_h_2_attn_proj_weight, None) # t425: "cuda:0 bf16[1, 2048, 4544]"
# t425 = prims.linear(t424, t_transformer_h_2_attn_proj_weight, None) # t425: "cuda:0 bf16[1, 2048, 4544]"
t5367 = torch.unsqueeze(t_transformer_h_3_norm_1_weight, 0) # t5367: "cuda:0 bf16[1, 4544]"
# t5367 = ltorch.unsqueeze(t_transformer_h_3_norm_1_weight, 0) # t5367: "cuda:0 bf16[1, 4544]"
# t5367 = prims.broadcast_in_dim(t_transformer_h_3_norm_1_weight, [1, 4544], [1]) # t5367: "cuda:0 bf16[1, 4544]"
t5368 = torch.unsqueeze(t5367, 1) # t5368: "cuda:0 bf16[1, 1, 4544]"
# t5368 = ltorch.unsqueeze(t5367, 1) # t5368: "cuda:0 bf16[1, 1, 4544]"
# t5368 = prims.broadcast_in_dim(t5367, [1, 1, 4544], [0, 2]) # t5368: "cuda:0 bf16[1, 1, 4544]"
del t5367
t472 = Tensor.expand(t5368, (1, 2048, 4544)) # t472: "cuda:0 bf16[1, 2048, 4544]"
# t472 = ltorch.expand(t5368, (1, 2048, 4544)) # t472: "cuda:0 bf16[1, 2048, 4544]"
# t472 = prims.broadcast_in_dim(t5368, (1, 2048, 4544), (0, 1, 2)) # t472: "cuda:0 bf16[1, 2048, 4544]"
del t5368
t5370 = torch.unsqueeze(t_transformer_h_3_norm_1_bias, 0) # t5370: "cuda:0 bf16[1, 4544]"
# t5370 = ltorch.unsqueeze(t_transformer_h_3_norm_1_bias, 0) # t5370: "cuda:0 bf16[1, 4544]"
# t5370 = prims.broadcast_in_dim(t_transformer_h_3_norm_1_bias, [1, 4544], [1]) # t5370: "cuda:0 bf16[1, 4544]"
t5371 = torch.unsqueeze(t5370, 1) # t5371: "cuda:0 bf16[1, 1, 4544]"
# t5371 = ltorch.unsqueeze(t5370, 1) # t5371: "cuda:0 bf16[1, 1, 4544]"
# t5371 = prims.broadcast_in_dim(t5370, [1, 1, 4544], [0, 2]) # t5371: "cuda:0 bf16[1, 1, 4544]"
del t5370
t475 = Tensor.expand(t5371, (1, 2048, 4544)) # t475: "cuda:0 bf16[1, 2048, 4544]"
# t475 = ltorch.expand(t5371, (1, 2048, 4544)) # t475: "cuda:0 bf16[1, 2048, 4544]"
# t475 = prims.broadcast_in_dim(t5371, (1, 2048, 4544), (0, 1, 2)) # t475: "cuda:0 bf16[1, 2048, 4544]"
del t5371
[t454, t461, t466, t478] = nvFusion6(t293, t425, t446, t472, t475)
# t452 = prims.convert_element_type(t293, dtypes.float32) # t452: "cuda:0 f32[1, 2048, 4544]"
# t447 = prims.convert_element_type(t446, dtypes.float32) # t447: "cuda:0 f32[1, 2048, 4544]"
# t448 = prims.convert_element_type(t425, dtypes.float32) # t448: "cuda:0 f32[1, 2048, 4544]"
# t449 = prims.add(t447, t448) # t449: "cuda:0 f32[1, 2048, 4544]"
# t453 = prims.add(t449, t452) # t453: "cuda:0 f32[1, 2048, 4544]"
# t454 = prims.convert_element_type(t453, dtypes.bfloat16) # t454: "cuda:0 bf16[1, 2048, 4544]"
# (t460, t461) = prims.var_mean(t453, (2,), correction=0)
# t462 = prims.broadcast_in_dim(t460, [1, 2048, 1], [0, 1]) # t462: "cuda:0 f32[1, 2048, 1]"
# t463 = prims.broadcast_in_dim(t461, [1, 2048, 1], [0, 1]) # t463: "cuda:0 f32[1, 2048, 1]"
# t465 = prims.add(t462, 1e-05) # t465: "cuda:0 f32[1, 2048, 1]"
# t466 = prims.rsqrt(t465) # t466: "cuda:0 f32[1, 2048, 1]"
# t467 = prims.broadcast_in_dim(t463, (1, 2048, 4544), (0, 1, 2)) # t467: "cuda:0 f32[1, 2048, 4544]"
# t469 = prims.sub(t453, t467) # t469: "cuda:0 f32[1, 2048, 4544]"
# t470 = prims.broadcast_in_dim(t466, (1, 2048, 4544), (0, 1, 2)) # t470: "cuda:0 f32[1, 2048, 4544]"
# t471 = prims.mul(t469, t470) # t471: "cuda:0 f32[1, 2048, 4544]"
# t473 = prims.convert_element_type(t472, dtypes.float32) # t473: "cuda:0 f32[1, 2048, 4544]"
# t474 = prims.mul(t471, t473) # t474: "cuda:0 f32[1, 2048, 4544]"
# t476 = prims.convert_element_type(t475, dtypes.float32) # t476: "cuda:0 f32[1, 2048, 4544]"
# t477 = prims.add(t474, t476) # t477: "cuda:0 f32[1, 2048, 4544]"
# t478 = prims.convert_element_type(t477, dtypes.bfloat16) # t478: "cuda:0 bf16[1, 2048, 4544]"
del t475
t587 = torch.nn.functional.linear(t478, t_transformer_h_3_mlp_fc_weight, None) # t587: "cuda:0 bf16[1, 2048, 18176]"
# t587 = ltorch.linear(t478, t_transformer_h_3_mlp_fc_weight, None) # t587: "cuda:0 bf16[1, 2048, 18176]"
# t587 = prims.linear(t478, t_transformer_h_3_mlp_fc_weight, None) # t587: "cuda:0 bf16[1, 2048, 18176]"
t479 = torch.nn.functional.linear(t478, t_transformer_h_3_attn_attn_weight, None) # t479: "cuda:0 bf16[1, 2048, 4672]"
# t479 = ltorch.linear(t478, t_transformer_h_3_attn_attn_weight, None) # t479: "cuda:0 bf16[1, 2048, 4672]"
# t479 = prims.linear(t478, t_transformer_h_3_attn_attn_weight, None) # t479: "cuda:0 bf16[1, 2048, 4672]"
t485 = torch.reshape(t479, (1, 2048, 1, 73, 64)) # t485: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t485 = ltorch.reshape(t479, (1, 2048, 1, 73, 64)) # t485: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t485 = prims.reshape(t479, (1, 2048, 1, 73, 64)) # t485: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t479
t491 = torch.permute(t485, (0, 2, 3, 1, 4)) # t491: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t491 = ltorch.permute(t485, (0, 2, 3, 1, 4)) # t491: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t491 = prims.transpose(t485, (0, 2, 3, 1, 4)) # t491: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t485
(t492, t493, t494) = torch.split(t491, (71, 1, 1), 2)
# (t492, t493, t494) = ltorch.split(t491, (71, 1, 1), 2)
# t492 = prims.slice_prim(t491, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t492: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t493 = prims.slice_prim(t491, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t493: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t494 = prims.slice_prim(t491, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t494: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t491
t500 = Tensor.expand(t493, (1, 1, 71, 2048, 64)) # t500: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t500 = ltorch.expand(t493, (1, 1, 71, 2048, 64)) # t500: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t500 = prims.broadcast_in_dim(t493, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t500: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t493
t506 = Tensor.expand(t494, (1, 1, 71, 2048, 64)) # t506: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t506 = ltorch.expand(t494, (1, 1, 71, 2048, 64)) # t506: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t506 = prims.broadcast_in_dim(t494, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t506: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t494
t512 = torch.reshape(t492, (1, 71, 2048, 64)) # t512: "cuda:0 bf16[1, 71, 2048, 64]"
# t512 = ltorch.reshape(t492, (1, 71, 2048, 64)) # t512: "cuda:0 bf16[1, 71, 2048, 64]"
# t512 = prims.reshape(t492, (1, 71, 2048, 64)) # t512: "cuda:0 bf16[1, 71, 2048, 64]"
del t492
t518 = torch.reshape(t500, (1, 71, 2048, 64)) # t518: "cuda:0 bf16[1, 71, 2048, 64]"
# t518 = ltorch.reshape(t500, (1, 71, 2048, 64)) # t518: "cuda:0 bf16[1, 71, 2048, 64]"
# t518 = prims.reshape(t500, (1, 71, 2048, 64)) # t518: "cuda:0 bf16[1, 71, 2048, 64]"
del t500
t524 = torch.reshape(t506, (1, 71, 2048, 64)) # t524: "cuda:0 bf16[1, 71, 2048, 64]"
# t524 = ltorch.reshape(t506, (1, 71, 2048, 64)) # t524: "cuda:0 bf16[1, 71, 2048, 64]"
# t524 = prims.reshape(t506, (1, 71, 2048, 64)) # t524: "cuda:0 bf16[1, 71, 2048, 64]"
del t506
t525 = torch_slice_prim_impl(t512, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t525: "cuda:0 bf16[1, 71, 2048, 64]"
t526 = torch_slice_prim_impl(t525, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t526: "cuda:0 bf16[1, 71, 2048, 32]"
t527 = torch_slice_prim_impl(t525, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t527: "cuda:0 bf16[1, 71, 2048, 32]"
t547 = torch_slice_prim_impl(t518, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t547: "cuda:0 bf16[1, 71, 2048, 64]"
t548 = torch_slice_prim_impl(t547, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t548: "cuda:0 bf16[1, 71, 2048, 32]"
t549 = torch_slice_prim_impl(t547, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t549: "cuda:0 bf16[1, 71, 2048, 32]"
t569 = torch_slice_prim_impl(t512, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t569: "cuda:0 bf16[1, 71, 2048, 0]"
del t512
t572 = torch_slice_prim_impl(t518, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t572: "cuda:0 bf16[1, 71, 2048, 0]"
del t518
[t571, t574, t606] = nvFusion7(t525, t526, t527, t547, t548, t549, t569, t572, t587, t61, t66)
# t588 = prims.convert_element_type(t587, dtypes.float32) # t588: "cuda:0 f32[1, 2048, 18176]"
# t590 = prims.div(t588, 1.4142135623730951) # t590: "cuda:0 f32[1, 2048, 18176]"
# t593 = prims.erf(t590) # t593: "cuda:0 f32[1, 2048, 18176]"
# t597 = prims.mul(0.5, t593) # t597: "cuda:0 f32[1, 2048, 18176]"
# t601 = prims.add(0.5, t597) # t601: "cuda:0 f32[1, 2048, 18176]"
# t605 = prims.mul(t588, t601) # t605: "cuda:0 f32[1, 2048, 18176]"
# t606 = prims.convert_element_type(t605, dtypes.bfloat16) # t606: "cuda:0 bf16[1, 2048, 18176]"
# t528 = prims.convert_element_type(t527, dtypes.float32) # t528: "cuda:0 f32[1, 71, 2048, 32]"
# t529 = prims.neg(t528) # t529: "cuda:0 f32[1, 71, 2048, 32]"
# t530 = prims.convert_element_type(t529, dtypes.bfloat16) # t530: "cuda:0 bf16[1, 71, 2048, 32]"
# t532 = prims.cat((t530, t526), -1) # t532: "cuda:0 bf16[1, 71, 2048, 64]"
# t534 = prims.convert_element_type(t525, dtypes.float32) # t534: "cuda:0 f32[1, 71, 2048, 64]"
# t536 = prims.mul(t534, t61) # t536: "cuda:0 f32[1, 71, 2048, 64]"
# t539 = prims.convert_element_type(t532, dtypes.float32) # t539: "cuda:0 f32[1, 71, 2048, 64]"
# t541 = prims.mul(t539, t66) # t541: "cuda:0 f32[1, 71, 2048, 64]"
# t545 = prims.add(t536, t541) # t545: "cuda:0 f32[1, 71, 2048, 64]"
# t546 = prims.convert_element_type(t545, dtypes.bfloat16) # t546: "cuda:0 bf16[1, 71, 2048, 64]"
# t550 = prims.convert_element_type(t549, dtypes.float32) # t550: "cuda:0 f32[1, 71, 2048, 32]"
# t551 = prims.neg(t550) # t551: "cuda:0 f32[1, 71, 2048, 32]"
# t552 = prims.convert_element_type(t551, dtypes.bfloat16) # t552: "cuda:0 bf16[1, 71, 2048, 32]"
# t554 = prims.cat((t552, t548), -1) # t554: "cuda:0 bf16[1, 71, 2048, 64]"
# t556 = prims.convert_element_type(t547, dtypes.float32) # t556: "cuda:0 f32[1, 71, 2048, 64]"
# t558 = prims.mul(t556, t61) # t558: "cuda:0 f32[1, 71, 2048, 64]"
# t561 = prims.convert_element_type(t554, dtypes.float32) # t561: "cuda:0 f32[1, 71, 2048, 64]"
# t563 = prims.mul(t561, t66) # t563: "cuda:0 f32[1, 71, 2048, 64]"
# t567 = prims.add(t558, t563) # t567: "cuda:0 f32[1, 71, 2048, 64]"
# t568 = prims.convert_element_type(t567, dtypes.bfloat16) # t568: "cuda:0 bf16[1, 71, 2048, 64]"
# t571 = prims.cat((t546, t569), -1) # t571: "cuda:0 bf16[1, 71, 2048, 64]"
# t574 = prims.cat((t568, t572), -1) # t574: "cuda:0 bf16[1, 71, 2048, 64]"
del t525, t526, t527, t547, t548, t549, t569, t572
t607 = torch.nn.functional.linear(t606, t_transformer_h_3_mlp_proj_weight, None) # t607: "cuda:0 bf16[1, 2048, 4544]"
# t607 = ltorch.linear(t606, t_transformer_h_3_mlp_proj_weight, None) # t607: "cuda:0 bf16[1, 2048, 4544]"
# t607 = prims.linear(t606, t_transformer_h_3_mlp_proj_weight, None) # t607: "cuda:0 bf16[1, 2048, 4544]"
(t575, t576, t577, t578) = cudnn_sdpa_fwd(t571, t574, t524, None, 0.0, True, scale=0.125)
t581 = torch.permute(t575, (0, 2, 1, 3)) # t581: "cuda:0 bf16[1, 2048, 71, 64]"
# t581 = ltorch.permute(t575, (0, 2, 1, 3)) # t581: "cuda:0 bf16[1, 2048, 71, 64]"
# t581 = prims.transpose(t575, (0, 2, 1, 3)) # t581: "cuda:0 bf16[1, 2048, 71, 64]"
t585 = torch.reshape(t581, (1, 2048, 4544)) # t585: "cuda:0 bf16[1, 2048, 4544]"
# t585 = ltorch.reshape(t581, (1, 2048, 4544)) # t585: "cuda:0 bf16[1, 2048, 4544]"
# t585 = prims.reshape(t581, (1, 2048, 4544)) # t585: "cuda:0 bf16[1, 2048, 4544]"
del t581
t586 = torch.nn.functional.linear(t585, t_transformer_h_3_attn_proj_weight, None) # t586: "cuda:0 bf16[1, 2048, 4544]"
# t586 = ltorch.linear(t585, t_transformer_h_3_attn_proj_weight, None) # t586: "cuda:0 bf16[1, 2048, 4544]"
# t586 = prims.linear(t585, t_transformer_h_3_attn_proj_weight, None) # t586: "cuda:0 bf16[1, 2048, 4544]"
t5393 = torch.unsqueeze(t_transformer_h_4_norm_1_weight, 0) # t5393: "cuda:0 bf16[1, 4544]"
# t5393 = ltorch.unsqueeze(t_transformer_h_4_norm_1_weight, 0) # t5393: "cuda:0 bf16[1, 4544]"
# t5393 = prims.broadcast_in_dim(t_transformer_h_4_norm_1_weight, [1, 4544], [1]) # t5393: "cuda:0 bf16[1, 4544]"
t5394 = torch.unsqueeze(t5393, 1) # t5394: "cuda:0 bf16[1, 1, 4544]"
# t5394 = ltorch.unsqueeze(t5393, 1) # t5394: "cuda:0 bf16[1, 1, 4544]"
# t5394 = prims.broadcast_in_dim(t5393, [1, 1, 4544], [0, 2]) # t5394: "cuda:0 bf16[1, 1, 4544]"
del t5393
t633 = Tensor.expand(t5394, (1, 2048, 4544)) # t633: "cuda:0 bf16[1, 2048, 4544]"
# t633 = ltorch.expand(t5394, (1, 2048, 4544)) # t633: "cuda:0 bf16[1, 2048, 4544]"
# t633 = prims.broadcast_in_dim(t5394, (1, 2048, 4544), (0, 1, 2)) # t633: "cuda:0 bf16[1, 2048, 4544]"
del t5394
t5396 = torch.unsqueeze(t_transformer_h_4_norm_1_bias, 0) # t5396: "cuda:0 bf16[1, 4544]"
# t5396 = ltorch.unsqueeze(t_transformer_h_4_norm_1_bias, 0) # t5396: "cuda:0 bf16[1, 4544]"
# t5396 = prims.broadcast_in_dim(t_transformer_h_4_norm_1_bias, [1, 4544], [1]) # t5396: "cuda:0 bf16[1, 4544]"
t5397 = torch.unsqueeze(t5396, 1) # t5397: "cuda:0 bf16[1, 1, 4544]"
# t5397 = ltorch.unsqueeze(t5396, 1) # t5397: "cuda:0 bf16[1, 1, 4544]"
# t5397 = prims.broadcast_in_dim(t5396, [1, 1, 4544], [0, 2]) # t5397: "cuda:0 bf16[1, 1, 4544]"
del t5396
t636 = Tensor.expand(t5397, (1, 2048, 4544)) # t636: "cuda:0 bf16[1, 2048, 4544]"
# t636 = ltorch.expand(t5397, (1, 2048, 4544)) # t636: "cuda:0 bf16[1, 2048, 4544]"
# t636 = prims.broadcast_in_dim(t5397, (1, 2048, 4544), (0, 1, 2)) # t636: "cuda:0 bf16[1, 2048, 4544]"
del t5397
[t615, t622, t627, t639] = nvFusion8(t454, t586, t607, t633, t636)
# t613 = prims.convert_element_type(t454, dtypes.float32) # t613: "cuda:0 f32[1, 2048, 4544]"
# t608 = prims.convert_element_type(t607, dtypes.float32) # t608: "cuda:0 f32[1, 2048, 4544]"
# t609 = prims.convert_element_type(t586, dtypes.float32) # t609: "cuda:0 f32[1, 2048, 4544]"
# t610 = prims.add(t608, t609) # t610: "cuda:0 f32[1, 2048, 4544]"
# t614 = prims.add(t610, t613) # t614: "cuda:0 f32[1, 2048, 4544]"
# t615 = prims.convert_element_type(t614, dtypes.bfloat16) # t615: "cuda:0 bf16[1, 2048, 4544]"
# (t621, t622) = prims.var_mean(t614, (2,), correction=0)
# t623 = prims.broadcast_in_dim(t621, [1, 2048, 1], [0, 1]) # t623: "cuda:0 f32[1, 2048, 1]"
# t624 = prims.broadcast_in_dim(t622, [1, 2048, 1], [0, 1]) # t624: "cuda:0 f32[1, 2048, 1]"
# t626 = prims.add(t623, 1e-05) # t626: "cuda:0 f32[1, 2048, 1]"
# t627 = prims.rsqrt(t626) # t627: "cuda:0 f32[1, 2048, 1]"
# t628 = prims.broadcast_in_dim(t624, (1, 2048, 4544), (0, 1, 2)) # t628: "cuda:0 f32[1, 2048, 4544]"
# t630 = prims.sub(t614, t628) # t630: "cuda:0 f32[1, 2048, 4544]"
# t631 = prims.broadcast_in_dim(t627, (1, 2048, 4544), (0, 1, 2)) # t631: "cuda:0 f32[1, 2048, 4544]"
# t632 = prims.mul(t630, t631) # t632: "cuda:0 f32[1, 2048, 4544]"
# t634 = prims.convert_element_type(t633, dtypes.float32) # t634: "cuda:0 f32[1, 2048, 4544]"
# t635 = prims.mul(t632, t634) # t635: "cuda:0 f32[1, 2048, 4544]"
# t637 = prims.convert_element_type(t636, dtypes.float32) # t637: "cuda:0 f32[1, 2048, 4544]"
# t638 = prims.add(t635, t637) # t638: "cuda:0 f32[1, 2048, 4544]"
# t639 = prims.convert_element_type(t638, dtypes.bfloat16) # t639: "cuda:0 bf16[1, 2048, 4544]"
del t636
t748 = torch.nn.functional.linear(t639, t_transformer_h_4_mlp_fc_weight, None) # t748: "cuda:0 bf16[1, 2048, 18176]"
# t748 = ltorch.linear(t639, t_transformer_h_4_mlp_fc_weight, None) # t748: "cuda:0 bf16[1, 2048, 18176]"
# t748 = prims.linear(t639, t_transformer_h_4_mlp_fc_weight, None) # t748: "cuda:0 bf16[1, 2048, 18176]"
t640 = torch.nn.functional.linear(t639, t_transformer_h_4_attn_attn_weight, None) # t640: "cuda:0 bf16[1, 2048, 4672]"
# t640 = ltorch.linear(t639, t_transformer_h_4_attn_attn_weight, None) # t640: "cuda:0 bf16[1, 2048, 4672]"
# t640 = prims.linear(t639, t_transformer_h_4_attn_attn_weight, None) # t640: "cuda:0 bf16[1, 2048, 4672]"
t646 = torch.reshape(t640, (1, 2048, 1, 73, 64)) # t646: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t646 = ltorch.reshape(t640, (1, 2048, 1, 73, 64)) # t646: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t646 = prims.reshape(t640, (1, 2048, 1, 73, 64)) # t646: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t640
t652 = torch.permute(t646, (0, 2, 3, 1, 4)) # t652: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t652 = ltorch.permute(t646, (0, 2, 3, 1, 4)) # t652: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t652 = prims.transpose(t646, (0, 2, 3, 1, 4)) # t652: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t646
(t653, t654, t655) = torch.split(t652, (71, 1, 1), 2)
# (t653, t654, t655) = ltorch.split(t652, (71, 1, 1), 2)
# t653 = prims.slice_prim(t652, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t653: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t654 = prims.slice_prim(t652, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t654: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t655 = prims.slice_prim(t652, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t655: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t652
t661 = Tensor.expand(t654, (1, 1, 71, 2048, 64)) # t661: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t661 = ltorch.expand(t654, (1, 1, 71, 2048, 64)) # t661: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t661 = prims.broadcast_in_dim(t654, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t661: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t654
t667 = Tensor.expand(t655, (1, 1, 71, 2048, 64)) # t667: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t667 = ltorch.expand(t655, (1, 1, 71, 2048, 64)) # t667: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t667 = prims.broadcast_in_dim(t655, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t667: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t655
t673 = torch.reshape(t653, (1, 71, 2048, 64)) # t673: "cuda:0 bf16[1, 71, 2048, 64]"
# t673 = ltorch.reshape(t653, (1, 71, 2048, 64)) # t673: "cuda:0 bf16[1, 71, 2048, 64]"
# t673 = prims.reshape(t653, (1, 71, 2048, 64)) # t673: "cuda:0 bf16[1, 71, 2048, 64]"
del t653
t679 = torch.reshape(t661, (1, 71, 2048, 64)) # t679: "cuda:0 bf16[1, 71, 2048, 64]"
# t679 = ltorch.reshape(t661, (1, 71, 2048, 64)) # t679: "cuda:0 bf16[1, 71, 2048, 64]"
# t679 = prims.reshape(t661, (1, 71, 2048, 64)) # t679: "cuda:0 bf16[1, 71, 2048, 64]"
del t661
t685 = torch.reshape(t667, (1, 71, 2048, 64)) # t685: "cuda:0 bf16[1, 71, 2048, 64]"
# t685 = ltorch.reshape(t667, (1, 71, 2048, 64)) # t685: "cuda:0 bf16[1, 71, 2048, 64]"
# t685 = prims.reshape(t667, (1, 71, 2048, 64)) # t685: "cuda:0 bf16[1, 71, 2048, 64]"
del t667
t686 = torch_slice_prim_impl(t673, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t686: "cuda:0 bf16[1, 71, 2048, 64]"
t687 = torch_slice_prim_impl(t686, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t687: "cuda:0 bf16[1, 71, 2048, 32]"
t688 = torch_slice_prim_impl(t686, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t688: "cuda:0 bf16[1, 71, 2048, 32]"
t708 = torch_slice_prim_impl(t679, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t708: "cuda:0 bf16[1, 71, 2048, 64]"
t709 = torch_slice_prim_impl(t708, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t709: "cuda:0 bf16[1, 71, 2048, 32]"
t710 = torch_slice_prim_impl(t708, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t710: "cuda:0 bf16[1, 71, 2048, 32]"
t730 = torch_slice_prim_impl(t673, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t730: "cuda:0 bf16[1, 71, 2048, 0]"
del t673
t733 = torch_slice_prim_impl(t679, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t733: "cuda:0 bf16[1, 71, 2048, 0]"
del t679
[t732, t735, t767] = nvFusion9(t61, t66, t686, t687, t688, t708, t709, t710, t730, t733, t748)
# t749 = prims.convert_element_type(t748, dtypes.float32) # t749: "cuda:0 f32[1, 2048, 18176]"
# t751 = prims.div(t749, 1.4142135623730951) # t751: "cuda:0 f32[1, 2048, 18176]"
# t754 = prims.erf(t751) # t754: "cuda:0 f32[1, 2048, 18176]"
# t758 = prims.mul(0.5, t754) # t758: "cuda:0 f32[1, 2048, 18176]"
# t762 = prims.add(0.5, t758) # t762: "cuda:0 f32[1, 2048, 18176]"
# t766 = prims.mul(t749, t762) # t766: "cuda:0 f32[1, 2048, 18176]"
# t767 = prims.convert_element_type(t766, dtypes.bfloat16) # t767: "cuda:0 bf16[1, 2048, 18176]"
# t689 = prims.convert_element_type(t688, dtypes.float32) # t689: "cuda:0 f32[1, 71, 2048, 32]"
# t690 = prims.neg(t689) # t690: "cuda:0 f32[1, 71, 2048, 32]"
# t691 = prims.convert_element_type(t690, dtypes.bfloat16) # t691: "cuda:0 bf16[1, 71, 2048, 32]"
# t693 = prims.cat((t691, t687), -1) # t693: "cuda:0 bf16[1, 71, 2048, 64]"
# t695 = prims.convert_element_type(t686, dtypes.float32) # t695: "cuda:0 f32[1, 71, 2048, 64]"
# t697 = prims.mul(t695, t61) # t697: "cuda:0 f32[1, 71, 2048, 64]"
# t700 = prims.convert_element_type(t693, dtypes.float32) # t700: "cuda:0 f32[1, 71, 2048, 64]"
# t702 = prims.mul(t700, t66) # t702: "cuda:0 f32[1, 71, 2048, 64]"
# t706 = prims.add(t697, t702) # t706: "cuda:0 f32[1, 71, 2048, 64]"
# t707 = prims.convert_element_type(t706, dtypes.bfloat16) # t707: "cuda:0 bf16[1, 71, 2048, 64]"
# t711 = prims.convert_element_type(t710, dtypes.float32) # t711: "cuda:0 f32[1, 71, 2048, 32]"
# t712 = prims.neg(t711) # t712: "cuda:0 f32[1, 71, 2048, 32]"
# t713 = prims.convert_element_type(t712, dtypes.bfloat16) # t713: "cuda:0 bf16[1, 71, 2048, 32]"
# t715 = prims.cat((t713, t709), -1) # t715: "cuda:0 bf16[1, 71, 2048, 64]"
# t717 = prims.convert_element_type(t708, dtypes.float32) # t717: "cuda:0 f32[1, 71, 2048, 64]"
# t719 = prims.mul(t717, t61) # t719: "cuda:0 f32[1, 71, 2048, 64]"
# t722 = prims.convert_element_type(t715, dtypes.float32) # t722: "cuda:0 f32[1, 71, 2048, 64]"
# t724 = prims.mul(t722, t66) # t724: "cuda:0 f32[1, 71, 2048, 64]"
# t728 = prims.add(t719, t724) # t728: "cuda:0 f32[1, 71, 2048, 64]"
# t729 = prims.convert_element_type(t728, dtypes.bfloat16) # t729: "cuda:0 bf16[1, 71, 2048, 64]"
# t732 = prims.cat((t707, t730), -1) # t732: "cuda:0 bf16[1, 71, 2048, 64]"
# t735 = prims.cat((t729, t733), -1) # t735: "cuda:0 bf16[1, 71, 2048, 64]"
del t686, t687, t688, t708, t709, t710, t730, t733
t768 = torch.nn.functional.linear(t767, t_transformer_h_4_mlp_proj_weight, None) # t768: "cuda:0 bf16[1, 2048, 4544]"
# t768 = ltorch.linear(t767, t_transformer_h_4_mlp_proj_weight, None) # t768: "cuda:0 bf16[1, 2048, 4544]"
# t768 = prims.linear(t767, t_transformer_h_4_mlp_proj_weight, None) # t768: "cuda:0 bf16[1, 2048, 4544]"
(t736, t737, t738, t739) = cudnn_sdpa_fwd(t732, t735, t685, None, 0.0, True, scale=0.125)
t742 = torch.permute(t736, (0, 2, 1, 3)) # t742: "cuda:0 bf16[1, 2048, 71, 64]"
# t742 = ltorch.permute(t736, (0, 2, 1, 3)) # t742: "cuda:0 bf16[1, 2048, 71, 64]"
# t742 = prims.transpose(t736, (0, 2, 1, 3)) # t742: "cuda:0 bf16[1, 2048, 71, 64]"
t746 = torch.reshape(t742, (1, 2048, 4544)) # t746: "cuda:0 bf16[1, 2048, 4544]"
# t746 = ltorch.reshape(t742, (1, 2048, 4544)) # t746: "cuda:0 bf16[1, 2048, 4544]"
# t746 = prims.reshape(t742, (1, 2048, 4544)) # t746: "cuda:0 bf16[1, 2048, 4544]"
del t742
t747 = torch.nn.functional.linear(t746, t_transformer_h_4_attn_proj_weight, None) # t747: "cuda:0 bf16[1, 2048, 4544]"
# t747 = ltorch.linear(t746, t_transformer_h_4_attn_proj_weight, None) # t747: "cuda:0 bf16[1, 2048, 4544]"
# t747 = prims.linear(t746, t_transformer_h_4_attn_proj_weight, None) # t747: "cuda:0 bf16[1, 2048, 4544]"
t5419 = torch.unsqueeze(t_transformer_h_5_norm_1_weight, 0) # t5419: "cuda:0 bf16[1, 4544]"
# t5419 = ltorch.unsqueeze(t_transformer_h_5_norm_1_weight, 0) # t5419: "cuda:0 bf16[1, 4544]"
# t5419 = prims.broadcast_in_dim(t_transformer_h_5_norm_1_weight, [1, 4544], [1]) # t5419: "cuda:0 bf16[1, 4544]"
t5420 = torch.unsqueeze(t5419, 1) # t5420: "cuda:0 bf16[1, 1, 4544]"
# t5420 = ltorch.unsqueeze(t5419, 1) # t5420: "cuda:0 bf16[1, 1, 4544]"
# t5420 = prims.broadcast_in_dim(t5419, [1, 1, 4544], [0, 2]) # t5420: "cuda:0 bf16[1, 1, 4544]"
del t5419
t794 = Tensor.expand(t5420, (1, 2048, 4544)) # t794: "cuda:0 bf16[1, 2048, 4544]"
# t794 = ltorch.expand(t5420, (1, 2048, 4544)) # t794: "cuda:0 bf16[1, 2048, 4544]"
# t794 = prims.broadcast_in_dim(t5420, (1, 2048, 4544), (0, 1, 2)) # t794: "cuda:0 bf16[1, 2048, 4544]"
del t5420
t5422 = torch.unsqueeze(t_transformer_h_5_norm_1_bias, 0) # t5422: "cuda:0 bf16[1, 4544]"
# t5422 = ltorch.unsqueeze(t_transformer_h_5_norm_1_bias, 0) # t5422: "cuda:0 bf16[1, 4544]"
# t5422 = prims.broadcast_in_dim(t_transformer_h_5_norm_1_bias, [1, 4544], [1]) # t5422: "cuda:0 bf16[1, 4544]"
t5423 = torch.unsqueeze(t5422, 1) # t5423: "cuda:0 bf16[1, 1, 4544]"
# t5423 = ltorch.unsqueeze(t5422, 1) # t5423: "cuda:0 bf16[1, 1, 4544]"
# t5423 = prims.broadcast_in_dim(t5422, [1, 1, 4544], [0, 2]) # t5423: "cuda:0 bf16[1, 1, 4544]"
del t5422
t797 = Tensor.expand(t5423, (1, 2048, 4544)) # t797: "cuda:0 bf16[1, 2048, 4544]"
# t797 = ltorch.expand(t5423, (1, 2048, 4544)) # t797: "cuda:0 bf16[1, 2048, 4544]"
# t797 = prims.broadcast_in_dim(t5423, (1, 2048, 4544), (0, 1, 2)) # t797: "cuda:0 bf16[1, 2048, 4544]"
del t5423
[t776, t783, t788, t800] = nvFusion10(t615, t747, t768, t794, t797)
# t774 = prims.convert_element_type(t615, dtypes.float32) # t774: "cuda:0 f32[1, 2048, 4544]"
# t769 = prims.convert_element_type(t768, dtypes.float32) # t769: "cuda:0 f32[1, 2048, 4544]"
# t770 = prims.convert_element_type(t747, dtypes.float32) # t770: "cuda:0 f32[1, 2048, 4544]"
# t771 = prims.add(t769, t770) # t771: "cuda:0 f32[1, 2048, 4544]"
# t775 = prims.add(t771, t774) # t775: "cuda:0 f32[1, 2048, 4544]"
# t776 = prims.convert_element_type(t775, dtypes.bfloat16) # t776: "cuda:0 bf16[1, 2048, 4544]"
# (t782, t783) = prims.var_mean(t775, (2,), correction=0)
# t784 = prims.broadcast_in_dim(t782, [1, 2048, 1], [0, 1]) # t784: "cuda:0 f32[1, 2048, 1]"
# t785 = prims.broadcast_in_dim(t783, [1, 2048, 1], [0, 1]) # t785: "cuda:0 f32[1, 2048, 1]"
# t787 = prims.add(t784, 1e-05) # t787: "cuda:0 f32[1, 2048, 1]"
# t788 = prims.rsqrt(t787) # t788: "cuda:0 f32[1, 2048, 1]"
# t789 = prims.broadcast_in_dim(t785, (1, 2048, 4544), (0, 1, 2)) # t789: "cuda:0 f32[1, 2048, 4544]"
# t791 = prims.sub(t775, t789) # t791: "cuda:0 f32[1, 2048, 4544]"
# t792 = prims.broadcast_in_dim(t788, (1, 2048, 4544), (0, 1, 2)) # t792: "cuda:0 f32[1, 2048, 4544]"
# t793 = prims.mul(t791, t792) # t793: "cuda:0 f32[1, 2048, 4544]"
# t795 = prims.convert_element_type(t794, dtypes.float32) # t795: "cuda:0 f32[1, 2048, 4544]"
# t796 = prims.mul(t793, t795) # t796: "cuda:0 f32[1, 2048, 4544]"
# t798 = prims.convert_element_type(t797, dtypes.float32) # t798: "cuda:0 f32[1, 2048, 4544]"
# t799 = prims.add(t796, t798) # t799: "cuda:0 f32[1, 2048, 4544]"
# t800 = prims.convert_element_type(t799, dtypes.bfloat16) # t800: "cuda:0 bf16[1, 2048, 4544]"
del t797
t909 = torch.nn.functional.linear(t800, t_transformer_h_5_mlp_fc_weight, None) # t909: "cuda:0 bf16[1, 2048, 18176]"
# t909 = ltorch.linear(t800, t_transformer_h_5_mlp_fc_weight, None) # t909: "cuda:0 bf16[1, 2048, 18176]"
# t909 = prims.linear(t800, t_transformer_h_5_mlp_fc_weight, None) # t909: "cuda:0 bf16[1, 2048, 18176]"
t801 = torch.nn.functional.linear(t800, t_transformer_h_5_attn_attn_weight, None) # t801: "cuda:0 bf16[1, 2048, 4672]"
# t801 = ltorch.linear(t800, t_transformer_h_5_attn_attn_weight, None) # t801: "cuda:0 bf16[1, 2048, 4672]"
# t801 = prims.linear(t800, t_transformer_h_5_attn_attn_weight, None) # t801: "cuda:0 bf16[1, 2048, 4672]"
t807 = torch.reshape(t801, (1, 2048, 1, 73, 64)) # t807: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t807 = ltorch.reshape(t801, (1, 2048, 1, 73, 64)) # t807: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t807 = prims.reshape(t801, (1, 2048, 1, 73, 64)) # t807: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t801
t813 = torch.permute(t807, (0, 2, 3, 1, 4)) # t813: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t813 = ltorch.permute(t807, (0, 2, 3, 1, 4)) # t813: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t813 = prims.transpose(t807, (0, 2, 3, 1, 4)) # t813: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t807
(t814, t815, t816) = torch.split(t813, (71, 1, 1), 2)
# (t814, t815, t816) = ltorch.split(t813, (71, 1, 1), 2)
# t814 = prims.slice_prim(t813, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t814: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t815 = prims.slice_prim(t813, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t815: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t816 = prims.slice_prim(t813, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t816: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t813
t822 = Tensor.expand(t815, (1, 1, 71, 2048, 64)) # t822: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t822 = ltorch.expand(t815, (1, 1, 71, 2048, 64)) # t822: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t822 = prims.broadcast_in_dim(t815, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t822: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t815
t828 = Tensor.expand(t816, (1, 1, 71, 2048, 64)) # t828: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t828 = ltorch.expand(t816, (1, 1, 71, 2048, 64)) # t828: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t828 = prims.broadcast_in_dim(t816, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t828: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t816
t834 = torch.reshape(t814, (1, 71, 2048, 64)) # t834: "cuda:0 bf16[1, 71, 2048, 64]"
# t834 = ltorch.reshape(t814, (1, 71, 2048, 64)) # t834: "cuda:0 bf16[1, 71, 2048, 64]"
# t834 = prims.reshape(t814, (1, 71, 2048, 64)) # t834: "cuda:0 bf16[1, 71, 2048, 64]"
del t814
t840 = torch.reshape(t822, (1, 71, 2048, 64)) # t840: "cuda:0 bf16[1, 71, 2048, 64]"
# t840 = ltorch.reshape(t822, (1, 71, 2048, 64)) # t840: "cuda:0 bf16[1, 71, 2048, 64]"
# t840 = prims.reshape(t822, (1, 71, 2048, 64)) # t840: "cuda:0 bf16[1, 71, 2048, 64]"
del t822
t846 = torch.reshape(t828, (1, 71, 2048, 64)) # t846: "cuda:0 bf16[1, 71, 2048, 64]"
# t846 = ltorch.reshape(t828, (1, 71, 2048, 64)) # t846: "cuda:0 bf16[1, 71, 2048, 64]"
# t846 = prims.reshape(t828, (1, 71, 2048, 64)) # t846: "cuda:0 bf16[1, 71, 2048, 64]"
del t828
t847 = torch_slice_prim_impl(t834, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t847: "cuda:0 bf16[1, 71, 2048, 64]"
t848 = torch_slice_prim_impl(t847, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t848: "cuda:0 bf16[1, 71, 2048, 32]"
t849 = torch_slice_prim_impl(t847, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t849: "cuda:0 bf16[1, 71, 2048, 32]"
t869 = torch_slice_prim_impl(t840, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t869: "cuda:0 bf16[1, 71, 2048, 64]"
t870 = torch_slice_prim_impl(t869, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t870: "cuda:0 bf16[1, 71, 2048, 32]"
t871 = torch_slice_prim_impl(t869, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t871: "cuda:0 bf16[1, 71, 2048, 32]"
t891 = torch_slice_prim_impl(t834, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t891: "cuda:0 bf16[1, 71, 2048, 0]"
del t834
t894 = torch_slice_prim_impl(t840, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t894: "cuda:0 bf16[1, 71, 2048, 0]"
del t840
[t893, t896, t928] = nvFusion11(t61, t66, t847, t848, t849, t869, t870, t871, t891, t894, t909)
# t910 = prims.convert_element_type(t909, dtypes.float32) # t910: "cuda:0 f32[1, 2048, 18176]"
# t912 = prims.div(t910, 1.4142135623730951) # t912: "cuda:0 f32[1, 2048, 18176]"
# t915 = prims.erf(t912) # t915: "cuda:0 f32[1, 2048, 18176]"
# t919 = prims.mul(0.5, t915) # t919: "cuda:0 f32[1, 2048, 18176]"
# t923 = prims.add(0.5, t919) # t923: "cuda:0 f32[1, 2048, 18176]"
# t927 = prims.mul(t910, t923) # t927: "cuda:0 f32[1, 2048, 18176]"
# t928 = prims.convert_element_type(t927, dtypes.bfloat16) # t928: "cuda:0 bf16[1, 2048, 18176]"
# t850 = prims.convert_element_type(t849, dtypes.float32) # t850: "cuda:0 f32[1, 71, 2048, 32]"
# t851 = prims.neg(t850) # t851: "cuda:0 f32[1, 71, 2048, 32]"
# t852 = prims.convert_element_type(t851, dtypes.bfloat16) # t852: "cuda:0 bf16[1, 71, 2048, 32]"
# t854 = prims.cat((t852, t848), -1) # t854: "cuda:0 bf16[1, 71, 2048, 64]"
# t856 = prims.convert_element_type(t847, dtypes.float32) # t856: "cuda:0 f32[1, 71, 2048, 64]"
# t858 = prims.mul(t856, t61) # t858: "cuda:0 f32[1, 71, 2048, 64]"
# t861 = prims.convert_element_type(t854, dtypes.float32) # t861: "cuda:0 f32[1, 71, 2048, 64]"
# t863 = prims.mul(t861, t66) # t863: "cuda:0 f32[1, 71, 2048, 64]"
# t867 = prims.add(t858, t863) # t867: "cuda:0 f32[1, 71, 2048, 64]"
# t868 = prims.convert_element_type(t867, dtypes.bfloat16) # t868: "cuda:0 bf16[1, 71, 2048, 64]"
# t872 = prims.convert_element_type(t871, dtypes.float32) # t872: "cuda:0 f32[1, 71, 2048, 32]"
# t873 = prims.neg(t872) # t873: "cuda:0 f32[1, 71, 2048, 32]"
# t874 = prims.convert_element_type(t873, dtypes.bfloat16) # t874: "cuda:0 bf16[1, 71, 2048, 32]"
# t876 = prims.cat((t874, t870), -1) # t876: "cuda:0 bf16[1, 71, 2048, 64]"
# t878 = prims.convert_element_type(t869, dtypes.float32) # t878: "cuda:0 f32[1, 71, 2048, 64]"
# t880 = prims.mul(t878, t61) # t880: "cuda:0 f32[1, 71, 2048, 64]"
# t883 = prims.convert_element_type(t876, dtypes.float32) # t883: "cuda:0 f32[1, 71, 2048, 64]"
# t885 = prims.mul(t883, t66) # t885: "cuda:0 f32[1, 71, 2048, 64]"
# t889 = prims.add(t880, t885) # t889: "cuda:0 f32[1, 71, 2048, 64]"
# t890 = prims.convert_element_type(t889, dtypes.bfloat16) # t890: "cuda:0 bf16[1, 71, 2048, 64]"
# t893 = prims.cat((t868, t891), -1) # t893: "cuda:0 bf16[1, 71, 2048, 64]"
# t896 = prims.cat((t890, t894), -1) # t896: "cuda:0 bf16[1, 71, 2048, 64]"
del t847, t848, t849, t869, t870, t871, t891, t894
t929 = torch.nn.functional.linear(t928, t_transformer_h_5_mlp_proj_weight, None) # t929: "cuda:0 bf16[1, 2048, 4544]"
# t929 = ltorch.linear(t928, t_transformer_h_5_mlp_proj_weight, None) # t929: "cuda:0 bf16[1, 2048, 4544]"
# t929 = prims.linear(t928, t_transformer_h_5_mlp_proj_weight, None) # t929: "cuda:0 bf16[1, 2048, 4544]"
(t897, t898, t899, t900) = cudnn_sdpa_fwd(t893, t896, t846, None, 0.0, True, scale=0.125)
t903 = torch.permute(t897, (0, 2, 1, 3)) # t903: "cuda:0 bf16[1, 2048, 71, 64]"
# t903 = ltorch.permute(t897, (0, 2, 1, 3)) # t903: "cuda:0 bf16[1, 2048, 71, 64]"
# t903 = prims.transpose(t897, (0, 2, 1, 3)) # t903: "cuda:0 bf16[1, 2048, 71, 64]"
t907 = torch.reshape(t903, (1, 2048, 4544)) # t907: "cuda:0 bf16[1, 2048, 4544]"
# t907 = ltorch.reshape(t903, (1, 2048, 4544)) # t907: "cuda:0 bf16[1, 2048, 4544]"
# t907 = prims.reshape(t903, (1, 2048, 4544)) # t907: "cuda:0 bf16[1, 2048, 4544]"
del t903
t908 = torch.nn.functional.linear(t907, t_transformer_h_5_attn_proj_weight, None) # t908: "cuda:0 bf16[1, 2048, 4544]"
# t908 = ltorch.linear(t907, t_transformer_h_5_attn_proj_weight, None) # t908: "cuda:0 bf16[1, 2048, 4544]"
# t908 = prims.linear(t907, t_transformer_h_5_attn_proj_weight, None) # t908: "cuda:0 bf16[1, 2048, 4544]"
t5445 = torch.unsqueeze(t_transformer_h_6_norm_1_weight, 0) # t5445: "cuda:0 bf16[1, 4544]"
# t5445 = ltorch.unsqueeze(t_transformer_h_6_norm_1_weight, 0) # t5445: "cuda:0 bf16[1, 4544]"
# t5445 = prims.broadcast_in_dim(t_transformer_h_6_norm_1_weight, [1, 4544], [1]) # t5445: "cuda:0 bf16[1, 4544]"
t5446 = torch.unsqueeze(t5445, 1) # t5446: "cuda:0 bf16[1, 1, 4544]"
# t5446 = ltorch.unsqueeze(t5445, 1) # t5446: "cuda:0 bf16[1, 1, 4544]"
# t5446 = prims.broadcast_in_dim(t5445, [1, 1, 4544], [0, 2]) # t5446: "cuda:0 bf16[1, 1, 4544]"
del t5445
t955 = Tensor.expand(t5446, (1, 2048, 4544)) # t955: "cuda:0 bf16[1, 2048, 4544]"
# t955 = ltorch.expand(t5446, (1, 2048, 4544)) # t955: "cuda:0 bf16[1, 2048, 4544]"
# t955 = prims.broadcast_in_dim(t5446, (1, 2048, 4544), (0, 1, 2)) # t955: "cuda:0 bf16[1, 2048, 4544]"
del t5446
t5448 = torch.unsqueeze(t_transformer_h_6_norm_1_bias, 0) # t5448: "cuda:0 bf16[1, 4544]"
# t5448 = ltorch.unsqueeze(t_transformer_h_6_norm_1_bias, 0) # t5448: "cuda:0 bf16[1, 4544]"
# t5448 = prims.broadcast_in_dim(t_transformer_h_6_norm_1_bias, [1, 4544], [1]) # t5448: "cuda:0 bf16[1, 4544]"
t5449 = torch.unsqueeze(t5448, 1) # t5449: "cuda:0 bf16[1, 1, 4544]"
# t5449 = ltorch.unsqueeze(t5448, 1) # t5449: "cuda:0 bf16[1, 1, 4544]"
# t5449 = prims.broadcast_in_dim(t5448, [1, 1, 4544], [0, 2]) # t5449: "cuda:0 bf16[1, 1, 4544]"
del t5448
t958 = Tensor.expand(t5449, (1, 2048, 4544)) # t958: "cuda:0 bf16[1, 2048, 4544]"
# t958 = ltorch.expand(t5449, (1, 2048, 4544)) # t958: "cuda:0 bf16[1, 2048, 4544]"
# t958 = prims.broadcast_in_dim(t5449, (1, 2048, 4544), (0, 1, 2)) # t958: "cuda:0 bf16[1, 2048, 4544]"
del t5449
[t937, t944, t949, t961] = nvFusion12(t776, t908, t929, t955, t958)
# t935 = prims.convert_element_type(t776, dtypes.float32) # t935: "cuda:0 f32[1, 2048, 4544]"
# t930 = prims.convert_element_type(t929, dtypes.float32) # t930: "cuda:0 f32[1, 2048, 4544]"
# t931 = prims.convert_element_type(t908, dtypes.float32) # t931: "cuda:0 f32[1, 2048, 4544]"
# t932 = prims.add(t930, t931) # t932: "cuda:0 f32[1, 2048, 4544]"
# t936 = prims.add(t932, t935) # t936: "cuda:0 f32[1, 2048, 4544]"
# t937 = prims.convert_element_type(t936, dtypes.bfloat16) # t937: "cuda:0 bf16[1, 2048, 4544]"
# (t943, t944) = prims.var_mean(t936, (2,), correction=0)
# t945 = prims.broadcast_in_dim(t943, [1, 2048, 1], [0, 1]) # t945: "cuda:0 f32[1, 2048, 1]"
# t946 = prims.broadcast_in_dim(t944, [1, 2048, 1], [0, 1]) # t946: "cuda:0 f32[1, 2048, 1]"
# t948 = prims.add(t945, 1e-05) # t948: "cuda:0 f32[1, 2048, 1]"
# t949 = prims.rsqrt(t948) # t949: "cuda:0 f32[1, 2048, 1]"
# t950 = prims.broadcast_in_dim(t946, (1, 2048, 4544), (0, 1, 2)) # t950: "cuda:0 f32[1, 2048, 4544]"
# t952 = prims.sub(t936, t950) # t952: "cuda:0 f32[1, 2048, 4544]"
# t953 = prims.broadcast_in_dim(t949, (1, 2048, 4544), (0, 1, 2)) # t953: "cuda:0 f32[1, 2048, 4544]"
# t954 = prims.mul(t952, t953) # t954: "cuda:0 f32[1, 2048, 4544]"
# t956 = prims.convert_element_type(t955, dtypes.float32) # t956: "cuda:0 f32[1, 2048, 4544]"
# t957 = prims.mul(t954, t956) # t957: "cuda:0 f32[1, 2048, 4544]"
# t959 = prims.convert_element_type(t958, dtypes.float32) # t959: "cuda:0 f32[1, 2048, 4544]"
# t960 = prims.add(t957, t959) # t960: "cuda:0 f32[1, 2048, 4544]"
# t961 = prims.convert_element_type(t960, dtypes.bfloat16) # t961: "cuda:0 bf16[1, 2048, 4544]"
del t958
t1070 = torch.nn.functional.linear(t961, t_transformer_h_6_mlp_fc_weight, None) # t1070: "cuda:0 bf16[1, 2048, 18176]"
# t1070 = ltorch.linear(t961, t_transformer_h_6_mlp_fc_weight, None) # t1070: "cuda:0 bf16[1, 2048, 18176]"
# t1070 = prims.linear(t961, t_transformer_h_6_mlp_fc_weight, None) # t1070: "cuda:0 bf16[1, 2048, 18176]"
t962 = torch.nn.functional.linear(t961, t_transformer_h_6_attn_attn_weight, None) # t962: "cuda:0 bf16[1, 2048, 4672]"
# t962 = ltorch.linear(t961, t_transformer_h_6_attn_attn_weight, None) # t962: "cuda:0 bf16[1, 2048, 4672]"
# t962 = prims.linear(t961, t_transformer_h_6_attn_attn_weight, None) # t962: "cuda:0 bf16[1, 2048, 4672]"
t968 = torch.reshape(t962, (1, 2048, 1, 73, 64)) # t968: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t968 = ltorch.reshape(t962, (1, 2048, 1, 73, 64)) # t968: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t968 = prims.reshape(t962, (1, 2048, 1, 73, 64)) # t968: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t962
t974 = torch.permute(t968, (0, 2, 3, 1, 4)) # t974: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t974 = ltorch.permute(t968, (0, 2, 3, 1, 4)) # t974: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t974 = prims.transpose(t968, (0, 2, 3, 1, 4)) # t974: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t968
(t975, t976, t977) = torch.split(t974, (71, 1, 1), 2)
# (t975, t976, t977) = ltorch.split(t974, (71, 1, 1), 2)
# t975 = prims.slice_prim(t974, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t975: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t976 = prims.slice_prim(t974, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t976: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t977 = prims.slice_prim(t974, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t977: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t974
t983 = Tensor.expand(t976, (1, 1, 71, 2048, 64)) # t983: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t983 = ltorch.expand(t976, (1, 1, 71, 2048, 64)) # t983: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t983 = prims.broadcast_in_dim(t976, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t983: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t976
t989 = Tensor.expand(t977, (1, 1, 71, 2048, 64)) # t989: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t989 = ltorch.expand(t977, (1, 1, 71, 2048, 64)) # t989: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t989 = prims.broadcast_in_dim(t977, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t989: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t977
t995 = torch.reshape(t975, (1, 71, 2048, 64)) # t995: "cuda:0 bf16[1, 71, 2048, 64]"
# t995 = ltorch.reshape(t975, (1, 71, 2048, 64)) # t995: "cuda:0 bf16[1, 71, 2048, 64]"
# t995 = prims.reshape(t975, (1, 71, 2048, 64)) # t995: "cuda:0 bf16[1, 71, 2048, 64]"
del t975
t1001 = torch.reshape(t983, (1, 71, 2048, 64)) # t1001: "cuda:0 bf16[1, 71, 2048, 64]"
# t1001 = ltorch.reshape(t983, (1, 71, 2048, 64)) # t1001: "cuda:0 bf16[1, 71, 2048, 64]"
# t1001 = prims.reshape(t983, (1, 71, 2048, 64)) # t1001: "cuda:0 bf16[1, 71, 2048, 64]"
del t983
t1007 = torch.reshape(t989, (1, 71, 2048, 64)) # t1007: "cuda:0 bf16[1, 71, 2048, 64]"
# t1007 = ltorch.reshape(t989, (1, 71, 2048, 64)) # t1007: "cuda:0 bf16[1, 71, 2048, 64]"
# t1007 = prims.reshape(t989, (1, 71, 2048, 64)) # t1007: "cuda:0 bf16[1, 71, 2048, 64]"
del t989
t1008 = torch_slice_prim_impl(t995, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1008: "cuda:0 bf16[1, 71, 2048, 64]"
t1009 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1009: "cuda:0 bf16[1, 71, 2048, 32]"
t1010 = torch_slice_prim_impl(t1008, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1010: "cuda:0 bf16[1, 71, 2048, 32]"
t1030 = torch_slice_prim_impl(t1001, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1030: "cuda:0 bf16[1, 71, 2048, 64]"
t1031 = torch_slice_prim_impl(t1030, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1031: "cuda:0 bf16[1, 71, 2048, 32]"
t1032 = torch_slice_prim_impl(t1030, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1032: "cuda:0 bf16[1, 71, 2048, 32]"
t1052 = torch_slice_prim_impl(t995, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1052: "cuda:0 bf16[1, 71, 2048, 0]"
del t995
t1055 = torch_slice_prim_impl(t1001, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1055: "cuda:0 bf16[1, 71, 2048, 0]"
del t1001
[t1054, t1057, t1089] = nvFusion13(t1008, t1009, t1010, t1030, t1031, t1032, t1052, t1055, t1070, t61, t66)
# t1071 = prims.convert_element_type(t1070, dtypes.float32) # t1071: "cuda:0 f32[1, 2048, 18176]"
# t1073 = prims.div(t1071, 1.4142135623730951) # t1073: "cuda:0 f32[1, 2048, 18176]"
# t1076 = prims.erf(t1073) # t1076: "cuda:0 f32[1, 2048, 18176]"
# t1080 = prims.mul(0.5, t1076) # t1080: "cuda:0 f32[1, 2048, 18176]"
# t1084 = prims.add(0.5, t1080) # t1084: "cuda:0 f32[1, 2048, 18176]"
# t1088 = prims.mul(t1071, t1084) # t1088: "cuda:0 f32[1, 2048, 18176]"
# t1089 = prims.convert_element_type(t1088, dtypes.bfloat16) # t1089: "cuda:0 bf16[1, 2048, 18176]"
# t1011 = prims.convert_element_type(t1010, dtypes.float32) # t1011: "cuda:0 f32[1, 71, 2048, 32]"
# t1012 = prims.neg(t1011) # t1012: "cuda:0 f32[1, 71, 2048, 32]"
# t1013 = prims.convert_element_type(t1012, dtypes.bfloat16) # t1013: "cuda:0 bf16[1, 71, 2048, 32]"
# t1015 = prims.cat((t1013, t1009), -1) # t1015: "cuda:0 bf16[1, 71, 2048, 64]"
# t1017 = prims.convert_element_type(t1008, dtypes.float32) # t1017: "cuda:0 f32[1, 71, 2048, 64]"
# t1019 = prims.mul(t1017, t61) # t1019: "cuda:0 f32[1, 71, 2048, 64]"
# t1022 = prims.convert_element_type(t1015, dtypes.float32) # t1022: "cuda:0 f32[1, 71, 2048, 64]"
# t1024 = prims.mul(t1022, t66) # t1024: "cuda:0 f32[1, 71, 2048, 64]"
# t1028 = prims.add(t1019, t1024) # t1028: "cuda:0 f32[1, 71, 2048, 64]"
# t1029 = prims.convert_element_type(t1028, dtypes.bfloat16) # t1029: "cuda:0 bf16[1, 71, 2048, 64]"
# t1033 = prims.convert_element_type(t1032, dtypes.float32) # t1033: "cuda:0 f32[1, 71, 2048, 32]"
# t1034 = prims.neg(t1033) # t1034: "cuda:0 f32[1, 71, 2048, 32]"
# t1035 = prims.convert_element_type(t1034, dtypes.bfloat16) # t1035: "cuda:0 bf16[1, 71, 2048, 32]"
# t1037 = prims.cat((t1035, t1031), -1) # t1037: "cuda:0 bf16[1, 71, 2048, 64]"
# t1039 = prims.convert_element_type(t1030, dtypes.float32) # t1039: "cuda:0 f32[1, 71, 2048, 64]"
# t1041 = prims.mul(t1039, t61) # t1041: "cuda:0 f32[1, 71, 2048, 64]"
# t1044 = prims.convert_element_type(t1037, dtypes.float32) # t1044: "cuda:0 f32[1, 71, 2048, 64]"
# t1046 = prims.mul(t1044, t66) # t1046: "cuda:0 f32[1, 71, 2048, 64]"
# t1050 = prims.add(t1041, t1046) # t1050: "cuda:0 f32[1, 71, 2048, 64]"
# t1051 = prims.convert_element_type(t1050, dtypes.bfloat16) # t1051: "cuda:0 bf16[1, 71, 2048, 64]"
# t1054 = prims.cat((t1029, t1052), -1) # t1054: "cuda:0 bf16[1, 71, 2048, 64]"
# t1057 = prims.cat((t1051, t1055), -1) # t1057: "cuda:0 bf16[1, 71, 2048, 64]"
del t1008, t1009, t1010, t1030, t1031, t1032, t1052, t1055
t1090 = torch.nn.functional.linear(t1089, t_transformer_h_6_mlp_proj_weight, None) # t1090: "cuda:0 bf16[1, 2048, 4544]"
# t1090 = ltorch.linear(t1089, t_transformer_h_6_mlp_proj_weight, None) # t1090: "cuda:0 bf16[1, 2048, 4544]"
# t1090 = prims.linear(t1089, t_transformer_h_6_mlp_proj_weight, None) # t1090: "cuda:0 bf16[1, 2048, 4544]"
(t1058, t1059, t1060, t1061) = cudnn_sdpa_fwd(t1054, t1057, t1007, None, 0.0, True, scale=0.125)
t1064 = torch.permute(t1058, (0, 2, 1, 3)) # t1064: "cuda:0 bf16[1, 2048, 71, 64]"
# t1064 = ltorch.permute(t1058, (0, 2, 1, 3)) # t1064: "cuda:0 bf16[1, 2048, 71, 64]"
# t1064 = prims.transpose(t1058, (0, 2, 1, 3)) # t1064: "cuda:0 bf16[1, 2048, 71, 64]"
t1068 = torch.reshape(t1064, (1, 2048, 4544)) # t1068: "cuda:0 bf16[1, 2048, 4544]"
# t1068 = ltorch.reshape(t1064, (1, 2048, 4544)) # t1068: "cuda:0 bf16[1, 2048, 4544]"
# t1068 = prims.reshape(t1064, (1, 2048, 4544)) # t1068: "cuda:0 bf16[1, 2048, 4544]"
del t1064
t1069 = torch.nn.functional.linear(t1068, t_transformer_h_6_attn_proj_weight, None) # t1069: "cuda:0 bf16[1, 2048, 4544]"
# t1069 = ltorch.linear(t1068, t_transformer_h_6_attn_proj_weight, None) # t1069: "cuda:0 bf16[1, 2048, 4544]"
# t1069 = prims.linear(t1068, t_transformer_h_6_attn_proj_weight, None) # t1069: "cuda:0 bf16[1, 2048, 4544]"
t5471 = torch.unsqueeze(t_transformer_h_7_norm_1_weight, 0) # t5471: "cuda:0 bf16[1, 4544]"
# t5471 = ltorch.unsqueeze(t_transformer_h_7_norm_1_weight, 0) # t5471: "cuda:0 bf16[1, 4544]"
# t5471 = prims.broadcast_in_dim(t_transformer_h_7_norm_1_weight, [1, 4544], [1]) # t5471: "cuda:0 bf16[1, 4544]"
t5472 = torch.unsqueeze(t5471, 1) # t5472: "cuda:0 bf16[1, 1, 4544]"
# t5472 = ltorch.unsqueeze(t5471, 1) # t5472: "cuda:0 bf16[1, 1, 4544]"
# t5472 = prims.broadcast_in_dim(t5471, [1, 1, 4544], [0, 2]) # t5472: "cuda:0 bf16[1, 1, 4544]"
del t5471
t1116 = Tensor.expand(t5472, (1, 2048, 4544)) # t1116: "cuda:0 bf16[1, 2048, 4544]"
# t1116 = ltorch.expand(t5472, (1, 2048, 4544)) # t1116: "cuda:0 bf16[1, 2048, 4544]"
# t1116 = prims.broadcast_in_dim(t5472, (1, 2048, 4544), (0, 1, 2)) # t1116: "cuda:0 bf16[1, 2048, 4544]"
del t5472
t5474 = torch.unsqueeze(t_transformer_h_7_norm_1_bias, 0) # t5474: "cuda:0 bf16[1, 4544]"
# t5474 = ltorch.unsqueeze(t_transformer_h_7_norm_1_bias, 0) # t5474: "cuda:0 bf16[1, 4544]"
# t5474 = prims.broadcast_in_dim(t_transformer_h_7_norm_1_bias, [1, 4544], [1]) # t5474: "cuda:0 bf16[1, 4544]"
t5475 = torch.unsqueeze(t5474, 1) # t5475: "cuda:0 bf16[1, 1, 4544]"
# t5475 = ltorch.unsqueeze(t5474, 1) # t5475: "cuda:0 bf16[1, 1, 4544]"
# t5475 = prims.broadcast_in_dim(t5474, [1, 1, 4544], [0, 2]) # t5475: "cuda:0 bf16[1, 1, 4544]"
del t5474
t1119 = Tensor.expand(t5475, (1, 2048, 4544)) # t1119: "cuda:0 bf16[1, 2048, 4544]"
# t1119 = ltorch.expand(t5475, (1, 2048, 4544)) # t1119: "cuda:0 bf16[1, 2048, 4544]"
# t1119 = prims.broadcast_in_dim(t5475, (1, 2048, 4544), (0, 1, 2)) # t1119: "cuda:0 bf16[1, 2048, 4544]"
del t5475
[t1098, t1105, t1110, t1122] = nvFusion14(t1069, t1090, t1116, t1119, t937)
# t1096 = prims.convert_element_type(t937, dtypes.float32) # t1096: "cuda:0 f32[1, 2048, 4544]"
# t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: "cuda:0 f32[1, 2048, 4544]"
# t1092 = prims.convert_element_type(t1069, dtypes.float32) # t1092: "cuda:0 f32[1, 2048, 4544]"
# t1093 = prims.add(t1091, t1092) # t1093: "cuda:0 f32[1, 2048, 4544]"
# t1097 = prims.add(t1093, t1096) # t1097: "cuda:0 f32[1, 2048, 4544]"
# t1098 = prims.convert_element_type(t1097, dtypes.bfloat16) # t1098: "cuda:0 bf16[1, 2048, 4544]"
# (t1104, t1105) = prims.var_mean(t1097, (2,), correction=0)
# t1106 = prims.broadcast_in_dim(t1104, [1, 2048, 1], [0, 1]) # t1106: "cuda:0 f32[1, 2048, 1]"
# t1107 = prims.broadcast_in_dim(t1105, [1, 2048, 1], [0, 1]) # t1107: "cuda:0 f32[1, 2048, 1]"
# t1109 = prims.add(t1106, 1e-05) # t1109: "cuda:0 f32[1, 2048, 1]"
# t1110 = prims.rsqrt(t1109) # t1110: "cuda:0 f32[1, 2048, 1]"
# t1111 = prims.broadcast_in_dim(t1107, (1, 2048, 4544), (0, 1, 2)) # t1111: "cuda:0 f32[1, 2048, 4544]"
# t1113 = prims.sub(t1097, t1111) # t1113: "cuda:0 f32[1, 2048, 4544]"
# t1114 = prims.broadcast_in_dim(t1110, (1, 2048, 4544), (0, 1, 2)) # t1114: "cuda:0 f32[1, 2048, 4544]"
# t1115 = prims.mul(t1113, t1114) # t1115: "cuda:0 f32[1, 2048, 4544]"
# t1117 = prims.convert_element_type(t1116, dtypes.float32) # t1117: "cuda:0 f32[1, 2048, 4544]"
# t1118 = prims.mul(t1115, t1117) # t1118: "cuda:0 f32[1, 2048, 4544]"
# t1120 = prims.convert_element_type(t1119, dtypes.float32) # t1120: "cuda:0 f32[1, 2048, 4544]"
# t1121 = prims.add(t1118, t1120) # t1121: "cuda:0 f32[1, 2048, 4544]"
# t1122 = prims.convert_element_type(t1121, dtypes.bfloat16) # t1122: "cuda:0 bf16[1, 2048, 4544]"
del t1119
t1231 = torch.nn.functional.linear(t1122, t_transformer_h_7_mlp_fc_weight, None) # t1231: "cuda:0 bf16[1, 2048, 18176]"
# t1231 = ltorch.linear(t1122, t_transformer_h_7_mlp_fc_weight, None) # t1231: "cuda:0 bf16[1, 2048, 18176]"
# t1231 = prims.linear(t1122, t_transformer_h_7_mlp_fc_weight, None) # t1231: "cuda:0 bf16[1, 2048, 18176]"
t1123 = torch.nn.functional.linear(t1122, t_transformer_h_7_attn_attn_weight, None) # t1123: "cuda:0 bf16[1, 2048, 4672]"
# t1123 = ltorch.linear(t1122, t_transformer_h_7_attn_attn_weight, None) # t1123: "cuda:0 bf16[1, 2048, 4672]"
# t1123 = prims.linear(t1122, t_transformer_h_7_attn_attn_weight, None) # t1123: "cuda:0 bf16[1, 2048, 4672]"
t1129 = torch.reshape(t1123, (1, 2048, 1, 73, 64)) # t1129: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1129 = ltorch.reshape(t1123, (1, 2048, 1, 73, 64)) # t1129: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1129 = prims.reshape(t1123, (1, 2048, 1, 73, 64)) # t1129: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t1123
t1135 = torch.permute(t1129, (0, 2, 3, 1, 4)) # t1135: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1135 = ltorch.permute(t1129, (0, 2, 3, 1, 4)) # t1135: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1135 = prims.transpose(t1129, (0, 2, 3, 1, 4)) # t1135: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t1129
(t1136, t1137, t1138) = torch.split(t1135, (71, 1, 1), 2)
# (t1136, t1137, t1138) = ltorch.split(t1135, (71, 1, 1), 2)
# t1136 = prims.slice_prim(t1135, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t1136: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1137 = prims.slice_prim(t1135, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t1137: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t1138 = prims.slice_prim(t1135, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t1138: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t1135
t1144 = Tensor.expand(t1137, (1, 1, 71, 2048, 64)) # t1144: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1144 = ltorch.expand(t1137, (1, 1, 71, 2048, 64)) # t1144: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1144 = prims.broadcast_in_dim(t1137, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1144: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1137
t1150 = Tensor.expand(t1138, (1, 1, 71, 2048, 64)) # t1150: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1150 = ltorch.expand(t1138, (1, 1, 71, 2048, 64)) # t1150: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1150 = prims.broadcast_in_dim(t1138, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1150: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1138
t1156 = torch.reshape(t1136, (1, 71, 2048, 64)) # t1156: "cuda:0 bf16[1, 71, 2048, 64]"
# t1156 = ltorch.reshape(t1136, (1, 71, 2048, 64)) # t1156: "cuda:0 bf16[1, 71, 2048, 64]"
# t1156 = prims.reshape(t1136, (1, 71, 2048, 64)) # t1156: "cuda:0 bf16[1, 71, 2048, 64]"
del t1136
t1162 = torch.reshape(t1144, (1, 71, 2048, 64)) # t1162: "cuda:0 bf16[1, 71, 2048, 64]"
# t1162 = ltorch.reshape(t1144, (1, 71, 2048, 64)) # t1162: "cuda:0 bf16[1, 71, 2048, 64]"
# t1162 = prims.reshape(t1144, (1, 71, 2048, 64)) # t1162: "cuda:0 bf16[1, 71, 2048, 64]"
del t1144
t1168 = torch.reshape(t1150, (1, 71, 2048, 64)) # t1168: "cuda:0 bf16[1, 71, 2048, 64]"
# t1168 = ltorch.reshape(t1150, (1, 71, 2048, 64)) # t1168: "cuda:0 bf16[1, 71, 2048, 64]"
# t1168 = prims.reshape(t1150, (1, 71, 2048, 64)) # t1168: "cuda:0 bf16[1, 71, 2048, 64]"
del t1150
t1169 = torch_slice_prim_impl(t1156, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1169: "cuda:0 bf16[1, 71, 2048, 64]"
t1170 = torch_slice_prim_impl(t1169, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1170: "cuda:0 bf16[1, 71, 2048, 32]"
t1171 = torch_slice_prim_impl(t1169, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1171: "cuda:0 bf16[1, 71, 2048, 32]"
t1191 = torch_slice_prim_impl(t1162, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1191: "cuda:0 bf16[1, 71, 2048, 64]"
t1192 = torch_slice_prim_impl(t1191, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1192: "cuda:0 bf16[1, 71, 2048, 32]"
t1193 = torch_slice_prim_impl(t1191, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1193: "cuda:0 bf16[1, 71, 2048, 32]"
t1213 = torch_slice_prim_impl(t1156, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1213: "cuda:0 bf16[1, 71, 2048, 0]"
del t1156
t1216 = torch_slice_prim_impl(t1162, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1216: "cuda:0 bf16[1, 71, 2048, 0]"
del t1162
[t1215, t1218, t1250] = nvFusion15(t1169, t1170, t1171, t1191, t1192, t1193, t1213, t1216, t1231, t61, t66)
# t1232 = prims.convert_element_type(t1231, dtypes.float32) # t1232: "cuda:0 f32[1, 2048, 18176]"
# t1234 = prims.div(t1232, 1.4142135623730951) # t1234: "cuda:0 f32[1, 2048, 18176]"
# t1237 = prims.erf(t1234) # t1237: "cuda:0 f32[1, 2048, 18176]"
# t1241 = prims.mul(0.5, t1237) # t1241: "cuda:0 f32[1, 2048, 18176]"
# t1245 = prims.add(0.5, t1241) # t1245: "cuda:0 f32[1, 2048, 18176]"
# t1249 = prims.mul(t1232, t1245) # t1249: "cuda:0 f32[1, 2048, 18176]"
# t1250 = prims.convert_element_type(t1249, dtypes.bfloat16) # t1250: "cuda:0 bf16[1, 2048, 18176]"
# t1172 = prims.convert_element_type(t1171, dtypes.float32) # t1172: "cuda:0 f32[1, 71, 2048, 32]"
# t1173 = prims.neg(t1172) # t1173: "cuda:0 f32[1, 71, 2048, 32]"
# t1174 = prims.convert_element_type(t1173, dtypes.bfloat16) # t1174: "cuda:0 bf16[1, 71, 2048, 32]"
# t1176 = prims.cat((t1174, t1170), -1) # t1176: "cuda:0 bf16[1, 71, 2048, 64]"
# t1178 = prims.convert_element_type(t1169, dtypes.float32) # t1178: "cuda:0 f32[1, 71, 2048, 64]"
# t1180 = prims.mul(t1178, t61) # t1180: "cuda:0 f32[1, 71, 2048, 64]"
# t1183 = prims.convert_element_type(t1176, dtypes.float32) # t1183: "cuda:0 f32[1, 71, 2048, 64]"
# t1185 = prims.mul(t1183, t66) # t1185: "cuda:0 f32[1, 71, 2048, 64]"
# t1189 = prims.add(t1180, t1185) # t1189: "cuda:0 f32[1, 71, 2048, 64]"
# t1190 = prims.convert_element_type(t1189, dtypes.bfloat16) # t1190: "cuda:0 bf16[1, 71, 2048, 64]"
# t1194 = prims.convert_element_type(t1193, dtypes.float32) # t1194: "cuda:0 f32[1, 71, 2048, 32]"
# t1195 = prims.neg(t1194) # t1195: "cuda:0 f32[1, 71, 2048, 32]"
# t1196 = prims.convert_element_type(t1195, dtypes.bfloat16) # t1196: "cuda:0 bf16[1, 71, 2048, 32]"
# t1198 = prims.cat((t1196, t1192), -1) # t1198: "cuda:0 bf16[1, 71, 2048, 64]"
# t1200 = prims.convert_element_type(t1191, dtypes.float32) # t1200: "cuda:0 f32[1, 71, 2048, 64]"
# t1202 = prims.mul(t1200, t61) # t1202: "cuda:0 f32[1, 71, 2048, 64]"
# t1205 = prims.convert_element_type(t1198, dtypes.float32) # t1205: "cuda:0 f32[1, 71, 2048, 64]"
# t1207 = prims.mul(t1205, t66) # t1207: "cuda:0 f32[1, 71, 2048, 64]"
# t1211 = prims.add(t1202, t1207) # t1211: "cuda:0 f32[1, 71, 2048, 64]"
# t1212 = prims.convert_element_type(t1211, dtypes.bfloat16) # t1212: "cuda:0 bf16[1, 71, 2048, 64]"
# t1215 = prims.cat((t1190, t1213), -1) # t1215: "cuda:0 bf16[1, 71, 2048, 64]"
# t1218 = prims.cat((t1212, t1216), -1) # t1218: "cuda:0 bf16[1, 71, 2048, 64]"
del t1169, t1170, t1171, t1191, t1192, t1193, t1213, t1216
t1251 = torch.nn.functional.linear(t1250, t_transformer_h_7_mlp_proj_weight, None) # t1251: "cuda:0 bf16[1, 2048, 4544]"
# t1251 = ltorch.linear(t1250, t_transformer_h_7_mlp_proj_weight, None) # t1251: "cuda:0 bf16[1, 2048, 4544]"
# t1251 = prims.linear(t1250, t_transformer_h_7_mlp_proj_weight, None) # t1251: "cuda:0 bf16[1, 2048, 4544]"
(t1219, t1220, t1221, t1222) = cudnn_sdpa_fwd(t1215, t1218, t1168, None, 0.0, True, scale=0.125)
t1225 = torch.permute(t1219, (0, 2, 1, 3)) # t1225: "cuda:0 bf16[1, 2048, 71, 64]"
# t1225 = ltorch.permute(t1219, (0, 2, 1, 3)) # t1225: "cuda:0 bf16[1, 2048, 71, 64]"
# t1225 = prims.transpose(t1219, (0, 2, 1, 3)) # t1225: "cuda:0 bf16[1, 2048, 71, 64]"
t1229 = torch.reshape(t1225, (1, 2048, 4544)) # t1229: "cuda:0 bf16[1, 2048, 4544]"
# t1229 = ltorch.reshape(t1225, (1, 2048, 4544)) # t1229: "cuda:0 bf16[1, 2048, 4544]"
# t1229 = prims.reshape(t1225, (1, 2048, 4544)) # t1229: "cuda:0 bf16[1, 2048, 4544]"
del t1225
t1230 = torch.nn.functional.linear(t1229, t_transformer_h_7_attn_proj_weight, None) # t1230: "cuda:0 bf16[1, 2048, 4544]"
# t1230 = ltorch.linear(t1229, t_transformer_h_7_attn_proj_weight, None) # t1230: "cuda:0 bf16[1, 2048, 4544]"
# t1230 = prims.linear(t1229, t_transformer_h_7_attn_proj_weight, None) # t1230: "cuda:0 bf16[1, 2048, 4544]"
t5497 = torch.unsqueeze(t_transformer_h_8_norm_1_weight, 0) # t5497: "cuda:0 bf16[1, 4544]"
# t5497 = ltorch.unsqueeze(t_transformer_h_8_norm_1_weight, 0) # t5497: "cuda:0 bf16[1, 4544]"
# t5497 = prims.broadcast_in_dim(t_transformer_h_8_norm_1_weight, [1, 4544], [1]) # t5497: "cuda:0 bf16[1, 4544]"
t5498 = torch.unsqueeze(t5497, 1) # t5498: "cuda:0 bf16[1, 1, 4544]"
# t5498 = ltorch.unsqueeze(t5497, 1) # t5498: "cuda:0 bf16[1, 1, 4544]"
# t5498 = prims.broadcast_in_dim(t5497, [1, 1, 4544], [0, 2]) # t5498: "cuda:0 bf16[1, 1, 4544]"
del t5497
t1277 = Tensor.expand(t5498, (1, 2048, 4544)) # t1277: "cuda:0 bf16[1, 2048, 4544]"
# t1277 = ltorch.expand(t5498, (1, 2048, 4544)) # t1277: "cuda:0 bf16[1, 2048, 4544]"
# t1277 = prims.broadcast_in_dim(t5498, (1, 2048, 4544), (0, 1, 2)) # t1277: "cuda:0 bf16[1, 2048, 4544]"
del t5498
t5500 = torch.unsqueeze(t_transformer_h_8_norm_1_bias, 0) # t5500: "cuda:0 bf16[1, 4544]"
# t5500 = ltorch.unsqueeze(t_transformer_h_8_norm_1_bias, 0) # t5500: "cuda:0 bf16[1, 4544]"
# t5500 = prims.broadcast_in_dim(t_transformer_h_8_norm_1_bias, [1, 4544], [1]) # t5500: "cuda:0 bf16[1, 4544]"
t5501 = torch.unsqueeze(t5500, 1) # t5501: "cuda:0 bf16[1, 1, 4544]"
# t5501 = ltorch.unsqueeze(t5500, 1) # t5501: "cuda:0 bf16[1, 1, 4544]"
# t5501 = prims.broadcast_in_dim(t5500, [1, 1, 4544], [0, 2]) # t5501: "cuda:0 bf16[1, 1, 4544]"
del t5500
t1280 = Tensor.expand(t5501, (1, 2048, 4544)) # t1280: "cuda:0 bf16[1, 2048, 4544]"
# t1280 = ltorch.expand(t5501, (1, 2048, 4544)) # t1280: "cuda:0 bf16[1, 2048, 4544]"
# t1280 = prims.broadcast_in_dim(t5501, (1, 2048, 4544), (0, 1, 2)) # t1280: "cuda:0 bf16[1, 2048, 4544]"
del t5501
[t1259, t1266, t1271, t1283] = nvFusion16(t1098, t1230, t1251, t1277, t1280)
# t1257 = prims.convert_element_type(t1098, dtypes.float32) # t1257: "cuda:0 f32[1, 2048, 4544]"
# t1252 = prims.convert_element_type(t1251, dtypes.float32) # t1252: "cuda:0 f32[1, 2048, 4544]"
# t1253 = prims.convert_element_type(t1230, dtypes.float32) # t1253: "cuda:0 f32[1, 2048, 4544]"
# t1254 = prims.add(t1252, t1253) # t1254: "cuda:0 f32[1, 2048, 4544]"
# t1258 = prims.add(t1254, t1257) # t1258: "cuda:0 f32[1, 2048, 4544]"
# t1259 = prims.convert_element_type(t1258, dtypes.bfloat16) # t1259: "cuda:0 bf16[1, 2048, 4544]"
# (t1265, t1266) = prims.var_mean(t1258, (2,), correction=0)
# t1267 = prims.broadcast_in_dim(t1265, [1, 2048, 1], [0, 1]) # t1267: "cuda:0 f32[1, 2048, 1]"
# t1268 = prims.broadcast_in_dim(t1266, [1, 2048, 1], [0, 1]) # t1268: "cuda:0 f32[1, 2048, 1]"
# t1270 = prims.add(t1267, 1e-05) # t1270: "cuda:0 f32[1, 2048, 1]"
# t1271 = prims.rsqrt(t1270) # t1271: "cuda:0 f32[1, 2048, 1]"
# t1272 = prims.broadcast_in_dim(t1268, (1, 2048, 4544), (0, 1, 2)) # t1272: "cuda:0 f32[1, 2048, 4544]"
# t1274 = prims.sub(t1258, t1272) # t1274: "cuda:0 f32[1, 2048, 4544]"
# t1275 = prims.broadcast_in_dim(t1271, (1, 2048, 4544), (0, 1, 2)) # t1275: "cuda:0 f32[1, 2048, 4544]"
# t1276 = prims.mul(t1274, t1275) # t1276: "cuda:0 f32[1, 2048, 4544]"
# t1278 = prims.convert_element_type(t1277, dtypes.float32) # t1278: "cuda:0 f32[1, 2048, 4544]"
# t1279 = prims.mul(t1276, t1278) # t1279: "cuda:0 f32[1, 2048, 4544]"
# t1281 = prims.convert_element_type(t1280, dtypes.float32) # t1281: "cuda:0 f32[1, 2048, 4544]"
# t1282 = prims.add(t1279, t1281) # t1282: "cuda:0 f32[1, 2048, 4544]"
# t1283 = prims.convert_element_type(t1282, dtypes.bfloat16) # t1283: "cuda:0 bf16[1, 2048, 4544]"
del t1280
t1392 = torch.nn.functional.linear(t1283, t_transformer_h_8_mlp_fc_weight, None) # t1392: "cuda:0 bf16[1, 2048, 18176]"
# t1392 = ltorch.linear(t1283, t_transformer_h_8_mlp_fc_weight, None) # t1392: "cuda:0 bf16[1, 2048, 18176]"
# t1392 = prims.linear(t1283, t_transformer_h_8_mlp_fc_weight, None) # t1392: "cuda:0 bf16[1, 2048, 18176]"
t1284 = torch.nn.functional.linear(t1283, t_transformer_h_8_attn_attn_weight, None) # t1284: "cuda:0 bf16[1, 2048, 4672]"
# t1284 = ltorch.linear(t1283, t_transformer_h_8_attn_attn_weight, None) # t1284: "cuda:0 bf16[1, 2048, 4672]"
# t1284 = prims.linear(t1283, t_transformer_h_8_attn_attn_weight, None) # t1284: "cuda:0 bf16[1, 2048, 4672]"
t1290 = torch.reshape(t1284, (1, 2048, 1, 73, 64)) # t1290: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1290 = ltorch.reshape(t1284, (1, 2048, 1, 73, 64)) # t1290: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1290 = prims.reshape(t1284, (1, 2048, 1, 73, 64)) # t1290: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t1284
t1296 = torch.permute(t1290, (0, 2, 3, 1, 4)) # t1296: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1296 = ltorch.permute(t1290, (0, 2, 3, 1, 4)) # t1296: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1296 = prims.transpose(t1290, (0, 2, 3, 1, 4)) # t1296: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t1290
(t1297, t1298, t1299) = torch.split(t1296, (71, 1, 1), 2)
# (t1297, t1298, t1299) = ltorch.split(t1296, (71, 1, 1), 2)
# t1297 = prims.slice_prim(t1296, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t1297: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1298 = prims.slice_prim(t1296, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t1298: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t1299 = prims.slice_prim(t1296, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t1299: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t1296
t1305 = Tensor.expand(t1298, (1, 1, 71, 2048, 64)) # t1305: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1305 = ltorch.expand(t1298, (1, 1, 71, 2048, 64)) # t1305: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1305 = prims.broadcast_in_dim(t1298, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1305: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1298
t1311 = Tensor.expand(t1299, (1, 1, 71, 2048, 64)) # t1311: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1311 = ltorch.expand(t1299, (1, 1, 71, 2048, 64)) # t1311: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1311 = prims.broadcast_in_dim(t1299, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1311: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1299
t1317 = torch.reshape(t1297, (1, 71, 2048, 64)) # t1317: "cuda:0 bf16[1, 71, 2048, 64]"
# t1317 = ltorch.reshape(t1297, (1, 71, 2048, 64)) # t1317: "cuda:0 bf16[1, 71, 2048, 64]"
# t1317 = prims.reshape(t1297, (1, 71, 2048, 64)) # t1317: "cuda:0 bf16[1, 71, 2048, 64]"
del t1297
t1323 = torch.reshape(t1305, (1, 71, 2048, 64)) # t1323: "cuda:0 bf16[1, 71, 2048, 64]"
# t1323 = ltorch.reshape(t1305, (1, 71, 2048, 64)) # t1323: "cuda:0 bf16[1, 71, 2048, 64]"
# t1323 = prims.reshape(t1305, (1, 71, 2048, 64)) # t1323: "cuda:0 bf16[1, 71, 2048, 64]"
del t1305
t1329 = torch.reshape(t1311, (1, 71, 2048, 64)) # t1329: "cuda:0 bf16[1, 71, 2048, 64]"
# t1329 = ltorch.reshape(t1311, (1, 71, 2048, 64)) # t1329: "cuda:0 bf16[1, 71, 2048, 64]"
# t1329 = prims.reshape(t1311, (1, 71, 2048, 64)) # t1329: "cuda:0 bf16[1, 71, 2048, 64]"
del t1311
t1330 = torch_slice_prim_impl(t1317, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1330: "cuda:0 bf16[1, 71, 2048, 64]"
t1331 = torch_slice_prim_impl(t1330, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1331: "cuda:0 bf16[1, 71, 2048, 32]"
t1332 = torch_slice_prim_impl(t1330, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1332: "cuda:0 bf16[1, 71, 2048, 32]"
t1352 = torch_slice_prim_impl(t1323, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1352: "cuda:0 bf16[1, 71, 2048, 64]"
t1353 = torch_slice_prim_impl(t1352, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1353: "cuda:0 bf16[1, 71, 2048, 32]"
t1354 = torch_slice_prim_impl(t1352, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1354: "cuda:0 bf16[1, 71, 2048, 32]"
t1374 = torch_slice_prim_impl(t1317, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1374: "cuda:0 bf16[1, 71, 2048, 0]"
del t1317
t1377 = torch_slice_prim_impl(t1323, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1377: "cuda:0 bf16[1, 71, 2048, 0]"
del t1323
[t1376, t1379, t1411] = nvFusion17(t1330, t1331, t1332, t1352, t1353, t1354, t1374, t1377, t1392, t61, t66)
# t1393 = prims.convert_element_type(t1392, dtypes.float32) # t1393: "cuda:0 f32[1, 2048, 18176]"
# t1395 = prims.div(t1393, 1.4142135623730951) # t1395: "cuda:0 f32[1, 2048, 18176]"
# t1398 = prims.erf(t1395) # t1398: "cuda:0 f32[1, 2048, 18176]"
# t1402 = prims.mul(0.5, t1398) # t1402: "cuda:0 f32[1, 2048, 18176]"
# t1406 = prims.add(0.5, t1402) # t1406: "cuda:0 f32[1, 2048, 18176]"
# t1410 = prims.mul(t1393, t1406) # t1410: "cuda:0 f32[1, 2048, 18176]"
# t1411 = prims.convert_element_type(t1410, dtypes.bfloat16) # t1411: "cuda:0 bf16[1, 2048, 18176]"
# t1333 = prims.convert_element_type(t1332, dtypes.float32) # t1333: "cuda:0 f32[1, 71, 2048, 32]"
# t1334 = prims.neg(t1333) # t1334: "cuda:0 f32[1, 71, 2048, 32]"
# t1335 = prims.convert_element_type(t1334, dtypes.bfloat16) # t1335: "cuda:0 bf16[1, 71, 2048, 32]"
# t1337 = prims.cat((t1335, t1331), -1) # t1337: "cuda:0 bf16[1, 71, 2048, 64]"
# t1339 = prims.convert_element_type(t1330, dtypes.float32) # t1339: "cuda:0 f32[1, 71, 2048, 64]"
# t1341 = prims.mul(t1339, t61) # t1341: "cuda:0 f32[1, 71, 2048, 64]"
# t1344 = prims.convert_element_type(t1337, dtypes.float32) # t1344: "cuda:0 f32[1, 71, 2048, 64]"
# t1346 = prims.mul(t1344, t66) # t1346: "cuda:0 f32[1, 71, 2048, 64]"
# t1350 = prims.add(t1341, t1346) # t1350: "cuda:0 f32[1, 71, 2048, 64]"
# t1351 = prims.convert_element_type(t1350, dtypes.bfloat16) # t1351: "cuda:0 bf16[1, 71, 2048, 64]"
# t1355 = prims.convert_element_type(t1354, dtypes.float32) # t1355: "cuda:0 f32[1, 71, 2048, 32]"
# t1356 = prims.neg(t1355) # t1356: "cuda:0 f32[1, 71, 2048, 32]"
# t1357 = prims.convert_element_type(t1356, dtypes.bfloat16) # t1357: "cuda:0 bf16[1, 71, 2048, 32]"
# t1359 = prims.cat((t1357, t1353), -1) # t1359: "cuda:0 bf16[1, 71, 2048, 64]"
# t1361 = prims.convert_element_type(t1352, dtypes.float32) # t1361: "cuda:0 f32[1, 71, 2048, 64]"
# t1363 = prims.mul(t1361, t61) # t1363: "cuda:0 f32[1, 71, 2048, 64]"
# t1366 = prims.convert_element_type(t1359, dtypes.float32) # t1366: "cuda:0 f32[1, 71, 2048, 64]"
# t1368 = prims.mul(t1366, t66) # t1368: "cuda:0 f32[1, 71, 2048, 64]"
# t1372 = prims.add(t1363, t1368) # t1372: "cuda:0 f32[1, 71, 2048, 64]"
# t1373 = prims.convert_element_type(t1372, dtypes.bfloat16) # t1373: "cuda:0 bf16[1, 71, 2048, 64]"
# t1376 = prims.cat((t1351, t1374), -1) # t1376: "cuda:0 bf16[1, 71, 2048, 64]"
# t1379 = prims.cat((t1373, t1377), -1) # t1379: "cuda:0 bf16[1, 71, 2048, 64]"
del t1330, t1331, t1332, t1352, t1353, t1354, t1374, t1377
t1412 = torch.nn.functional.linear(t1411, t_transformer_h_8_mlp_proj_weight, None) # t1412: "cuda:0 bf16[1, 2048, 4544]"
# t1412 = ltorch.linear(t1411, t_transformer_h_8_mlp_proj_weight, None) # t1412: "cuda:0 bf16[1, 2048, 4544]"
# t1412 = prims.linear(t1411, t_transformer_h_8_mlp_proj_weight, None) # t1412: "cuda:0 bf16[1, 2048, 4544]"
(t1380, t1381, t1382, t1383) = cudnn_sdpa_fwd(t1376, t1379, t1329, None, 0.0, True, scale=0.125)
t1386 = torch.permute(t1380, (0, 2, 1, 3)) # t1386: "cuda:0 bf16[1, 2048, 71, 64]"
# t1386 = ltorch.permute(t1380, (0, 2, 1, 3)) # t1386: "cuda:0 bf16[1, 2048, 71, 64]"
# t1386 = prims.transpose(t1380, (0, 2, 1, 3)) # t1386: "cuda:0 bf16[1, 2048, 71, 64]"
t1390 = torch.reshape(t1386, (1, 2048, 4544)) # t1390: "cuda:0 bf16[1, 2048, 4544]"
# t1390 = ltorch.reshape(t1386, (1, 2048, 4544)) # t1390: "cuda:0 bf16[1, 2048, 4544]"
# t1390 = prims.reshape(t1386, (1, 2048, 4544)) # t1390: "cuda:0 bf16[1, 2048, 4544]"
del t1386
t1391 = torch.nn.functional.linear(t1390, t_transformer_h_8_attn_proj_weight, None) # t1391: "cuda:0 bf16[1, 2048, 4544]"
# t1391 = ltorch.linear(t1390, t_transformer_h_8_attn_proj_weight, None) # t1391: "cuda:0 bf16[1, 2048, 4544]"
# t1391 = prims.linear(t1390, t_transformer_h_8_attn_proj_weight, None) # t1391: "cuda:0 bf16[1, 2048, 4544]"
t5523 = torch.unsqueeze(t_transformer_h_9_norm_1_weight, 0) # t5523: "cuda:0 bf16[1, 4544]"
# t5523 = ltorch.unsqueeze(t_transformer_h_9_norm_1_weight, 0) # t5523: "cuda:0 bf16[1, 4544]"
# t5523 = prims.broadcast_in_dim(t_transformer_h_9_norm_1_weight, [1, 4544], [1]) # t5523: "cuda:0 bf16[1, 4544]"
t5524 = torch.unsqueeze(t5523, 1) # t5524: "cuda:0 bf16[1, 1, 4544]"
# t5524 = ltorch.unsqueeze(t5523, 1) # t5524: "cuda:0 bf16[1, 1, 4544]"
# t5524 = prims.broadcast_in_dim(t5523, [1, 1, 4544], [0, 2]) # t5524: "cuda:0 bf16[1, 1, 4544]"
del t5523
t1438 = Tensor.expand(t5524, (1, 2048, 4544)) # t1438: "cuda:0 bf16[1, 2048, 4544]"
# t1438 = ltorch.expand(t5524, (1, 2048, 4544)) # t1438: "cuda:0 bf16[1, 2048, 4544]"
# t1438 = prims.broadcast_in_dim(t5524, (1, 2048, 4544), (0, 1, 2)) # t1438: "cuda:0 bf16[1, 2048, 4544]"
del t5524
t5526 = torch.unsqueeze(t_transformer_h_9_norm_1_bias, 0) # t5526: "cuda:0 bf16[1, 4544]"
# t5526 = ltorch.unsqueeze(t_transformer_h_9_norm_1_bias, 0) # t5526: "cuda:0 bf16[1, 4544]"
# t5526 = prims.broadcast_in_dim(t_transformer_h_9_norm_1_bias, [1, 4544], [1]) # t5526: "cuda:0 bf16[1, 4544]"
t5527 = torch.unsqueeze(t5526, 1) # t5527: "cuda:0 bf16[1, 1, 4544]"
# t5527 = ltorch.unsqueeze(t5526, 1) # t5527: "cuda:0 bf16[1, 1, 4544]"
# t5527 = prims.broadcast_in_dim(t5526, [1, 1, 4544], [0, 2]) # t5527: "cuda:0 bf16[1, 1, 4544]"
del t5526
t1441 = Tensor.expand(t5527, (1, 2048, 4544)) # t1441: "cuda:0 bf16[1, 2048, 4544]"
# t1441 = ltorch.expand(t5527, (1, 2048, 4544)) # t1441: "cuda:0 bf16[1, 2048, 4544]"
# t1441 = prims.broadcast_in_dim(t5527, (1, 2048, 4544), (0, 1, 2)) # t1441: "cuda:0 bf16[1, 2048, 4544]"
del t5527
[t1420, t1427, t1432, t1444] = nvFusion18(t1259, t1391, t1412, t1438, t1441)
# t1418 = prims.convert_element_type(t1259, dtypes.float32) # t1418: "cuda:0 f32[1, 2048, 4544]"
# t1413 = prims.convert_element_type(t1412, dtypes.float32) # t1413: "cuda:0 f32[1, 2048, 4544]"
# t1414 = prims.convert_element_type(t1391, dtypes.float32) # t1414: "cuda:0 f32[1, 2048, 4544]"
# t1415 = prims.add(t1413, t1414) # t1415: "cuda:0 f32[1, 2048, 4544]"
# t1419 = prims.add(t1415, t1418) # t1419: "cuda:0 f32[1, 2048, 4544]"
# t1420 = prims.convert_element_type(t1419, dtypes.bfloat16) # t1420: "cuda:0 bf16[1, 2048, 4544]"
# (t1426, t1427) = prims.var_mean(t1419, (2,), correction=0)
# t1428 = prims.broadcast_in_dim(t1426, [1, 2048, 1], [0, 1]) # t1428: "cuda:0 f32[1, 2048, 1]"
# t1429 = prims.broadcast_in_dim(t1427, [1, 2048, 1], [0, 1]) # t1429: "cuda:0 f32[1, 2048, 1]"
# t1431 = prims.add(t1428, 1e-05) # t1431: "cuda:0 f32[1, 2048, 1]"
# t1432 = prims.rsqrt(t1431) # t1432: "cuda:0 f32[1, 2048, 1]"
# t1433 = prims.broadcast_in_dim(t1429, (1, 2048, 4544), (0, 1, 2)) # t1433: "cuda:0 f32[1, 2048, 4544]"
# t1435 = prims.sub(t1419, t1433) # t1435: "cuda:0 f32[1, 2048, 4544]"
# t1436 = prims.broadcast_in_dim(t1432, (1, 2048, 4544), (0, 1, 2)) # t1436: "cuda:0 f32[1, 2048, 4544]"
# t1437 = prims.mul(t1435, t1436) # t1437: "cuda:0 f32[1, 2048, 4544]"
# t1439 = prims.convert_element_type(t1438, dtypes.float32) # t1439: "cuda:0 f32[1, 2048, 4544]"
# t1440 = prims.mul(t1437, t1439) # t1440: "cuda:0 f32[1, 2048, 4544]"
# t1442 = prims.convert_element_type(t1441, dtypes.float32) # t1442: "cuda:0 f32[1, 2048, 4544]"
# t1443 = prims.add(t1440, t1442) # t1443: "cuda:0 f32[1, 2048, 4544]"
# t1444 = prims.convert_element_type(t1443, dtypes.bfloat16) # t1444: "cuda:0 bf16[1, 2048, 4544]"
del t1441
t1553 = torch.nn.functional.linear(t1444, t_transformer_h_9_mlp_fc_weight, None) # t1553: "cuda:0 bf16[1, 2048, 18176]"
# t1553 = ltorch.linear(t1444, t_transformer_h_9_mlp_fc_weight, None) # t1553: "cuda:0 bf16[1, 2048, 18176]"
# t1553 = prims.linear(t1444, t_transformer_h_9_mlp_fc_weight, None) # t1553: "cuda:0 bf16[1, 2048, 18176]"
t1445 = torch.nn.functional.linear(t1444, t_transformer_h_9_attn_attn_weight, None) # t1445: "cuda:0 bf16[1, 2048, 4672]"
# t1445 = ltorch.linear(t1444, t_transformer_h_9_attn_attn_weight, None) # t1445: "cuda:0 bf16[1, 2048, 4672]"
# t1445 = prims.linear(t1444, t_transformer_h_9_attn_attn_weight, None) # t1445: "cuda:0 bf16[1, 2048, 4672]"
t1451 = torch.reshape(t1445, (1, 2048, 1, 73, 64)) # t1451: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1451 = ltorch.reshape(t1445, (1, 2048, 1, 73, 64)) # t1451: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1451 = prims.reshape(t1445, (1, 2048, 1, 73, 64)) # t1451: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t1445
t1457 = torch.permute(t1451, (0, 2, 3, 1, 4)) # t1457: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1457 = ltorch.permute(t1451, (0, 2, 3, 1, 4)) # t1457: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1457 = prims.transpose(t1451, (0, 2, 3, 1, 4)) # t1457: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t1451
(t1458, t1459, t1460) = torch.split(t1457, (71, 1, 1), 2)
# (t1458, t1459, t1460) = ltorch.split(t1457, (71, 1, 1), 2)
# t1458 = prims.slice_prim(t1457, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t1458: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1459 = prims.slice_prim(t1457, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t1459: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t1460 = prims.slice_prim(t1457, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t1460: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t1457
t1466 = Tensor.expand(t1459, (1, 1, 71, 2048, 64)) # t1466: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1466 = ltorch.expand(t1459, (1, 1, 71, 2048, 64)) # t1466: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1466 = prims.broadcast_in_dim(t1459, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1466: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1459
t1472 = Tensor.expand(t1460, (1, 1, 71, 2048, 64)) # t1472: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1472 = ltorch.expand(t1460, (1, 1, 71, 2048, 64)) # t1472: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1472 = prims.broadcast_in_dim(t1460, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1472: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1460
t1478 = torch.reshape(t1458, (1, 71, 2048, 64)) # t1478: "cuda:0 bf16[1, 71, 2048, 64]"
# t1478 = ltorch.reshape(t1458, (1, 71, 2048, 64)) # t1478: "cuda:0 bf16[1, 71, 2048, 64]"
# t1478 = prims.reshape(t1458, (1, 71, 2048, 64)) # t1478: "cuda:0 bf16[1, 71, 2048, 64]"
del t1458
t1484 = torch.reshape(t1466, (1, 71, 2048, 64)) # t1484: "cuda:0 bf16[1, 71, 2048, 64]"
# t1484 = ltorch.reshape(t1466, (1, 71, 2048, 64)) # t1484: "cuda:0 bf16[1, 71, 2048, 64]"
# t1484 = prims.reshape(t1466, (1, 71, 2048, 64)) # t1484: "cuda:0 bf16[1, 71, 2048, 64]"
del t1466
t1490 = torch.reshape(t1472, (1, 71, 2048, 64)) # t1490: "cuda:0 bf16[1, 71, 2048, 64]"
# t1490 = ltorch.reshape(t1472, (1, 71, 2048, 64)) # t1490: "cuda:0 bf16[1, 71, 2048, 64]"
# t1490 = prims.reshape(t1472, (1, 71, 2048, 64)) # t1490: "cuda:0 bf16[1, 71, 2048, 64]"
del t1472
t1491 = torch_slice_prim_impl(t1478, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1491: "cuda:0 bf16[1, 71, 2048, 64]"
t1492 = torch_slice_prim_impl(t1491, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1492: "cuda:0 bf16[1, 71, 2048, 32]"
t1493 = torch_slice_prim_impl(t1491, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1493: "cuda:0 bf16[1, 71, 2048, 32]"
t1513 = torch_slice_prim_impl(t1484, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1513: "cuda:0 bf16[1, 71, 2048, 64]"
t1514 = torch_slice_prim_impl(t1513, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1514: "cuda:0 bf16[1, 71, 2048, 32]"
t1515 = torch_slice_prim_impl(t1513, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1515: "cuda:0 bf16[1, 71, 2048, 32]"
t1535 = torch_slice_prim_impl(t1478, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1535: "cuda:0 bf16[1, 71, 2048, 0]"
del t1478
t1538 = torch_slice_prim_impl(t1484, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1538: "cuda:0 bf16[1, 71, 2048, 0]"
del t1484
[t1537, t1540, t1572] = nvFusion19(t1491, t1492, t1493, t1513, t1514, t1515, t1535, t1538, t1553, t61, t66)
# t1554 = prims.convert_element_type(t1553, dtypes.float32) # t1554: "cuda:0 f32[1, 2048, 18176]"
# t1556 = prims.div(t1554, 1.4142135623730951) # t1556: "cuda:0 f32[1, 2048, 18176]"
# t1559 = prims.erf(t1556) # t1559: "cuda:0 f32[1, 2048, 18176]"
# t1563 = prims.mul(0.5, t1559) # t1563: "cuda:0 f32[1, 2048, 18176]"
# t1567 = prims.add(0.5, t1563) # t1567: "cuda:0 f32[1, 2048, 18176]"
# t1571 = prims.mul(t1554, t1567) # t1571: "cuda:0 f32[1, 2048, 18176]"
# t1572 = prims.convert_element_type(t1571, dtypes.bfloat16) # t1572: "cuda:0 bf16[1, 2048, 18176]"
# t1494 = prims.convert_element_type(t1493, dtypes.float32) # t1494: "cuda:0 f32[1, 71, 2048, 32]"
# t1495 = prims.neg(t1494) # t1495: "cuda:0 f32[1, 71, 2048, 32]"
# t1496 = prims.convert_element_type(t1495, dtypes.bfloat16) # t1496: "cuda:0 bf16[1, 71, 2048, 32]"
# t1498 = prims.cat((t1496, t1492), -1) # t1498: "cuda:0 bf16[1, 71, 2048, 64]"
# t1500 = prims.convert_element_type(t1491, dtypes.float32) # t1500: "cuda:0 f32[1, 71, 2048, 64]"
# t1502 = prims.mul(t1500, t61) # t1502: "cuda:0 f32[1, 71, 2048, 64]"
# t1505 = prims.convert_element_type(t1498, dtypes.float32) # t1505: "cuda:0 f32[1, 71, 2048, 64]"
# t1507 = prims.mul(t1505, t66) # t1507: "cuda:0 f32[1, 71, 2048, 64]"
# t1511 = prims.add(t1502, t1507) # t1511: "cuda:0 f32[1, 71, 2048, 64]"
# t1512 = prims.convert_element_type(t1511, dtypes.bfloat16) # t1512: "cuda:0 bf16[1, 71, 2048, 64]"
# t1516 = prims.convert_element_type(t1515, dtypes.float32) # t1516: "cuda:0 f32[1, 71, 2048, 32]"
# t1517 = prims.neg(t1516) # t1517: "cuda:0 f32[1, 71, 2048, 32]"
# t1518 = prims.convert_element_type(t1517, dtypes.bfloat16) # t1518: "cuda:0 bf16[1, 71, 2048, 32]"
# t1520 = prims.cat((t1518, t1514), -1) # t1520: "cuda:0 bf16[1, 71, 2048, 64]"
# t1522 = prims.convert_element_type(t1513, dtypes.float32) # t1522: "cuda:0 f32[1, 71, 2048, 64]"
# t1524 = prims.mul(t1522, t61) # t1524: "cuda:0 f32[1, 71, 2048, 64]"
# t1527 = prims.convert_element_type(t1520, dtypes.float32) # t1527: "cuda:0 f32[1, 71, 2048, 64]"
# t1529 = prims.mul(t1527, t66) # t1529: "cuda:0 f32[1, 71, 2048, 64]"
# t1533 = prims.add(t1524, t1529) # t1533: "cuda:0 f32[1, 71, 2048, 64]"
# t1534 = prims.convert_element_type(t1533, dtypes.bfloat16) # t1534: "cuda:0 bf16[1, 71, 2048, 64]"
# t1537 = prims.cat((t1512, t1535), -1) # t1537: "cuda:0 bf16[1, 71, 2048, 64]"
# t1540 = prims.cat((t1534, t1538), -1) # t1540: "cuda:0 bf16[1, 71, 2048, 64]"
del t1491, t1492, t1493, t1513, t1514, t1515, t1535, t1538
t1573 = torch.nn.functional.linear(t1572, t_transformer_h_9_mlp_proj_weight, None) # t1573: "cuda:0 bf16[1, 2048, 4544]"
# t1573 = ltorch.linear(t1572, t_transformer_h_9_mlp_proj_weight, None) # t1573: "cuda:0 bf16[1, 2048, 4544]"
# t1573 = prims.linear(t1572, t_transformer_h_9_mlp_proj_weight, None) # t1573: "cuda:0 bf16[1, 2048, 4544]"
(t1541, t1542, t1543, t1544) = cudnn_sdpa_fwd(t1537, t1540, t1490, None, 0.0, True, scale=0.125)
t1547 = torch.permute(t1541, (0, 2, 1, 3)) # t1547: "cuda:0 bf16[1, 2048, 71, 64]"
# t1547 = ltorch.permute(t1541, (0, 2, 1, 3)) # t1547: "cuda:0 bf16[1, 2048, 71, 64]"
# t1547 = prims.transpose(t1541, (0, 2, 1, 3)) # t1547: "cuda:0 bf16[1, 2048, 71, 64]"
t1551 = torch.reshape(t1547, (1, 2048, 4544)) # t1551: "cuda:0 bf16[1, 2048, 4544]"
# t1551 = ltorch.reshape(t1547, (1, 2048, 4544)) # t1551: "cuda:0 bf16[1, 2048, 4544]"
# t1551 = prims.reshape(t1547, (1, 2048, 4544)) # t1551: "cuda:0 bf16[1, 2048, 4544]"
del t1547
t1552 = torch.nn.functional.linear(t1551, t_transformer_h_9_attn_proj_weight, None) # t1552: "cuda:0 bf16[1, 2048, 4544]"
# t1552 = ltorch.linear(t1551, t_transformer_h_9_attn_proj_weight, None) # t1552: "cuda:0 bf16[1, 2048, 4544]"
# t1552 = prims.linear(t1551, t_transformer_h_9_attn_proj_weight, None) # t1552: "cuda:0 bf16[1, 2048, 4544]"
t5549 = torch.unsqueeze(t_transformer_h_10_norm_1_weight, 0) # t5549: "cuda:0 bf16[1, 4544]"
# t5549 = ltorch.unsqueeze(t_transformer_h_10_norm_1_weight, 0) # t5549: "cuda:0 bf16[1, 4544]"
# t5549 = prims.broadcast_in_dim(t_transformer_h_10_norm_1_weight, [1, 4544], [1]) # t5549: "cuda:0 bf16[1, 4544]"
t5550 = torch.unsqueeze(t5549, 1) # t5550: "cuda:0 bf16[1, 1, 4544]"
# t5550 = ltorch.unsqueeze(t5549, 1) # t5550: "cuda:0 bf16[1, 1, 4544]"
# t5550 = prims.broadcast_in_dim(t5549, [1, 1, 4544], [0, 2]) # t5550: "cuda:0 bf16[1, 1, 4544]"
del t5549
t1599 = Tensor.expand(t5550, (1, 2048, 4544)) # t1599: "cuda:0 bf16[1, 2048, 4544]"
# t1599 = ltorch.expand(t5550, (1, 2048, 4544)) # t1599: "cuda:0 bf16[1, 2048, 4544]"
# t1599 = prims.broadcast_in_dim(t5550, (1, 2048, 4544), (0, 1, 2)) # t1599: "cuda:0 bf16[1, 2048, 4544]"
del t5550
t5552 = torch.unsqueeze(t_transformer_h_10_norm_1_bias, 0) # t5552: "cuda:0 bf16[1, 4544]"
# t5552 = ltorch.unsqueeze(t_transformer_h_10_norm_1_bias, 0) # t5552: "cuda:0 bf16[1, 4544]"
# t5552 = prims.broadcast_in_dim(t_transformer_h_10_norm_1_bias, [1, 4544], [1]) # t5552: "cuda:0 bf16[1, 4544]"
t5553 = torch.unsqueeze(t5552, 1) # t5553: "cuda:0 bf16[1, 1, 4544]"
# t5553 = ltorch.unsqueeze(t5552, 1) # t5553: "cuda:0 bf16[1, 1, 4544]"
# t5553 = prims.broadcast_in_dim(t5552, [1, 1, 4544], [0, 2]) # t5553: "cuda:0 bf16[1, 1, 4544]"
del t5552
t1602 = Tensor.expand(t5553, (1, 2048, 4544)) # t1602: "cuda:0 bf16[1, 2048, 4544]"
# t1602 = ltorch.expand(t5553, (1, 2048, 4544)) # t1602: "cuda:0 bf16[1, 2048, 4544]"
# t1602 = prims.broadcast_in_dim(t5553, (1, 2048, 4544), (0, 1, 2)) # t1602: "cuda:0 bf16[1, 2048, 4544]"
del t5553
[t1581, t1588, t1593, t1605] = nvFusion20(t1420, t1552, t1573, t1599, t1602)
# t1579 = prims.convert_element_type(t1420, dtypes.float32) # t1579: "cuda:0 f32[1, 2048, 4544]"
# t1574 = prims.convert_element_type(t1573, dtypes.float32) # t1574: "cuda:0 f32[1, 2048, 4544]"
# t1575 = prims.convert_element_type(t1552, dtypes.float32) # t1575: "cuda:0 f32[1, 2048, 4544]"
# t1576 = prims.add(t1574, t1575) # t1576: "cuda:0 f32[1, 2048, 4544]"
# t1580 = prims.add(t1576, t1579) # t1580: "cuda:0 f32[1, 2048, 4544]"
# t1581 = prims.convert_element_type(t1580, dtypes.bfloat16) # t1581: "cuda:0 bf16[1, 2048, 4544]"
# (t1587, t1588) = prims.var_mean(t1580, (2,), correction=0)
# t1589 = prims.broadcast_in_dim(t1587, [1, 2048, 1], [0, 1]) # t1589: "cuda:0 f32[1, 2048, 1]"
# t1590 = prims.broadcast_in_dim(t1588, [1, 2048, 1], [0, 1]) # t1590: "cuda:0 f32[1, 2048, 1]"
# t1592 = prims.add(t1589, 1e-05) # t1592: "cuda:0 f32[1, 2048, 1]"
# t1593 = prims.rsqrt(t1592) # t1593: "cuda:0 f32[1, 2048, 1]"
# t1594 = prims.broadcast_in_dim(t1590, (1, 2048, 4544), (0, 1, 2)) # t1594: "cuda:0 f32[1, 2048, 4544]"
# t1596 = prims.sub(t1580, t1594) # t1596: "cuda:0 f32[1, 2048, 4544]"
# t1597 = prims.broadcast_in_dim(t1593, (1, 2048, 4544), (0, 1, 2)) # t1597: "cuda:0 f32[1, 2048, 4544]"
# t1598 = prims.mul(t1596, t1597) # t1598: "cuda:0 f32[1, 2048, 4544]"
# t1600 = prims.convert_element_type(t1599, dtypes.float32) # t1600: "cuda:0 f32[1, 2048, 4544]"
# t1601 = prims.mul(t1598, t1600) # t1601: "cuda:0 f32[1, 2048, 4544]"
# t1603 = prims.convert_element_type(t1602, dtypes.float32) # t1603: "cuda:0 f32[1, 2048, 4544]"
# t1604 = prims.add(t1601, t1603) # t1604: "cuda:0 f32[1, 2048, 4544]"
# t1605 = prims.convert_element_type(t1604, dtypes.bfloat16) # t1605: "cuda:0 bf16[1, 2048, 4544]"
del t1602
t1606 = torch.nn.functional.linear(t1605, t_transformer_h_10_attn_attn_weight, None) # t1606: "cuda:0 bf16[1, 2048, 4672]"
# t1606 = ltorch.linear(t1605, t_transformer_h_10_attn_attn_weight, None) # t1606: "cuda:0 bf16[1, 2048, 4672]"
# t1606 = prims.linear(t1605, t_transformer_h_10_attn_attn_weight, None) # t1606: "cuda:0 bf16[1, 2048, 4672]"
t1714 = torch.nn.functional.linear(t1605, t_transformer_h_10_mlp_fc_weight, None) # t1714: "cuda:0 bf16[1, 2048, 18176]"
# t1714 = ltorch.linear(t1605, t_transformer_h_10_mlp_fc_weight, None) # t1714: "cuda:0 bf16[1, 2048, 18176]"
# t1714 = prims.linear(t1605, t_transformer_h_10_mlp_fc_weight, None) # t1714: "cuda:0 bf16[1, 2048, 18176]"
t1612 = torch.reshape(t1606, (1, 2048, 1, 73, 64)) # t1612: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1612 = ltorch.reshape(t1606, (1, 2048, 1, 73, 64)) # t1612: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1612 = prims.reshape(t1606, (1, 2048, 1, 73, 64)) # t1612: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t1606
t1618 = torch.permute(t1612, (0, 2, 3, 1, 4)) # t1618: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1618 = ltorch.permute(t1612, (0, 2, 3, 1, 4)) # t1618: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1618 = prims.transpose(t1612, (0, 2, 3, 1, 4)) # t1618: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t1612
(t1619, t1620, t1621) = torch.split(t1618, (71, 1, 1), 2)
# (t1619, t1620, t1621) = ltorch.split(t1618, (71, 1, 1), 2)
# t1619 = prims.slice_prim(t1618, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t1619: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1620 = prims.slice_prim(t1618, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t1620: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t1621 = prims.slice_prim(t1618, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t1621: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t1618
t1627 = Tensor.expand(t1620, (1, 1, 71, 2048, 64)) # t1627: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1627 = ltorch.expand(t1620, (1, 1, 71, 2048, 64)) # t1627: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1627 = prims.broadcast_in_dim(t1620, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1627: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1620
t1633 = Tensor.expand(t1621, (1, 1, 71, 2048, 64)) # t1633: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1633 = ltorch.expand(t1621, (1, 1, 71, 2048, 64)) # t1633: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1633 = prims.broadcast_in_dim(t1621, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1633: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1621
t1639 = torch.reshape(t1619, (1, 71, 2048, 64)) # t1639: "cuda:0 bf16[1, 71, 2048, 64]"
# t1639 = ltorch.reshape(t1619, (1, 71, 2048, 64)) # t1639: "cuda:0 bf16[1, 71, 2048, 64]"
# t1639 = prims.reshape(t1619, (1, 71, 2048, 64)) # t1639: "cuda:0 bf16[1, 71, 2048, 64]"
del t1619
t1645 = torch.reshape(t1627, (1, 71, 2048, 64)) # t1645: "cuda:0 bf16[1, 71, 2048, 64]"
# t1645 = ltorch.reshape(t1627, (1, 71, 2048, 64)) # t1645: "cuda:0 bf16[1, 71, 2048, 64]"
# t1645 = prims.reshape(t1627, (1, 71, 2048, 64)) # t1645: "cuda:0 bf16[1, 71, 2048, 64]"
del t1627
t1651 = torch.reshape(t1633, (1, 71, 2048, 64)) # t1651: "cuda:0 bf16[1, 71, 2048, 64]"
# t1651 = ltorch.reshape(t1633, (1, 71, 2048, 64)) # t1651: "cuda:0 bf16[1, 71, 2048, 64]"
# t1651 = prims.reshape(t1633, (1, 71, 2048, 64)) # t1651: "cuda:0 bf16[1, 71, 2048, 64]"
del t1633
t1652 = torch_slice_prim_impl(t1639, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1652: "cuda:0 bf16[1, 71, 2048, 64]"
t1653 = torch_slice_prim_impl(t1652, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1653: "cuda:0 bf16[1, 71, 2048, 32]"
t1654 = torch_slice_prim_impl(t1652, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1654: "cuda:0 bf16[1, 71, 2048, 32]"
t1674 = torch_slice_prim_impl(t1645, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1674: "cuda:0 bf16[1, 71, 2048, 64]"
t1675 = torch_slice_prim_impl(t1674, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1675: "cuda:0 bf16[1, 71, 2048, 32]"
t1676 = torch_slice_prim_impl(t1674, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1676: "cuda:0 bf16[1, 71, 2048, 32]"
t1696 = torch_slice_prim_impl(t1639, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1696: "cuda:0 bf16[1, 71, 2048, 0]"
del t1639
t1699 = torch_slice_prim_impl(t1645, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1699: "cuda:0 bf16[1, 71, 2048, 0]"
del t1645
[t1698, t1701, t1733] = nvFusion21(t1652, t1653, t1654, t1674, t1675, t1676, t1696, t1699, t1714, t61, t66)
# t1655 = prims.convert_element_type(t1654, dtypes.float32) # t1655: "cuda:0 f32[1, 71, 2048, 32]"
# t1656 = prims.neg(t1655) # t1656: "cuda:0 f32[1, 71, 2048, 32]"
# t1657 = prims.convert_element_type(t1656, dtypes.bfloat16) # t1657: "cuda:0 bf16[1, 71, 2048, 32]"
# t1659 = prims.cat((t1657, t1653), -1) # t1659: "cuda:0 bf16[1, 71, 2048, 64]"
# t1661 = prims.convert_element_type(t1652, dtypes.float32) # t1661: "cuda:0 f32[1, 71, 2048, 64]"
# t1663 = prims.mul(t1661, t61) # t1663: "cuda:0 f32[1, 71, 2048, 64]"
# t1666 = prims.convert_element_type(t1659, dtypes.float32) # t1666: "cuda:0 f32[1, 71, 2048, 64]"
# t1668 = prims.mul(t1666, t66) # t1668: "cuda:0 f32[1, 71, 2048, 64]"
# t1672 = prims.add(t1663, t1668) # t1672: "cuda:0 f32[1, 71, 2048, 64]"
# t1673 = prims.convert_element_type(t1672, dtypes.bfloat16) # t1673: "cuda:0 bf16[1, 71, 2048, 64]"
# t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: "cuda:0 f32[1, 71, 2048, 32]"
# t1678 = prims.neg(t1677) # t1678: "cuda:0 f32[1, 71, 2048, 32]"
# t1679 = prims.convert_element_type(t1678, dtypes.bfloat16) # t1679: "cuda:0 bf16[1, 71, 2048, 32]"
# t1681 = prims.cat((t1679, t1675), -1) # t1681: "cuda:0 bf16[1, 71, 2048, 64]"
# t1683 = prims.convert_element_type(t1674, dtypes.float32) # t1683: "cuda:0 f32[1, 71, 2048, 64]"
# t1685 = prims.mul(t1683, t61) # t1685: "cuda:0 f32[1, 71, 2048, 64]"
# t1688 = prims.convert_element_type(t1681, dtypes.float32) # t1688: "cuda:0 f32[1, 71, 2048, 64]"
# t1690 = prims.mul(t1688, t66) # t1690: "cuda:0 f32[1, 71, 2048, 64]"
# t1694 = prims.add(t1685, t1690) # t1694: "cuda:0 f32[1, 71, 2048, 64]"
# t1695 = prims.convert_element_type(t1694, dtypes.bfloat16) # t1695: "cuda:0 bf16[1, 71, 2048, 64]"
# t1698 = prims.cat((t1673, t1696), -1) # t1698: "cuda:0 bf16[1, 71, 2048, 64]"
# t1701 = prims.cat((t1695, t1699), -1) # t1701: "cuda:0 bf16[1, 71, 2048, 64]"
# t1715 = prims.convert_element_type(t1714, dtypes.float32) # t1715: "cuda:0 f32[1, 2048, 18176]"
# t1717 = prims.div(t1715, 1.4142135623730951) # t1717: "cuda:0 f32[1, 2048, 18176]"
# t1720 = prims.erf(t1717) # t1720: "cuda:0 f32[1, 2048, 18176]"
# t1724 = prims.mul(0.5, t1720) # t1724: "cuda:0 f32[1, 2048, 18176]"
# t1728 = prims.add(0.5, t1724) # t1728: "cuda:0 f32[1, 2048, 18176]"
# t1732 = prims.mul(t1715, t1728) # t1732: "cuda:0 f32[1, 2048, 18176]"
# t1733 = prims.convert_element_type(t1732, dtypes.bfloat16) # t1733: "cuda:0 bf16[1, 2048, 18176]"
del t1652, t1653, t1654, t1674, t1675, t1676, t1696, t1699
(t1702, t1703, t1704, t1705) = cudnn_sdpa_fwd(t1698, t1701, t1651, None, 0.0, True, scale=0.125)
t1734 = torch.nn.functional.linear(t1733, t_transformer_h_10_mlp_proj_weight, None) # t1734: "cuda:0 bf16[1, 2048, 4544]"
# t1734 = ltorch.linear(t1733, t_transformer_h_10_mlp_proj_weight, None) # t1734: "cuda:0 bf16[1, 2048, 4544]"
# t1734 = prims.linear(t1733, t_transformer_h_10_mlp_proj_weight, None) # t1734: "cuda:0 bf16[1, 2048, 4544]"
t1708 = torch.permute(t1702, (0, 2, 1, 3)) # t1708: "cuda:0 bf16[1, 2048, 71, 64]"
# t1708 = ltorch.permute(t1702, (0, 2, 1, 3)) # t1708: "cuda:0 bf16[1, 2048, 71, 64]"
# t1708 = prims.transpose(t1702, (0, 2, 1, 3)) # t1708: "cuda:0 bf16[1, 2048, 71, 64]"
t1712 = torch.reshape(t1708, (1, 2048, 4544)) # t1712: "cuda:0 bf16[1, 2048, 4544]"
# t1712 = ltorch.reshape(t1708, (1, 2048, 4544)) # t1712: "cuda:0 bf16[1, 2048, 4544]"
# t1712 = prims.reshape(t1708, (1, 2048, 4544)) # t1712: "cuda:0 bf16[1, 2048, 4544]"
del t1708
t1713 = torch.nn.functional.linear(t1712, t_transformer_h_10_attn_proj_weight, None) # t1713: "cuda:0 bf16[1, 2048, 4544]"
# t1713 = ltorch.linear(t1712, t_transformer_h_10_attn_proj_weight, None) # t1713: "cuda:0 bf16[1, 2048, 4544]"
# t1713 = prims.linear(t1712, t_transformer_h_10_attn_proj_weight, None) # t1713: "cuda:0 bf16[1, 2048, 4544]"
t5575 = torch.unsqueeze(t_transformer_h_11_norm_1_weight, 0) # t5575: "cuda:0 bf16[1, 4544]"
# t5575 = ltorch.unsqueeze(t_transformer_h_11_norm_1_weight, 0) # t5575: "cuda:0 bf16[1, 4544]"
# t5575 = prims.broadcast_in_dim(t_transformer_h_11_norm_1_weight, [1, 4544], [1]) # t5575: "cuda:0 bf16[1, 4544]"
t5576 = torch.unsqueeze(t5575, 1) # t5576: "cuda:0 bf16[1, 1, 4544]"
# t5576 = ltorch.unsqueeze(t5575, 1) # t5576: "cuda:0 bf16[1, 1, 4544]"
# t5576 = prims.broadcast_in_dim(t5575, [1, 1, 4544], [0, 2]) # t5576: "cuda:0 bf16[1, 1, 4544]"
del t5575
t1760 = Tensor.expand(t5576, (1, 2048, 4544)) # t1760: "cuda:0 bf16[1, 2048, 4544]"
# t1760 = ltorch.expand(t5576, (1, 2048, 4544)) # t1760: "cuda:0 bf16[1, 2048, 4544]"
# t1760 = prims.broadcast_in_dim(t5576, (1, 2048, 4544), (0, 1, 2)) # t1760: "cuda:0 bf16[1, 2048, 4544]"
del t5576
t5578 = torch.unsqueeze(t_transformer_h_11_norm_1_bias, 0) # t5578: "cuda:0 bf16[1, 4544]"
# t5578 = ltorch.unsqueeze(t_transformer_h_11_norm_1_bias, 0) # t5578: "cuda:0 bf16[1, 4544]"
# t5578 = prims.broadcast_in_dim(t_transformer_h_11_norm_1_bias, [1, 4544], [1]) # t5578: "cuda:0 bf16[1, 4544]"
t5579 = torch.unsqueeze(t5578, 1) # t5579: "cuda:0 bf16[1, 1, 4544]"
# t5579 = ltorch.unsqueeze(t5578, 1) # t5579: "cuda:0 bf16[1, 1, 4544]"
# t5579 = prims.broadcast_in_dim(t5578, [1, 1, 4544], [0, 2]) # t5579: "cuda:0 bf16[1, 1, 4544]"
del t5578
t1763 = Tensor.expand(t5579, (1, 2048, 4544)) # t1763: "cuda:0 bf16[1, 2048, 4544]"
# t1763 = ltorch.expand(t5579, (1, 2048, 4544)) # t1763: "cuda:0 bf16[1, 2048, 4544]"
# t1763 = prims.broadcast_in_dim(t5579, (1, 2048, 4544), (0, 1, 2)) # t1763: "cuda:0 bf16[1, 2048, 4544]"
del t5579
[t1742, t1749, t1754, t1766] = nvFusion22(t1581, t1713, t1734, t1760, t1763)
# t1740 = prims.convert_element_type(t1581, dtypes.float32) # t1740: "cuda:0 f32[1, 2048, 4544]"
# t1735 = prims.convert_element_type(t1734, dtypes.float32) # t1735: "cuda:0 f32[1, 2048, 4544]"
# t1736 = prims.convert_element_type(t1713, dtypes.float32) # t1736: "cuda:0 f32[1, 2048, 4544]"
# t1737 = prims.add(t1735, t1736) # t1737: "cuda:0 f32[1, 2048, 4544]"
# t1741 = prims.add(t1737, t1740) # t1741: "cuda:0 f32[1, 2048, 4544]"
# t1742 = prims.convert_element_type(t1741, dtypes.bfloat16) # t1742: "cuda:0 bf16[1, 2048, 4544]"
# (t1748, t1749) = prims.var_mean(t1741, (2,), correction=0)
# t1750 = prims.broadcast_in_dim(t1748, [1, 2048, 1], [0, 1]) # t1750: "cuda:0 f32[1, 2048, 1]"
# t1751 = prims.broadcast_in_dim(t1749, [1, 2048, 1], [0, 1]) # t1751: "cuda:0 f32[1, 2048, 1]"
# t1753 = prims.add(t1750, 1e-05) # t1753: "cuda:0 f32[1, 2048, 1]"
# t1754 = prims.rsqrt(t1753) # t1754: "cuda:0 f32[1, 2048, 1]"
# t1755 = prims.broadcast_in_dim(t1751, (1, 2048, 4544), (0, 1, 2)) # t1755: "cuda:0 f32[1, 2048, 4544]"
# t1757 = prims.sub(t1741, t1755) # t1757: "cuda:0 f32[1, 2048, 4544]"
# t1758 = prims.broadcast_in_dim(t1754, (1, 2048, 4544), (0, 1, 2)) # t1758: "cuda:0 f32[1, 2048, 4544]"
# t1759 = prims.mul(t1757, t1758) # t1759: "cuda:0 f32[1, 2048, 4544]"
# t1761 = prims.convert_element_type(t1760, dtypes.float32) # t1761: "cuda:0 f32[1, 2048, 4544]"
# t1762 = prims.mul(t1759, t1761) # t1762: "cuda:0 f32[1, 2048, 4544]"
# t1764 = prims.convert_element_type(t1763, dtypes.float32) # t1764: "cuda:0 f32[1, 2048, 4544]"
# t1765 = prims.add(t1762, t1764) # t1765: "cuda:0 f32[1, 2048, 4544]"
# t1766 = prims.convert_element_type(t1765, dtypes.bfloat16) # t1766: "cuda:0 bf16[1, 2048, 4544]"
del t1763
t1875 = torch.nn.functional.linear(t1766, t_transformer_h_11_mlp_fc_weight, None) # t1875: "cuda:0 bf16[1, 2048, 18176]"
# t1875 = ltorch.linear(t1766, t_transformer_h_11_mlp_fc_weight, None) # t1875: "cuda:0 bf16[1, 2048, 18176]"
# t1875 = prims.linear(t1766, t_transformer_h_11_mlp_fc_weight, None) # t1875: "cuda:0 bf16[1, 2048, 18176]"
t1767 = torch.nn.functional.linear(t1766, t_transformer_h_11_attn_attn_weight, None) # t1767: "cuda:0 bf16[1, 2048, 4672]"
# t1767 = ltorch.linear(t1766, t_transformer_h_11_attn_attn_weight, None) # t1767: "cuda:0 bf16[1, 2048, 4672]"
# t1767 = prims.linear(t1766, t_transformer_h_11_attn_attn_weight, None) # t1767: "cuda:0 bf16[1, 2048, 4672]"
t1773 = torch.reshape(t1767, (1, 2048, 1, 73, 64)) # t1773: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1773 = ltorch.reshape(t1767, (1, 2048, 1, 73, 64)) # t1773: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1773 = prims.reshape(t1767, (1, 2048, 1, 73, 64)) # t1773: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t1767
t1779 = torch.permute(t1773, (0, 2, 3, 1, 4)) # t1779: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1779 = ltorch.permute(t1773, (0, 2, 3, 1, 4)) # t1779: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1779 = prims.transpose(t1773, (0, 2, 3, 1, 4)) # t1779: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t1773
(t1780, t1781, t1782) = torch.split(t1779, (71, 1, 1), 2)
# (t1780, t1781, t1782) = ltorch.split(t1779, (71, 1, 1), 2)
# t1780 = prims.slice_prim(t1779, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t1780: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1781 = prims.slice_prim(t1779, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t1781: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t1782 = prims.slice_prim(t1779, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t1782: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t1779
t1788 = Tensor.expand(t1781, (1, 1, 71, 2048, 64)) # t1788: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1788 = ltorch.expand(t1781, (1, 1, 71, 2048, 64)) # t1788: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1788 = prims.broadcast_in_dim(t1781, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1788: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1781
t1794 = Tensor.expand(t1782, (1, 1, 71, 2048, 64)) # t1794: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1794 = ltorch.expand(t1782, (1, 1, 71, 2048, 64)) # t1794: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1794 = prims.broadcast_in_dim(t1782, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1794: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1782
t1800 = torch.reshape(t1780, (1, 71, 2048, 64)) # t1800: "cuda:0 bf16[1, 71, 2048, 64]"
# t1800 = ltorch.reshape(t1780, (1, 71, 2048, 64)) # t1800: "cuda:0 bf16[1, 71, 2048, 64]"
# t1800 = prims.reshape(t1780, (1, 71, 2048, 64)) # t1800: "cuda:0 bf16[1, 71, 2048, 64]"
del t1780
t1806 = torch.reshape(t1788, (1, 71, 2048, 64)) # t1806: "cuda:0 bf16[1, 71, 2048, 64]"
# t1806 = ltorch.reshape(t1788, (1, 71, 2048, 64)) # t1806: "cuda:0 bf16[1, 71, 2048, 64]"
# t1806 = prims.reshape(t1788, (1, 71, 2048, 64)) # t1806: "cuda:0 bf16[1, 71, 2048, 64]"
del t1788
t1812 = torch.reshape(t1794, (1, 71, 2048, 64)) # t1812: "cuda:0 bf16[1, 71, 2048, 64]"
# t1812 = ltorch.reshape(t1794, (1, 71, 2048, 64)) # t1812: "cuda:0 bf16[1, 71, 2048, 64]"
# t1812 = prims.reshape(t1794, (1, 71, 2048, 64)) # t1812: "cuda:0 bf16[1, 71, 2048, 64]"
del t1794
t1813 = torch_slice_prim_impl(t1800, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1813: "cuda:0 bf16[1, 71, 2048, 64]"
t1814 = torch_slice_prim_impl(t1813, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1814: "cuda:0 bf16[1, 71, 2048, 32]"
t1815 = torch_slice_prim_impl(t1813, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1815: "cuda:0 bf16[1, 71, 2048, 32]"
t1835 = torch_slice_prim_impl(t1806, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1835: "cuda:0 bf16[1, 71, 2048, 64]"
t1836 = torch_slice_prim_impl(t1835, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1836: "cuda:0 bf16[1, 71, 2048, 32]"
t1837 = torch_slice_prim_impl(t1835, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1837: "cuda:0 bf16[1, 71, 2048, 32]"
t1857 = torch_slice_prim_impl(t1800, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1857: "cuda:0 bf16[1, 71, 2048, 0]"
del t1800
t1860 = torch_slice_prim_impl(t1806, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t1860: "cuda:0 bf16[1, 71, 2048, 0]"
del t1806
[t1859, t1862, t1894] = nvFusion23(t1813, t1814, t1815, t1835, t1836, t1837, t1857, t1860, t1875, t61, t66)
# t1876 = prims.convert_element_type(t1875, dtypes.float32) # t1876: "cuda:0 f32[1, 2048, 18176]"
# t1878 = prims.div(t1876, 1.4142135623730951) # t1878: "cuda:0 f32[1, 2048, 18176]"
# t1881 = prims.erf(t1878) # t1881: "cuda:0 f32[1, 2048, 18176]"
# t1885 = prims.mul(0.5, t1881) # t1885: "cuda:0 f32[1, 2048, 18176]"
# t1889 = prims.add(0.5, t1885) # t1889: "cuda:0 f32[1, 2048, 18176]"
# t1893 = prims.mul(t1876, t1889) # t1893: "cuda:0 f32[1, 2048, 18176]"
# t1894 = prims.convert_element_type(t1893, dtypes.bfloat16) # t1894: "cuda:0 bf16[1, 2048, 18176]"
# t1816 = prims.convert_element_type(t1815, dtypes.float32) # t1816: "cuda:0 f32[1, 71, 2048, 32]"
# t1817 = prims.neg(t1816) # t1817: "cuda:0 f32[1, 71, 2048, 32]"
# t1818 = prims.convert_element_type(t1817, dtypes.bfloat16) # t1818: "cuda:0 bf16[1, 71, 2048, 32]"
# t1820 = prims.cat((t1818, t1814), -1) # t1820: "cuda:0 bf16[1, 71, 2048, 64]"
# t1822 = prims.convert_element_type(t1813, dtypes.float32) # t1822: "cuda:0 f32[1, 71, 2048, 64]"
# t1824 = prims.mul(t1822, t61) # t1824: "cuda:0 f32[1, 71, 2048, 64]"
# t1827 = prims.convert_element_type(t1820, dtypes.float32) # t1827: "cuda:0 f32[1, 71, 2048, 64]"
# t1829 = prims.mul(t1827, t66) # t1829: "cuda:0 f32[1, 71, 2048, 64]"
# t1833 = prims.add(t1824, t1829) # t1833: "cuda:0 f32[1, 71, 2048, 64]"
# t1834 = prims.convert_element_type(t1833, dtypes.bfloat16) # t1834: "cuda:0 bf16[1, 71, 2048, 64]"
# t1838 = prims.convert_element_type(t1837, dtypes.float32) # t1838: "cuda:0 f32[1, 71, 2048, 32]"
# t1839 = prims.neg(t1838) # t1839: "cuda:0 f32[1, 71, 2048, 32]"
# t1840 = prims.convert_element_type(t1839, dtypes.bfloat16) # t1840: "cuda:0 bf16[1, 71, 2048, 32]"
# t1842 = prims.cat((t1840, t1836), -1) # t1842: "cuda:0 bf16[1, 71, 2048, 64]"
# t1844 = prims.convert_element_type(t1835, dtypes.float32) # t1844: "cuda:0 f32[1, 71, 2048, 64]"
# t1846 = prims.mul(t1844, t61) # t1846: "cuda:0 f32[1, 71, 2048, 64]"
# t1849 = prims.convert_element_type(t1842, dtypes.float32) # t1849: "cuda:0 f32[1, 71, 2048, 64]"
# t1851 = prims.mul(t1849, t66) # t1851: "cuda:0 f32[1, 71, 2048, 64]"
# t1855 = prims.add(t1846, t1851) # t1855: "cuda:0 f32[1, 71, 2048, 64]"
# t1856 = prims.convert_element_type(t1855, dtypes.bfloat16) # t1856: "cuda:0 bf16[1, 71, 2048, 64]"
# t1859 = prims.cat((t1834, t1857), -1) # t1859: "cuda:0 bf16[1, 71, 2048, 64]"
# t1862 = prims.cat((t1856, t1860), -1) # t1862: "cuda:0 bf16[1, 71, 2048, 64]"
del t1813, t1814, t1815, t1835, t1836, t1837, t1857, t1860
t1895 = torch.nn.functional.linear(t1894, t_transformer_h_11_mlp_proj_weight, None) # t1895: "cuda:0 bf16[1, 2048, 4544]"
# t1895 = ltorch.linear(t1894, t_transformer_h_11_mlp_proj_weight, None) # t1895: "cuda:0 bf16[1, 2048, 4544]"
# t1895 = prims.linear(t1894, t_transformer_h_11_mlp_proj_weight, None) # t1895: "cuda:0 bf16[1, 2048, 4544]"
(t1863, t1864, t1865, t1866) = cudnn_sdpa_fwd(t1859, t1862, t1812, None, 0.0, True, scale=0.125)
t1869 = torch.permute(t1863, (0, 2, 1, 3)) # t1869: "cuda:0 bf16[1, 2048, 71, 64]"
# t1869 = ltorch.permute(t1863, (0, 2, 1, 3)) # t1869: "cuda:0 bf16[1, 2048, 71, 64]"
# t1869 = prims.transpose(t1863, (0, 2, 1, 3)) # t1869: "cuda:0 bf16[1, 2048, 71, 64]"
t1873 = torch.reshape(t1869, (1, 2048, 4544)) # t1873: "cuda:0 bf16[1, 2048, 4544]"
# t1873 = ltorch.reshape(t1869, (1, 2048, 4544)) # t1873: "cuda:0 bf16[1, 2048, 4544]"
# t1873 = prims.reshape(t1869, (1, 2048, 4544)) # t1873: "cuda:0 bf16[1, 2048, 4544]"
del t1869
t1874 = torch.nn.functional.linear(t1873, t_transformer_h_11_attn_proj_weight, None) # t1874: "cuda:0 bf16[1, 2048, 4544]"
# t1874 = ltorch.linear(t1873, t_transformer_h_11_attn_proj_weight, None) # t1874: "cuda:0 bf16[1, 2048, 4544]"
# t1874 = prims.linear(t1873, t_transformer_h_11_attn_proj_weight, None) # t1874: "cuda:0 bf16[1, 2048, 4544]"
t5601 = torch.unsqueeze(t_transformer_h_12_norm_1_weight, 0) # t5601: "cuda:0 bf16[1, 4544]"
# t5601 = ltorch.unsqueeze(t_transformer_h_12_norm_1_weight, 0) # t5601: "cuda:0 bf16[1, 4544]"
# t5601 = prims.broadcast_in_dim(t_transformer_h_12_norm_1_weight, [1, 4544], [1]) # t5601: "cuda:0 bf16[1, 4544]"
t5602 = torch.unsqueeze(t5601, 1) # t5602: "cuda:0 bf16[1, 1, 4544]"
# t5602 = ltorch.unsqueeze(t5601, 1) # t5602: "cuda:0 bf16[1, 1, 4544]"
# t5602 = prims.broadcast_in_dim(t5601, [1, 1, 4544], [0, 2]) # t5602: "cuda:0 bf16[1, 1, 4544]"
del t5601
t1921 = Tensor.expand(t5602, (1, 2048, 4544)) # t1921: "cuda:0 bf16[1, 2048, 4544]"
# t1921 = ltorch.expand(t5602, (1, 2048, 4544)) # t1921: "cuda:0 bf16[1, 2048, 4544]"
# t1921 = prims.broadcast_in_dim(t5602, (1, 2048, 4544), (0, 1, 2)) # t1921: "cuda:0 bf16[1, 2048, 4544]"
del t5602
t5604 = torch.unsqueeze(t_transformer_h_12_norm_1_bias, 0) # t5604: "cuda:0 bf16[1, 4544]"
# t5604 = ltorch.unsqueeze(t_transformer_h_12_norm_1_bias, 0) # t5604: "cuda:0 bf16[1, 4544]"
# t5604 = prims.broadcast_in_dim(t_transformer_h_12_norm_1_bias, [1, 4544], [1]) # t5604: "cuda:0 bf16[1, 4544]"
t5605 = torch.unsqueeze(t5604, 1) # t5605: "cuda:0 bf16[1, 1, 4544]"
# t5605 = ltorch.unsqueeze(t5604, 1) # t5605: "cuda:0 bf16[1, 1, 4544]"
# t5605 = prims.broadcast_in_dim(t5604, [1, 1, 4544], [0, 2]) # t5605: "cuda:0 bf16[1, 1, 4544]"
del t5604
t1924 = Tensor.expand(t5605, (1, 2048, 4544)) # t1924: "cuda:0 bf16[1, 2048, 4544]"
# t1924 = ltorch.expand(t5605, (1, 2048, 4544)) # t1924: "cuda:0 bf16[1, 2048, 4544]"
# t1924 = prims.broadcast_in_dim(t5605, (1, 2048, 4544), (0, 1, 2)) # t1924: "cuda:0 bf16[1, 2048, 4544]"
del t5605
[t1903, t1910, t1915, t1927] = nvFusion24(t1742, t1874, t1895, t1921, t1924)
# t1901 = prims.convert_element_type(t1742, dtypes.float32) # t1901: "cuda:0 f32[1, 2048, 4544]"
# t1896 = prims.convert_element_type(t1895, dtypes.float32) # t1896: "cuda:0 f32[1, 2048, 4544]"
# t1897 = prims.convert_element_type(t1874, dtypes.float32) # t1897: "cuda:0 f32[1, 2048, 4544]"
# t1898 = prims.add(t1896, t1897) # t1898: "cuda:0 f32[1, 2048, 4544]"
# t1902 = prims.add(t1898, t1901) # t1902: "cuda:0 f32[1, 2048, 4544]"
# t1903 = prims.convert_element_type(t1902, dtypes.bfloat16) # t1903: "cuda:0 bf16[1, 2048, 4544]"
# (t1909, t1910) = prims.var_mean(t1902, (2,), correction=0)
# t1911 = prims.broadcast_in_dim(t1909, [1, 2048, 1], [0, 1]) # t1911: "cuda:0 f32[1, 2048, 1]"
# t1912 = prims.broadcast_in_dim(t1910, [1, 2048, 1], [0, 1]) # t1912: "cuda:0 f32[1, 2048, 1]"
# t1914 = prims.add(t1911, 1e-05) # t1914: "cuda:0 f32[1, 2048, 1]"
# t1915 = prims.rsqrt(t1914) # t1915: "cuda:0 f32[1, 2048, 1]"
# t1916 = prims.broadcast_in_dim(t1912, (1, 2048, 4544), (0, 1, 2)) # t1916: "cuda:0 f32[1, 2048, 4544]"
# t1918 = prims.sub(t1902, t1916) # t1918: "cuda:0 f32[1, 2048, 4544]"
# t1919 = prims.broadcast_in_dim(t1915, (1, 2048, 4544), (0, 1, 2)) # t1919: "cuda:0 f32[1, 2048, 4544]"
# t1920 = prims.mul(t1918, t1919) # t1920: "cuda:0 f32[1, 2048, 4544]"
# t1922 = prims.convert_element_type(t1921, dtypes.float32) # t1922: "cuda:0 f32[1, 2048, 4544]"
# t1923 = prims.mul(t1920, t1922) # t1923: "cuda:0 f32[1, 2048, 4544]"
# t1925 = prims.convert_element_type(t1924, dtypes.float32) # t1925: "cuda:0 f32[1, 2048, 4544]"
# t1926 = prims.add(t1923, t1925) # t1926: "cuda:0 f32[1, 2048, 4544]"
# t1927 = prims.convert_element_type(t1926, dtypes.bfloat16) # t1927: "cuda:0 bf16[1, 2048, 4544]"
del t1924
t2036 = torch.nn.functional.linear(t1927, t_transformer_h_12_mlp_fc_weight, None) # t2036: "cuda:0 bf16[1, 2048, 18176]"
# t2036 = ltorch.linear(t1927, t_transformer_h_12_mlp_fc_weight, None) # t2036: "cuda:0 bf16[1, 2048, 18176]"
# t2036 = prims.linear(t1927, t_transformer_h_12_mlp_fc_weight, None) # t2036: "cuda:0 bf16[1, 2048, 18176]"
t1928 = torch.nn.functional.linear(t1927, t_transformer_h_12_attn_attn_weight, None) # t1928: "cuda:0 bf16[1, 2048, 4672]"
# t1928 = ltorch.linear(t1927, t_transformer_h_12_attn_attn_weight, None) # t1928: "cuda:0 bf16[1, 2048, 4672]"
# t1928 = prims.linear(t1927, t_transformer_h_12_attn_attn_weight, None) # t1928: "cuda:0 bf16[1, 2048, 4672]"
t1934 = torch.reshape(t1928, (1, 2048, 1, 73, 64)) # t1934: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1934 = ltorch.reshape(t1928, (1, 2048, 1, 73, 64)) # t1934: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t1934 = prims.reshape(t1928, (1, 2048, 1, 73, 64)) # t1934: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t1928
t1940 = torch.permute(t1934, (0, 2, 3, 1, 4)) # t1940: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1940 = ltorch.permute(t1934, (0, 2, 3, 1, 4)) # t1940: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t1940 = prims.transpose(t1934, (0, 2, 3, 1, 4)) # t1940: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t1934
(t1941, t1942, t1943) = torch.split(t1940, (71, 1, 1), 2)
# (t1941, t1942, t1943) = ltorch.split(t1940, (71, 1, 1), 2)
# t1941 = prims.slice_prim(t1940, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t1941: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1942 = prims.slice_prim(t1940, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t1942: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t1943 = prims.slice_prim(t1940, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t1943: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t1940
t1949 = Tensor.expand(t1942, (1, 1, 71, 2048, 64)) # t1949: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1949 = ltorch.expand(t1942, (1, 1, 71, 2048, 64)) # t1949: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1949 = prims.broadcast_in_dim(t1942, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1949: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1942
t1955 = Tensor.expand(t1943, (1, 1, 71, 2048, 64)) # t1955: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1955 = ltorch.expand(t1943, (1, 1, 71, 2048, 64)) # t1955: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t1955 = prims.broadcast_in_dim(t1943, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t1955: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t1943
t1961 = torch.reshape(t1941, (1, 71, 2048, 64)) # t1961: "cuda:0 bf16[1, 71, 2048, 64]"
# t1961 = ltorch.reshape(t1941, (1, 71, 2048, 64)) # t1961: "cuda:0 bf16[1, 71, 2048, 64]"
# t1961 = prims.reshape(t1941, (1, 71, 2048, 64)) # t1961: "cuda:0 bf16[1, 71, 2048, 64]"
del t1941
t1967 = torch.reshape(t1949, (1, 71, 2048, 64)) # t1967: "cuda:0 bf16[1, 71, 2048, 64]"
# t1967 = ltorch.reshape(t1949, (1, 71, 2048, 64)) # t1967: "cuda:0 bf16[1, 71, 2048, 64]"
# t1967 = prims.reshape(t1949, (1, 71, 2048, 64)) # t1967: "cuda:0 bf16[1, 71, 2048, 64]"
del t1949
t1973 = torch.reshape(t1955, (1, 71, 2048, 64)) # t1973: "cuda:0 bf16[1, 71, 2048, 64]"
# t1973 = ltorch.reshape(t1955, (1, 71, 2048, 64)) # t1973: "cuda:0 bf16[1, 71, 2048, 64]"
# t1973 = prims.reshape(t1955, (1, 71, 2048, 64)) # t1973: "cuda:0 bf16[1, 71, 2048, 64]"
del t1955
t1974 = torch_slice_prim_impl(t1961, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1974: "cuda:0 bf16[1, 71, 2048, 64]"
t1975 = torch_slice_prim_impl(t1974, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1975: "cuda:0 bf16[1, 71, 2048, 32]"
t1976 = torch_slice_prim_impl(t1974, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1976: "cuda:0 bf16[1, 71, 2048, 32]"
t1996 = torch_slice_prim_impl(t1967, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1996: "cuda:0 bf16[1, 71, 2048, 64]"
t1997 = torch_slice_prim_impl(t1996, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t1997: "cuda:0 bf16[1, 71, 2048, 32]"
t1998 = torch_slice_prim_impl(t1996, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t1998: "cuda:0 bf16[1, 71, 2048, 32]"
t2018 = torch_slice_prim_impl(t1961, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2018: "cuda:0 bf16[1, 71, 2048, 0]"
del t1961
t2021 = torch_slice_prim_impl(t1967, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2021: "cuda:0 bf16[1, 71, 2048, 0]"
del t1967
[t2020, t2023, t2055] = nvFusion25(t1974, t1975, t1976, t1996, t1997, t1998, t2018, t2021, t2036, t61, t66)
# t2037 = prims.convert_element_type(t2036, dtypes.float32) # t2037: "cuda:0 f32[1, 2048, 18176]"
# t2039 = prims.div(t2037, 1.4142135623730951) # t2039: "cuda:0 f32[1, 2048, 18176]"
# t2042 = prims.erf(t2039) # t2042: "cuda:0 f32[1, 2048, 18176]"
# t2046 = prims.mul(0.5, t2042) # t2046: "cuda:0 f32[1, 2048, 18176]"
# t2050 = prims.add(0.5, t2046) # t2050: "cuda:0 f32[1, 2048, 18176]"
# t2054 = prims.mul(t2037, t2050) # t2054: "cuda:0 f32[1, 2048, 18176]"
# t2055 = prims.convert_element_type(t2054, dtypes.bfloat16) # t2055: "cuda:0 bf16[1, 2048, 18176]"
# t1977 = prims.convert_element_type(t1976, dtypes.float32) # t1977: "cuda:0 f32[1, 71, 2048, 32]"
# t1978 = prims.neg(t1977) # t1978: "cuda:0 f32[1, 71, 2048, 32]"
# t1979 = prims.convert_element_type(t1978, dtypes.bfloat16) # t1979: "cuda:0 bf16[1, 71, 2048, 32]"
# t1981 = prims.cat((t1979, t1975), -1) # t1981: "cuda:0 bf16[1, 71, 2048, 64]"
# t1983 = prims.convert_element_type(t1974, dtypes.float32) # t1983: "cuda:0 f32[1, 71, 2048, 64]"
# t1985 = prims.mul(t1983, t61) # t1985: "cuda:0 f32[1, 71, 2048, 64]"
# t1988 = prims.convert_element_type(t1981, dtypes.float32) # t1988: "cuda:0 f32[1, 71, 2048, 64]"
# t1990 = prims.mul(t1988, t66) # t1990: "cuda:0 f32[1, 71, 2048, 64]"
# t1994 = prims.add(t1985, t1990) # t1994: "cuda:0 f32[1, 71, 2048, 64]"
# t1995 = prims.convert_element_type(t1994, dtypes.bfloat16) # t1995: "cuda:0 bf16[1, 71, 2048, 64]"
# t1999 = prims.convert_element_type(t1998, dtypes.float32) # t1999: "cuda:0 f32[1, 71, 2048, 32]"
# t2000 = prims.neg(t1999) # t2000: "cuda:0 f32[1, 71, 2048, 32]"
# t2001 = prims.convert_element_type(t2000, dtypes.bfloat16) # t2001: "cuda:0 bf16[1, 71, 2048, 32]"
# t2003 = prims.cat((t2001, t1997), -1) # t2003: "cuda:0 bf16[1, 71, 2048, 64]"
# t2005 = prims.convert_element_type(t1996, dtypes.float32) # t2005: "cuda:0 f32[1, 71, 2048, 64]"
# t2007 = prims.mul(t2005, t61) # t2007: "cuda:0 f32[1, 71, 2048, 64]"
# t2010 = prims.convert_element_type(t2003, dtypes.float32) # t2010: "cuda:0 f32[1, 71, 2048, 64]"
# t2012 = prims.mul(t2010, t66) # t2012: "cuda:0 f32[1, 71, 2048, 64]"
# t2016 = prims.add(t2007, t2012) # t2016: "cuda:0 f32[1, 71, 2048, 64]"
# t2017 = prims.convert_element_type(t2016, dtypes.bfloat16) # t2017: "cuda:0 bf16[1, 71, 2048, 64]"
# t2020 = prims.cat((t1995, t2018), -1) # t2020: "cuda:0 bf16[1, 71, 2048, 64]"
# t2023 = prims.cat((t2017, t2021), -1) # t2023: "cuda:0 bf16[1, 71, 2048, 64]"
del t1974, t1975, t1976, t1996, t1997, t1998, t2018, t2021
t2056 = torch.nn.functional.linear(t2055, t_transformer_h_12_mlp_proj_weight, None) # t2056: "cuda:0 bf16[1, 2048, 4544]"
# t2056 = ltorch.linear(t2055, t_transformer_h_12_mlp_proj_weight, None) # t2056: "cuda:0 bf16[1, 2048, 4544]"
# t2056 = prims.linear(t2055, t_transformer_h_12_mlp_proj_weight, None) # t2056: "cuda:0 bf16[1, 2048, 4544]"
(t2024, t2025, t2026, t2027) = cudnn_sdpa_fwd(t2020, t2023, t1973, None, 0.0, True, scale=0.125)
t2030 = torch.permute(t2024, (0, 2, 1, 3)) # t2030: "cuda:0 bf16[1, 2048, 71, 64]"
# t2030 = ltorch.permute(t2024, (0, 2, 1, 3)) # t2030: "cuda:0 bf16[1, 2048, 71, 64]"
# t2030 = prims.transpose(t2024, (0, 2, 1, 3)) # t2030: "cuda:0 bf16[1, 2048, 71, 64]"
t2034 = torch.reshape(t2030, (1, 2048, 4544)) # t2034: "cuda:0 bf16[1, 2048, 4544]"
# t2034 = ltorch.reshape(t2030, (1, 2048, 4544)) # t2034: "cuda:0 bf16[1, 2048, 4544]"
# t2034 = prims.reshape(t2030, (1, 2048, 4544)) # t2034: "cuda:0 bf16[1, 2048, 4544]"
del t2030
t2035 = torch.nn.functional.linear(t2034, t_transformer_h_12_attn_proj_weight, None) # t2035: "cuda:0 bf16[1, 2048, 4544]"
# t2035 = ltorch.linear(t2034, t_transformer_h_12_attn_proj_weight, None) # t2035: "cuda:0 bf16[1, 2048, 4544]"
# t2035 = prims.linear(t2034, t_transformer_h_12_attn_proj_weight, None) # t2035: "cuda:0 bf16[1, 2048, 4544]"
t5627 = torch.unsqueeze(t_transformer_h_13_norm_1_weight, 0) # t5627: "cuda:0 bf16[1, 4544]"
# t5627 = ltorch.unsqueeze(t_transformer_h_13_norm_1_weight, 0) # t5627: "cuda:0 bf16[1, 4544]"
# t5627 = prims.broadcast_in_dim(t_transformer_h_13_norm_1_weight, [1, 4544], [1]) # t5627: "cuda:0 bf16[1, 4544]"
t5628 = torch.unsqueeze(t5627, 1) # t5628: "cuda:0 bf16[1, 1, 4544]"
# t5628 = ltorch.unsqueeze(t5627, 1) # t5628: "cuda:0 bf16[1, 1, 4544]"
# t5628 = prims.broadcast_in_dim(t5627, [1, 1, 4544], [0, 2]) # t5628: "cuda:0 bf16[1, 1, 4544]"
del t5627
t2082 = Tensor.expand(t5628, (1, 2048, 4544)) # t2082: "cuda:0 bf16[1, 2048, 4544]"
# t2082 = ltorch.expand(t5628, (1, 2048, 4544)) # t2082: "cuda:0 bf16[1, 2048, 4544]"
# t2082 = prims.broadcast_in_dim(t5628, (1, 2048, 4544), (0, 1, 2)) # t2082: "cuda:0 bf16[1, 2048, 4544]"
del t5628
t5630 = torch.unsqueeze(t_transformer_h_13_norm_1_bias, 0) # t5630: "cuda:0 bf16[1, 4544]"
# t5630 = ltorch.unsqueeze(t_transformer_h_13_norm_1_bias, 0) # t5630: "cuda:0 bf16[1, 4544]"
# t5630 = prims.broadcast_in_dim(t_transformer_h_13_norm_1_bias, [1, 4544], [1]) # t5630: "cuda:0 bf16[1, 4544]"
t5631 = torch.unsqueeze(t5630, 1) # t5631: "cuda:0 bf16[1, 1, 4544]"
# t5631 = ltorch.unsqueeze(t5630, 1) # t5631: "cuda:0 bf16[1, 1, 4544]"
# t5631 = prims.broadcast_in_dim(t5630, [1, 1, 4544], [0, 2]) # t5631: "cuda:0 bf16[1, 1, 4544]"
del t5630
t2085 = Tensor.expand(t5631, (1, 2048, 4544)) # t2085: "cuda:0 bf16[1, 2048, 4544]"
# t2085 = ltorch.expand(t5631, (1, 2048, 4544)) # t2085: "cuda:0 bf16[1, 2048, 4544]"
# t2085 = prims.broadcast_in_dim(t5631, (1, 2048, 4544), (0, 1, 2)) # t2085: "cuda:0 bf16[1, 2048, 4544]"
del t5631
[t2064, t2071, t2076, t2088] = nvFusion26(t1903, t2035, t2056, t2082, t2085)
# t2062 = prims.convert_element_type(t1903, dtypes.float32) # t2062: "cuda:0 f32[1, 2048, 4544]"
# t2057 = prims.convert_element_type(t2056, dtypes.float32) # t2057: "cuda:0 f32[1, 2048, 4544]"
# t2058 = prims.convert_element_type(t2035, dtypes.float32) # t2058: "cuda:0 f32[1, 2048, 4544]"
# t2059 = prims.add(t2057, t2058) # t2059: "cuda:0 f32[1, 2048, 4544]"
# t2063 = prims.add(t2059, t2062) # t2063: "cuda:0 f32[1, 2048, 4544]"
# t2064 = prims.convert_element_type(t2063, dtypes.bfloat16) # t2064: "cuda:0 bf16[1, 2048, 4544]"
# (t2070, t2071) = prims.var_mean(t2063, (2,), correction=0)
# t2072 = prims.broadcast_in_dim(t2070, [1, 2048, 1], [0, 1]) # t2072: "cuda:0 f32[1, 2048, 1]"
# t2073 = prims.broadcast_in_dim(t2071, [1, 2048, 1], [0, 1]) # t2073: "cuda:0 f32[1, 2048, 1]"
# t2075 = prims.add(t2072, 1e-05) # t2075: "cuda:0 f32[1, 2048, 1]"
# t2076 = prims.rsqrt(t2075) # t2076: "cuda:0 f32[1, 2048, 1]"
# t2077 = prims.broadcast_in_dim(t2073, (1, 2048, 4544), (0, 1, 2)) # t2077: "cuda:0 f32[1, 2048, 4544]"
# t2079 = prims.sub(t2063, t2077) # t2079: "cuda:0 f32[1, 2048, 4544]"
# t2080 = prims.broadcast_in_dim(t2076, (1, 2048, 4544), (0, 1, 2)) # t2080: "cuda:0 f32[1, 2048, 4544]"
# t2081 = prims.mul(t2079, t2080) # t2081: "cuda:0 f32[1, 2048, 4544]"
# t2083 = prims.convert_element_type(t2082, dtypes.float32) # t2083: "cuda:0 f32[1, 2048, 4544]"
# t2084 = prims.mul(t2081, t2083) # t2084: "cuda:0 f32[1, 2048, 4544]"
# t2086 = prims.convert_element_type(t2085, dtypes.float32) # t2086: "cuda:0 f32[1, 2048, 4544]"
# t2087 = prims.add(t2084, t2086) # t2087: "cuda:0 f32[1, 2048, 4544]"
# t2088 = prims.convert_element_type(t2087, dtypes.bfloat16) # t2088: "cuda:0 bf16[1, 2048, 4544]"
del t2085
t2197 = torch.nn.functional.linear(t2088, t_transformer_h_13_mlp_fc_weight, None) # t2197: "cuda:0 bf16[1, 2048, 18176]"
# t2197 = ltorch.linear(t2088, t_transformer_h_13_mlp_fc_weight, None) # t2197: "cuda:0 bf16[1, 2048, 18176]"
# t2197 = prims.linear(t2088, t_transformer_h_13_mlp_fc_weight, None) # t2197: "cuda:0 bf16[1, 2048, 18176]"
t2089 = torch.nn.functional.linear(t2088, t_transformer_h_13_attn_attn_weight, None) # t2089: "cuda:0 bf16[1, 2048, 4672]"
# t2089 = ltorch.linear(t2088, t_transformer_h_13_attn_attn_weight, None) # t2089: "cuda:0 bf16[1, 2048, 4672]"
# t2089 = prims.linear(t2088, t_transformer_h_13_attn_attn_weight, None) # t2089: "cuda:0 bf16[1, 2048, 4672]"
t2095 = torch.reshape(t2089, (1, 2048, 1, 73, 64)) # t2095: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2095 = ltorch.reshape(t2089, (1, 2048, 1, 73, 64)) # t2095: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2095 = prims.reshape(t2089, (1, 2048, 1, 73, 64)) # t2095: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t2089
t2101 = torch.permute(t2095, (0, 2, 3, 1, 4)) # t2101: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2101 = ltorch.permute(t2095, (0, 2, 3, 1, 4)) # t2101: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2101 = prims.transpose(t2095, (0, 2, 3, 1, 4)) # t2101: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t2095
(t2102, t2103, t2104) = torch.split(t2101, (71, 1, 1), 2)
# (t2102, t2103, t2104) = ltorch.split(t2101, (71, 1, 1), 2)
# t2102 = prims.slice_prim(t2101, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t2102: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2103 = prims.slice_prim(t2101, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t2103: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t2104 = prims.slice_prim(t2101, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t2104: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t2101
t2110 = Tensor.expand(t2103, (1, 1, 71, 2048, 64)) # t2110: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2110 = ltorch.expand(t2103, (1, 1, 71, 2048, 64)) # t2110: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2110 = prims.broadcast_in_dim(t2103, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2110: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2103
t2116 = Tensor.expand(t2104, (1, 1, 71, 2048, 64)) # t2116: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2116 = ltorch.expand(t2104, (1, 1, 71, 2048, 64)) # t2116: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2116 = prims.broadcast_in_dim(t2104, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2116: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2104
t2122 = torch.reshape(t2102, (1, 71, 2048, 64)) # t2122: "cuda:0 bf16[1, 71, 2048, 64]"
# t2122 = ltorch.reshape(t2102, (1, 71, 2048, 64)) # t2122: "cuda:0 bf16[1, 71, 2048, 64]"
# t2122 = prims.reshape(t2102, (1, 71, 2048, 64)) # t2122: "cuda:0 bf16[1, 71, 2048, 64]"
del t2102
t2128 = torch.reshape(t2110, (1, 71, 2048, 64)) # t2128: "cuda:0 bf16[1, 71, 2048, 64]"
# t2128 = ltorch.reshape(t2110, (1, 71, 2048, 64)) # t2128: "cuda:0 bf16[1, 71, 2048, 64]"
# t2128 = prims.reshape(t2110, (1, 71, 2048, 64)) # t2128: "cuda:0 bf16[1, 71, 2048, 64]"
del t2110
t2134 = torch.reshape(t2116, (1, 71, 2048, 64)) # t2134: "cuda:0 bf16[1, 71, 2048, 64]"
# t2134 = ltorch.reshape(t2116, (1, 71, 2048, 64)) # t2134: "cuda:0 bf16[1, 71, 2048, 64]"
# t2134 = prims.reshape(t2116, (1, 71, 2048, 64)) # t2134: "cuda:0 bf16[1, 71, 2048, 64]"
del t2116
t2135 = torch_slice_prim_impl(t2122, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2135: "cuda:0 bf16[1, 71, 2048, 64]"
t2136 = torch_slice_prim_impl(t2135, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2136: "cuda:0 bf16[1, 71, 2048, 32]"
t2137 = torch_slice_prim_impl(t2135, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2137: "cuda:0 bf16[1, 71, 2048, 32]"
t2157 = torch_slice_prim_impl(t2128, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2157: "cuda:0 bf16[1, 71, 2048, 64]"
t2158 = torch_slice_prim_impl(t2157, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2158: "cuda:0 bf16[1, 71, 2048, 32]"
t2159 = torch_slice_prim_impl(t2157, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2159: "cuda:0 bf16[1, 71, 2048, 32]"
t2179 = torch_slice_prim_impl(t2122, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2179: "cuda:0 bf16[1, 71, 2048, 0]"
del t2122
t2182 = torch_slice_prim_impl(t2128, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2182: "cuda:0 bf16[1, 71, 2048, 0]"
del t2128
[t2181, t2184, t2216] = nvFusion27(t2135, t2136, t2137, t2157, t2158, t2159, t2179, t2182, t2197, t61, t66)
# t2198 = prims.convert_element_type(t2197, dtypes.float32) # t2198: "cuda:0 f32[1, 2048, 18176]"
# t2200 = prims.div(t2198, 1.4142135623730951) # t2200: "cuda:0 f32[1, 2048, 18176]"
# t2203 = prims.erf(t2200) # t2203: "cuda:0 f32[1, 2048, 18176]"
# t2207 = prims.mul(0.5, t2203) # t2207: "cuda:0 f32[1, 2048, 18176]"
# t2211 = prims.add(0.5, t2207) # t2211: "cuda:0 f32[1, 2048, 18176]"
# t2215 = prims.mul(t2198, t2211) # t2215: "cuda:0 f32[1, 2048, 18176]"
# t2216 = prims.convert_element_type(t2215, dtypes.bfloat16) # t2216: "cuda:0 bf16[1, 2048, 18176]"
# t2138 = prims.convert_element_type(t2137, dtypes.float32) # t2138: "cuda:0 f32[1, 71, 2048, 32]"
# t2139 = prims.neg(t2138) # t2139: "cuda:0 f32[1, 71, 2048, 32]"
# t2140 = prims.convert_element_type(t2139, dtypes.bfloat16) # t2140: "cuda:0 bf16[1, 71, 2048, 32]"
# t2142 = prims.cat((t2140, t2136), -1) # t2142: "cuda:0 bf16[1, 71, 2048, 64]"
# t2144 = prims.convert_element_type(t2135, dtypes.float32) # t2144: "cuda:0 f32[1, 71, 2048, 64]"
# t2146 = prims.mul(t2144, t61) # t2146: "cuda:0 f32[1, 71, 2048, 64]"
# t2149 = prims.convert_element_type(t2142, dtypes.float32) # t2149: "cuda:0 f32[1, 71, 2048, 64]"
# t2151 = prims.mul(t2149, t66) # t2151: "cuda:0 f32[1, 71, 2048, 64]"
# t2155 = prims.add(t2146, t2151) # t2155: "cuda:0 f32[1, 71, 2048, 64]"
# t2156 = prims.convert_element_type(t2155, dtypes.bfloat16) # t2156: "cuda:0 bf16[1, 71, 2048, 64]"
# t2160 = prims.convert_element_type(t2159, dtypes.float32) # t2160: "cuda:0 f32[1, 71, 2048, 32]"
# t2161 = prims.neg(t2160) # t2161: "cuda:0 f32[1, 71, 2048, 32]"
# t2162 = prims.convert_element_type(t2161, dtypes.bfloat16) # t2162: "cuda:0 bf16[1, 71, 2048, 32]"
# t2164 = prims.cat((t2162, t2158), -1) # t2164: "cuda:0 bf16[1, 71, 2048, 64]"
# t2166 = prims.convert_element_type(t2157, dtypes.float32) # t2166: "cuda:0 f32[1, 71, 2048, 64]"
# t2168 = prims.mul(t2166, t61) # t2168: "cuda:0 f32[1, 71, 2048, 64]"
# t2171 = prims.convert_element_type(t2164, dtypes.float32) # t2171: "cuda:0 f32[1, 71, 2048, 64]"
# t2173 = prims.mul(t2171, t66) # t2173: "cuda:0 f32[1, 71, 2048, 64]"
# t2177 = prims.add(t2168, t2173) # t2177: "cuda:0 f32[1, 71, 2048, 64]"
# t2178 = prims.convert_element_type(t2177, dtypes.bfloat16) # t2178: "cuda:0 bf16[1, 71, 2048, 64]"
# t2181 = prims.cat((t2156, t2179), -1) # t2181: "cuda:0 bf16[1, 71, 2048, 64]"
# t2184 = prims.cat((t2178, t2182), -1) # t2184: "cuda:0 bf16[1, 71, 2048, 64]"
del t2135, t2136, t2137, t2157, t2158, t2159, t2179, t2182
t2217 = torch.nn.functional.linear(t2216, t_transformer_h_13_mlp_proj_weight, None) # t2217: "cuda:0 bf16[1, 2048, 4544]"
# t2217 = ltorch.linear(t2216, t_transformer_h_13_mlp_proj_weight, None) # t2217: "cuda:0 bf16[1, 2048, 4544]"
# t2217 = prims.linear(t2216, t_transformer_h_13_mlp_proj_weight, None) # t2217: "cuda:0 bf16[1, 2048, 4544]"
(t2185, t2186, t2187, t2188) = cudnn_sdpa_fwd(t2181, t2184, t2134, None, 0.0, True, scale=0.125)
t2191 = torch.permute(t2185, (0, 2, 1, 3)) # t2191: "cuda:0 bf16[1, 2048, 71, 64]"
# t2191 = ltorch.permute(t2185, (0, 2, 1, 3)) # t2191: "cuda:0 bf16[1, 2048, 71, 64]"
# t2191 = prims.transpose(t2185, (0, 2, 1, 3)) # t2191: "cuda:0 bf16[1, 2048, 71, 64]"
t2195 = torch.reshape(t2191, (1, 2048, 4544)) # t2195: "cuda:0 bf16[1, 2048, 4544]"
# t2195 = ltorch.reshape(t2191, (1, 2048, 4544)) # t2195: "cuda:0 bf16[1, 2048, 4544]"
# t2195 = prims.reshape(t2191, (1, 2048, 4544)) # t2195: "cuda:0 bf16[1, 2048, 4544]"
del t2191
t2196 = torch.nn.functional.linear(t2195, t_transformer_h_13_attn_proj_weight, None) # t2196: "cuda:0 bf16[1, 2048, 4544]"
# t2196 = ltorch.linear(t2195, t_transformer_h_13_attn_proj_weight, None) # t2196: "cuda:0 bf16[1, 2048, 4544]"
# t2196 = prims.linear(t2195, t_transformer_h_13_attn_proj_weight, None) # t2196: "cuda:0 bf16[1, 2048, 4544]"
t5653 = torch.unsqueeze(t_transformer_h_14_norm_1_weight, 0) # t5653: "cuda:0 bf16[1, 4544]"
# t5653 = ltorch.unsqueeze(t_transformer_h_14_norm_1_weight, 0) # t5653: "cuda:0 bf16[1, 4544]"
# t5653 = prims.broadcast_in_dim(t_transformer_h_14_norm_1_weight, [1, 4544], [1]) # t5653: "cuda:0 bf16[1, 4544]"
t5654 = torch.unsqueeze(t5653, 1) # t5654: "cuda:0 bf16[1, 1, 4544]"
# t5654 = ltorch.unsqueeze(t5653, 1) # t5654: "cuda:0 bf16[1, 1, 4544]"
# t5654 = prims.broadcast_in_dim(t5653, [1, 1, 4544], [0, 2]) # t5654: "cuda:0 bf16[1, 1, 4544]"
del t5653
t2243 = Tensor.expand(t5654, (1, 2048, 4544)) # t2243: "cuda:0 bf16[1, 2048, 4544]"
# t2243 = ltorch.expand(t5654, (1, 2048, 4544)) # t2243: "cuda:0 bf16[1, 2048, 4544]"
# t2243 = prims.broadcast_in_dim(t5654, (1, 2048, 4544), (0, 1, 2)) # t2243: "cuda:0 bf16[1, 2048, 4544]"
del t5654
t5656 = torch.unsqueeze(t_transformer_h_14_norm_1_bias, 0) # t5656: "cuda:0 bf16[1, 4544]"
# t5656 = ltorch.unsqueeze(t_transformer_h_14_norm_1_bias, 0) # t5656: "cuda:0 bf16[1, 4544]"
# t5656 = prims.broadcast_in_dim(t_transformer_h_14_norm_1_bias, [1, 4544], [1]) # t5656: "cuda:0 bf16[1, 4544]"
t5657 = torch.unsqueeze(t5656, 1) # t5657: "cuda:0 bf16[1, 1, 4544]"
# t5657 = ltorch.unsqueeze(t5656, 1) # t5657: "cuda:0 bf16[1, 1, 4544]"
# t5657 = prims.broadcast_in_dim(t5656, [1, 1, 4544], [0, 2]) # t5657: "cuda:0 bf16[1, 1, 4544]"
del t5656
t2246 = Tensor.expand(t5657, (1, 2048, 4544)) # t2246: "cuda:0 bf16[1, 2048, 4544]"
# t2246 = ltorch.expand(t5657, (1, 2048, 4544)) # t2246: "cuda:0 bf16[1, 2048, 4544]"
# t2246 = prims.broadcast_in_dim(t5657, (1, 2048, 4544), (0, 1, 2)) # t2246: "cuda:0 bf16[1, 2048, 4544]"
del t5657
[t2225, t2232, t2237, t2249] = nvFusion28(t2064, t2196, t2217, t2243, t2246)
# t2223 = prims.convert_element_type(t2064, dtypes.float32) # t2223: "cuda:0 f32[1, 2048, 4544]"
# t2218 = prims.convert_element_type(t2217, dtypes.float32) # t2218: "cuda:0 f32[1, 2048, 4544]"
# t2219 = prims.convert_element_type(t2196, dtypes.float32) # t2219: "cuda:0 f32[1, 2048, 4544]"
# t2220 = prims.add(t2218, t2219) # t2220: "cuda:0 f32[1, 2048, 4544]"
# t2224 = prims.add(t2220, t2223) # t2224: "cuda:0 f32[1, 2048, 4544]"
# t2225 = prims.convert_element_type(t2224, dtypes.bfloat16) # t2225: "cuda:0 bf16[1, 2048, 4544]"
# (t2231, t2232) = prims.var_mean(t2224, (2,), correction=0)
# t2233 = prims.broadcast_in_dim(t2231, [1, 2048, 1], [0, 1]) # t2233: "cuda:0 f32[1, 2048, 1]"
# t2234 = prims.broadcast_in_dim(t2232, [1, 2048, 1], [0, 1]) # t2234: "cuda:0 f32[1, 2048, 1]"
# t2236 = prims.add(t2233, 1e-05) # t2236: "cuda:0 f32[1, 2048, 1]"
# t2237 = prims.rsqrt(t2236) # t2237: "cuda:0 f32[1, 2048, 1]"
# t2238 = prims.broadcast_in_dim(t2234, (1, 2048, 4544), (0, 1, 2)) # t2238: "cuda:0 f32[1, 2048, 4544]"
# t2240 = prims.sub(t2224, t2238) # t2240: "cuda:0 f32[1, 2048, 4544]"
# t2241 = prims.broadcast_in_dim(t2237, (1, 2048, 4544), (0, 1, 2)) # t2241: "cuda:0 f32[1, 2048, 4544]"
# t2242 = prims.mul(t2240, t2241) # t2242: "cuda:0 f32[1, 2048, 4544]"
# t2244 = prims.convert_element_type(t2243, dtypes.float32) # t2244: "cuda:0 f32[1, 2048, 4544]"
# t2245 = prims.mul(t2242, t2244) # t2245: "cuda:0 f32[1, 2048, 4544]"
# t2247 = prims.convert_element_type(t2246, dtypes.float32) # t2247: "cuda:0 f32[1, 2048, 4544]"
# t2248 = prims.add(t2245, t2247) # t2248: "cuda:0 f32[1, 2048, 4544]"
# t2249 = prims.convert_element_type(t2248, dtypes.bfloat16) # t2249: "cuda:0 bf16[1, 2048, 4544]"
del t2246
t2358 = torch.nn.functional.linear(t2249, t_transformer_h_14_mlp_fc_weight, None) # t2358: "cuda:0 bf16[1, 2048, 18176]"
# t2358 = ltorch.linear(t2249, t_transformer_h_14_mlp_fc_weight, None) # t2358: "cuda:0 bf16[1, 2048, 18176]"
# t2358 = prims.linear(t2249, t_transformer_h_14_mlp_fc_weight, None) # t2358: "cuda:0 bf16[1, 2048, 18176]"
t2250 = torch.nn.functional.linear(t2249, t_transformer_h_14_attn_attn_weight, None) # t2250: "cuda:0 bf16[1, 2048, 4672]"
# t2250 = ltorch.linear(t2249, t_transformer_h_14_attn_attn_weight, None) # t2250: "cuda:0 bf16[1, 2048, 4672]"
# t2250 = prims.linear(t2249, t_transformer_h_14_attn_attn_weight, None) # t2250: "cuda:0 bf16[1, 2048, 4672]"
t2256 = torch.reshape(t2250, (1, 2048, 1, 73, 64)) # t2256: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2256 = ltorch.reshape(t2250, (1, 2048, 1, 73, 64)) # t2256: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2256 = prims.reshape(t2250, (1, 2048, 1, 73, 64)) # t2256: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t2250
t2262 = torch.permute(t2256, (0, 2, 3, 1, 4)) # t2262: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2262 = ltorch.permute(t2256, (0, 2, 3, 1, 4)) # t2262: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2262 = prims.transpose(t2256, (0, 2, 3, 1, 4)) # t2262: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t2256
(t2263, t2264, t2265) = torch.split(t2262, (71, 1, 1), 2)
# (t2263, t2264, t2265) = ltorch.split(t2262, (71, 1, 1), 2)
# t2263 = prims.slice_prim(t2262, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t2263: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2264 = prims.slice_prim(t2262, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t2264: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t2265 = prims.slice_prim(t2262, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t2265: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t2262
t2271 = Tensor.expand(t2264, (1, 1, 71, 2048, 64)) # t2271: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2271 = ltorch.expand(t2264, (1, 1, 71, 2048, 64)) # t2271: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2271 = prims.broadcast_in_dim(t2264, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2271: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2264
t2277 = Tensor.expand(t2265, (1, 1, 71, 2048, 64)) # t2277: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2277 = ltorch.expand(t2265, (1, 1, 71, 2048, 64)) # t2277: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2277 = prims.broadcast_in_dim(t2265, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2277: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2265
t2283 = torch.reshape(t2263, (1, 71, 2048, 64)) # t2283: "cuda:0 bf16[1, 71, 2048, 64]"
# t2283 = ltorch.reshape(t2263, (1, 71, 2048, 64)) # t2283: "cuda:0 bf16[1, 71, 2048, 64]"
# t2283 = prims.reshape(t2263, (1, 71, 2048, 64)) # t2283: "cuda:0 bf16[1, 71, 2048, 64]"
del t2263
t2289 = torch.reshape(t2271, (1, 71, 2048, 64)) # t2289: "cuda:0 bf16[1, 71, 2048, 64]"
# t2289 = ltorch.reshape(t2271, (1, 71, 2048, 64)) # t2289: "cuda:0 bf16[1, 71, 2048, 64]"
# t2289 = prims.reshape(t2271, (1, 71, 2048, 64)) # t2289: "cuda:0 bf16[1, 71, 2048, 64]"
del t2271
t2295 = torch.reshape(t2277, (1, 71, 2048, 64)) # t2295: "cuda:0 bf16[1, 71, 2048, 64]"
# t2295 = ltorch.reshape(t2277, (1, 71, 2048, 64)) # t2295: "cuda:0 bf16[1, 71, 2048, 64]"
# t2295 = prims.reshape(t2277, (1, 71, 2048, 64)) # t2295: "cuda:0 bf16[1, 71, 2048, 64]"
del t2277
t2296 = torch_slice_prim_impl(t2283, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2296: "cuda:0 bf16[1, 71, 2048, 64]"
t2297 = torch_slice_prim_impl(t2296, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2297: "cuda:0 bf16[1, 71, 2048, 32]"
t2298 = torch_slice_prim_impl(t2296, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2298: "cuda:0 bf16[1, 71, 2048, 32]"
t2318 = torch_slice_prim_impl(t2289, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2318: "cuda:0 bf16[1, 71, 2048, 64]"
t2319 = torch_slice_prim_impl(t2318, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2319: "cuda:0 bf16[1, 71, 2048, 32]"
t2320 = torch_slice_prim_impl(t2318, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2320: "cuda:0 bf16[1, 71, 2048, 32]"
t2340 = torch_slice_prim_impl(t2283, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2340: "cuda:0 bf16[1, 71, 2048, 0]"
del t2283
t2343 = torch_slice_prim_impl(t2289, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2343: "cuda:0 bf16[1, 71, 2048, 0]"
del t2289
[t2342, t2345, t2377] = nvFusion29(t2296, t2297, t2298, t2318, t2319, t2320, t2340, t2343, t2358, t61, t66)
# t2359 = prims.convert_element_type(t2358, dtypes.float32) # t2359: "cuda:0 f32[1, 2048, 18176]"
# t2361 = prims.div(t2359, 1.4142135623730951) # t2361: "cuda:0 f32[1, 2048, 18176]"
# t2364 = prims.erf(t2361) # t2364: "cuda:0 f32[1, 2048, 18176]"
# t2368 = prims.mul(0.5, t2364) # t2368: "cuda:0 f32[1, 2048, 18176]"
# t2372 = prims.add(0.5, t2368) # t2372: "cuda:0 f32[1, 2048, 18176]"
# t2376 = prims.mul(t2359, t2372) # t2376: "cuda:0 f32[1, 2048, 18176]"
# t2377 = prims.convert_element_type(t2376, dtypes.bfloat16) # t2377: "cuda:0 bf16[1, 2048, 18176]"
# t2299 = prims.convert_element_type(t2298, dtypes.float32) # t2299: "cuda:0 f32[1, 71, 2048, 32]"
# t2300 = prims.neg(t2299) # t2300: "cuda:0 f32[1, 71, 2048, 32]"
# t2301 = prims.convert_element_type(t2300, dtypes.bfloat16) # t2301: "cuda:0 bf16[1, 71, 2048, 32]"
# t2303 = prims.cat((t2301, t2297), -1) # t2303: "cuda:0 bf16[1, 71, 2048, 64]"
# t2305 = prims.convert_element_type(t2296, dtypes.float32) # t2305: "cuda:0 f32[1, 71, 2048, 64]"
# t2307 = prims.mul(t2305, t61) # t2307: "cuda:0 f32[1, 71, 2048, 64]"
# t2310 = prims.convert_element_type(t2303, dtypes.float32) # t2310: "cuda:0 f32[1, 71, 2048, 64]"
# t2312 = prims.mul(t2310, t66) # t2312: "cuda:0 f32[1, 71, 2048, 64]"
# t2316 = prims.add(t2307, t2312) # t2316: "cuda:0 f32[1, 71, 2048, 64]"
# t2317 = prims.convert_element_type(t2316, dtypes.bfloat16) # t2317: "cuda:0 bf16[1, 71, 2048, 64]"
# t2321 = prims.convert_element_type(t2320, dtypes.float32) # t2321: "cuda:0 f32[1, 71, 2048, 32]"
# t2322 = prims.neg(t2321) # t2322: "cuda:0 f32[1, 71, 2048, 32]"
# t2323 = prims.convert_element_type(t2322, dtypes.bfloat16) # t2323: "cuda:0 bf16[1, 71, 2048, 32]"
# t2325 = prims.cat((t2323, t2319), -1) # t2325: "cuda:0 bf16[1, 71, 2048, 64]"
# t2327 = prims.convert_element_type(t2318, dtypes.float32) # t2327: "cuda:0 f32[1, 71, 2048, 64]"
# t2329 = prims.mul(t2327, t61) # t2329: "cuda:0 f32[1, 71, 2048, 64]"
# t2332 = prims.convert_element_type(t2325, dtypes.float32) # t2332: "cuda:0 f32[1, 71, 2048, 64]"
# t2334 = prims.mul(t2332, t66) # t2334: "cuda:0 f32[1, 71, 2048, 64]"
# t2338 = prims.add(t2329, t2334) # t2338: "cuda:0 f32[1, 71, 2048, 64]"
# t2339 = prims.convert_element_type(t2338, dtypes.bfloat16) # t2339: "cuda:0 bf16[1, 71, 2048, 64]"
# t2342 = prims.cat((t2317, t2340), -1) # t2342: "cuda:0 bf16[1, 71, 2048, 64]"
# t2345 = prims.cat((t2339, t2343), -1) # t2345: "cuda:0 bf16[1, 71, 2048, 64]"
del t2296, t2297, t2298, t2318, t2319, t2320, t2340, t2343
t2378 = torch.nn.functional.linear(t2377, t_transformer_h_14_mlp_proj_weight, None) # t2378: "cuda:0 bf16[1, 2048, 4544]"
# t2378 = ltorch.linear(t2377, t_transformer_h_14_mlp_proj_weight, None) # t2378: "cuda:0 bf16[1, 2048, 4544]"
# t2378 = prims.linear(t2377, t_transformer_h_14_mlp_proj_weight, None) # t2378: "cuda:0 bf16[1, 2048, 4544]"
(t2346, t2347, t2348, t2349) = cudnn_sdpa_fwd(t2342, t2345, t2295, None, 0.0, True, scale=0.125)
t2352 = torch.permute(t2346, (0, 2, 1, 3)) # t2352: "cuda:0 bf16[1, 2048, 71, 64]"
# t2352 = ltorch.permute(t2346, (0, 2, 1, 3)) # t2352: "cuda:0 bf16[1, 2048, 71, 64]"
# t2352 = prims.transpose(t2346, (0, 2, 1, 3)) # t2352: "cuda:0 bf16[1, 2048, 71, 64]"
t2356 = torch.reshape(t2352, (1, 2048, 4544)) # t2356: "cuda:0 bf16[1, 2048, 4544]"
# t2356 = ltorch.reshape(t2352, (1, 2048, 4544)) # t2356: "cuda:0 bf16[1, 2048, 4544]"
# t2356 = prims.reshape(t2352, (1, 2048, 4544)) # t2356: "cuda:0 bf16[1, 2048, 4544]"
del t2352
t2357 = torch.nn.functional.linear(t2356, t_transformer_h_14_attn_proj_weight, None) # t2357: "cuda:0 bf16[1, 2048, 4544]"
# t2357 = ltorch.linear(t2356, t_transformer_h_14_attn_proj_weight, None) # t2357: "cuda:0 bf16[1, 2048, 4544]"
# t2357 = prims.linear(t2356, t_transformer_h_14_attn_proj_weight, None) # t2357: "cuda:0 bf16[1, 2048, 4544]"
t5679 = torch.unsqueeze(t_transformer_h_15_norm_1_weight, 0) # t5679: "cuda:0 bf16[1, 4544]"
# t5679 = ltorch.unsqueeze(t_transformer_h_15_norm_1_weight, 0) # t5679: "cuda:0 bf16[1, 4544]"
# t5679 = prims.broadcast_in_dim(t_transformer_h_15_norm_1_weight, [1, 4544], [1]) # t5679: "cuda:0 bf16[1, 4544]"
t5680 = torch.unsqueeze(t5679, 1) # t5680: "cuda:0 bf16[1, 1, 4544]"
# t5680 = ltorch.unsqueeze(t5679, 1) # t5680: "cuda:0 bf16[1, 1, 4544]"
# t5680 = prims.broadcast_in_dim(t5679, [1, 1, 4544], [0, 2]) # t5680: "cuda:0 bf16[1, 1, 4544]"
del t5679
t2404 = Tensor.expand(t5680, (1, 2048, 4544)) # t2404: "cuda:0 bf16[1, 2048, 4544]"
# t2404 = ltorch.expand(t5680, (1, 2048, 4544)) # t2404: "cuda:0 bf16[1, 2048, 4544]"
# t2404 = prims.broadcast_in_dim(t5680, (1, 2048, 4544), (0, 1, 2)) # t2404: "cuda:0 bf16[1, 2048, 4544]"
del t5680
t5682 = torch.unsqueeze(t_transformer_h_15_norm_1_bias, 0) # t5682: "cuda:0 bf16[1, 4544]"
# t5682 = ltorch.unsqueeze(t_transformer_h_15_norm_1_bias, 0) # t5682: "cuda:0 bf16[1, 4544]"
# t5682 = prims.broadcast_in_dim(t_transformer_h_15_norm_1_bias, [1, 4544], [1]) # t5682: "cuda:0 bf16[1, 4544]"
t5683 = torch.unsqueeze(t5682, 1) # t5683: "cuda:0 bf16[1, 1, 4544]"
# t5683 = ltorch.unsqueeze(t5682, 1) # t5683: "cuda:0 bf16[1, 1, 4544]"
# t5683 = prims.broadcast_in_dim(t5682, [1, 1, 4544], [0, 2]) # t5683: "cuda:0 bf16[1, 1, 4544]"
del t5682
t2407 = Tensor.expand(t5683, (1, 2048, 4544)) # t2407: "cuda:0 bf16[1, 2048, 4544]"
# t2407 = ltorch.expand(t5683, (1, 2048, 4544)) # t2407: "cuda:0 bf16[1, 2048, 4544]"
# t2407 = prims.broadcast_in_dim(t5683, (1, 2048, 4544), (0, 1, 2)) # t2407: "cuda:0 bf16[1, 2048, 4544]"
del t5683
[t2386, t2393, t2398, t2410] = nvFusion30(t2225, t2357, t2378, t2404, t2407)
# t2384 = prims.convert_element_type(t2225, dtypes.float32) # t2384: "cuda:0 f32[1, 2048, 4544]"
# t2379 = prims.convert_element_type(t2378, dtypes.float32) # t2379: "cuda:0 f32[1, 2048, 4544]"
# t2380 = prims.convert_element_type(t2357, dtypes.float32) # t2380: "cuda:0 f32[1, 2048, 4544]"
# t2381 = prims.add(t2379, t2380) # t2381: "cuda:0 f32[1, 2048, 4544]"
# t2385 = prims.add(t2381, t2384) # t2385: "cuda:0 f32[1, 2048, 4544]"
# t2386 = prims.convert_element_type(t2385, dtypes.bfloat16) # t2386: "cuda:0 bf16[1, 2048, 4544]"
# (t2392, t2393) = prims.var_mean(t2385, (2,), correction=0)
# t2394 = prims.broadcast_in_dim(t2392, [1, 2048, 1], [0, 1]) # t2394: "cuda:0 f32[1, 2048, 1]"
# t2395 = prims.broadcast_in_dim(t2393, [1, 2048, 1], [0, 1]) # t2395: "cuda:0 f32[1, 2048, 1]"
# t2397 = prims.add(t2394, 1e-05) # t2397: "cuda:0 f32[1, 2048, 1]"
# t2398 = prims.rsqrt(t2397) # t2398: "cuda:0 f32[1, 2048, 1]"
# t2399 = prims.broadcast_in_dim(t2395, (1, 2048, 4544), (0, 1, 2)) # t2399: "cuda:0 f32[1, 2048, 4544]"
# t2401 = prims.sub(t2385, t2399) # t2401: "cuda:0 f32[1, 2048, 4544]"
# t2402 = prims.broadcast_in_dim(t2398, (1, 2048, 4544), (0, 1, 2)) # t2402: "cuda:0 f32[1, 2048, 4544]"
# t2403 = prims.mul(t2401, t2402) # t2403: "cuda:0 f32[1, 2048, 4544]"
# t2405 = prims.convert_element_type(t2404, dtypes.float32) # t2405: "cuda:0 f32[1, 2048, 4544]"
# t2406 = prims.mul(t2403, t2405) # t2406: "cuda:0 f32[1, 2048, 4544]"
# t2408 = prims.convert_element_type(t2407, dtypes.float32) # t2408: "cuda:0 f32[1, 2048, 4544]"
# t2409 = prims.add(t2406, t2408) # t2409: "cuda:0 f32[1, 2048, 4544]"
# t2410 = prims.convert_element_type(t2409, dtypes.bfloat16) # t2410: "cuda:0 bf16[1, 2048, 4544]"
del t2407
t2519 = torch.nn.functional.linear(t2410, t_transformer_h_15_mlp_fc_weight, None) # t2519: "cuda:0 bf16[1, 2048, 18176]"
# t2519 = ltorch.linear(t2410, t_transformer_h_15_mlp_fc_weight, None) # t2519: "cuda:0 bf16[1, 2048, 18176]"
# t2519 = prims.linear(t2410, t_transformer_h_15_mlp_fc_weight, None) # t2519: "cuda:0 bf16[1, 2048, 18176]"
t2411 = torch.nn.functional.linear(t2410, t_transformer_h_15_attn_attn_weight, None) # t2411: "cuda:0 bf16[1, 2048, 4672]"
# t2411 = ltorch.linear(t2410, t_transformer_h_15_attn_attn_weight, None) # t2411: "cuda:0 bf16[1, 2048, 4672]"
# t2411 = prims.linear(t2410, t_transformer_h_15_attn_attn_weight, None) # t2411: "cuda:0 bf16[1, 2048, 4672]"
t2417 = torch.reshape(t2411, (1, 2048, 1, 73, 64)) # t2417: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2417 = ltorch.reshape(t2411, (1, 2048, 1, 73, 64)) # t2417: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2417 = prims.reshape(t2411, (1, 2048, 1, 73, 64)) # t2417: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t2411
t2423 = torch.permute(t2417, (0, 2, 3, 1, 4)) # t2423: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2423 = ltorch.permute(t2417, (0, 2, 3, 1, 4)) # t2423: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2423 = prims.transpose(t2417, (0, 2, 3, 1, 4)) # t2423: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t2417
(t2424, t2425, t2426) = torch.split(t2423, (71, 1, 1), 2)
# (t2424, t2425, t2426) = ltorch.split(t2423, (71, 1, 1), 2)
# t2424 = prims.slice_prim(t2423, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t2424: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2425 = prims.slice_prim(t2423, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t2425: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t2426 = prims.slice_prim(t2423, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t2426: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t2423
t2432 = Tensor.expand(t2425, (1, 1, 71, 2048, 64)) # t2432: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2432 = ltorch.expand(t2425, (1, 1, 71, 2048, 64)) # t2432: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2432 = prims.broadcast_in_dim(t2425, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2432: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2425
t2438 = Tensor.expand(t2426, (1, 1, 71, 2048, 64)) # t2438: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2438 = ltorch.expand(t2426, (1, 1, 71, 2048, 64)) # t2438: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2438 = prims.broadcast_in_dim(t2426, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2438: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2426
t2444 = torch.reshape(t2424, (1, 71, 2048, 64)) # t2444: "cuda:0 bf16[1, 71, 2048, 64]"
# t2444 = ltorch.reshape(t2424, (1, 71, 2048, 64)) # t2444: "cuda:0 bf16[1, 71, 2048, 64]"
# t2444 = prims.reshape(t2424, (1, 71, 2048, 64)) # t2444: "cuda:0 bf16[1, 71, 2048, 64]"
del t2424
t2450 = torch.reshape(t2432, (1, 71, 2048, 64)) # t2450: "cuda:0 bf16[1, 71, 2048, 64]"
# t2450 = ltorch.reshape(t2432, (1, 71, 2048, 64)) # t2450: "cuda:0 bf16[1, 71, 2048, 64]"
# t2450 = prims.reshape(t2432, (1, 71, 2048, 64)) # t2450: "cuda:0 bf16[1, 71, 2048, 64]"
del t2432
t2456 = torch.reshape(t2438, (1, 71, 2048, 64)) # t2456: "cuda:0 bf16[1, 71, 2048, 64]"
# t2456 = ltorch.reshape(t2438, (1, 71, 2048, 64)) # t2456: "cuda:0 bf16[1, 71, 2048, 64]"
# t2456 = prims.reshape(t2438, (1, 71, 2048, 64)) # t2456: "cuda:0 bf16[1, 71, 2048, 64]"
del t2438
t2457 = torch_slice_prim_impl(t2444, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2457: "cuda:0 bf16[1, 71, 2048, 64]"
t2458 = torch_slice_prim_impl(t2457, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2458: "cuda:0 bf16[1, 71, 2048, 32]"
t2459 = torch_slice_prim_impl(t2457, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2459: "cuda:0 bf16[1, 71, 2048, 32]"
t2479 = torch_slice_prim_impl(t2450, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2479: "cuda:0 bf16[1, 71, 2048, 64]"
t2480 = torch_slice_prim_impl(t2479, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2480: "cuda:0 bf16[1, 71, 2048, 32]"
t2481 = torch_slice_prim_impl(t2479, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2481: "cuda:0 bf16[1, 71, 2048, 32]"
t2501 = torch_slice_prim_impl(t2444, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2501: "cuda:0 bf16[1, 71, 2048, 0]"
del t2444
t2504 = torch_slice_prim_impl(t2450, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2504: "cuda:0 bf16[1, 71, 2048, 0]"
del t2450
[t2503, t2506, t2538] = nvFusion31(t2457, t2458, t2459, t2479, t2480, t2481, t2501, t2504, t2519, t61, t66)
# t2520 = prims.convert_element_type(t2519, dtypes.float32) # t2520: "cuda:0 f32[1, 2048, 18176]"
# t2522 = prims.div(t2520, 1.4142135623730951) # t2522: "cuda:0 f32[1, 2048, 18176]"
# t2525 = prims.erf(t2522) # t2525: "cuda:0 f32[1, 2048, 18176]"
# t2529 = prims.mul(0.5, t2525) # t2529: "cuda:0 f32[1, 2048, 18176]"
# t2533 = prims.add(0.5, t2529) # t2533: "cuda:0 f32[1, 2048, 18176]"
# t2537 = prims.mul(t2520, t2533) # t2537: "cuda:0 f32[1, 2048, 18176]"
# t2538 = prims.convert_element_type(t2537, dtypes.bfloat16) # t2538: "cuda:0 bf16[1, 2048, 18176]"
# t2460 = prims.convert_element_type(t2459, dtypes.float32) # t2460: "cuda:0 f32[1, 71, 2048, 32]"
# t2461 = prims.neg(t2460) # t2461: "cuda:0 f32[1, 71, 2048, 32]"
# t2462 = prims.convert_element_type(t2461, dtypes.bfloat16) # t2462: "cuda:0 bf16[1, 71, 2048, 32]"
# t2464 = prims.cat((t2462, t2458), -1) # t2464: "cuda:0 bf16[1, 71, 2048, 64]"
# t2466 = prims.convert_element_type(t2457, dtypes.float32) # t2466: "cuda:0 f32[1, 71, 2048, 64]"
# t2468 = prims.mul(t2466, t61) # t2468: "cuda:0 f32[1, 71, 2048, 64]"
# t2471 = prims.convert_element_type(t2464, dtypes.float32) # t2471: "cuda:0 f32[1, 71, 2048, 64]"
# t2473 = prims.mul(t2471, t66) # t2473: "cuda:0 f32[1, 71, 2048, 64]"
# t2477 = prims.add(t2468, t2473) # t2477: "cuda:0 f32[1, 71, 2048, 64]"
# t2478 = prims.convert_element_type(t2477, dtypes.bfloat16) # t2478: "cuda:0 bf16[1, 71, 2048, 64]"
# t2482 = prims.convert_element_type(t2481, dtypes.float32) # t2482: "cuda:0 f32[1, 71, 2048, 32]"
# t2483 = prims.neg(t2482) # t2483: "cuda:0 f32[1, 71, 2048, 32]"
# t2484 = prims.convert_element_type(t2483, dtypes.bfloat16) # t2484: "cuda:0 bf16[1, 71, 2048, 32]"
# t2486 = prims.cat((t2484, t2480), -1) # t2486: "cuda:0 bf16[1, 71, 2048, 64]"
# t2488 = prims.convert_element_type(t2479, dtypes.float32) # t2488: "cuda:0 f32[1, 71, 2048, 64]"
# t2490 = prims.mul(t2488, t61) # t2490: "cuda:0 f32[1, 71, 2048, 64]"
# t2493 = prims.convert_element_type(t2486, dtypes.float32) # t2493: "cuda:0 f32[1, 71, 2048, 64]"
# t2495 = prims.mul(t2493, t66) # t2495: "cuda:0 f32[1, 71, 2048, 64]"
# t2499 = prims.add(t2490, t2495) # t2499: "cuda:0 f32[1, 71, 2048, 64]"
# t2500 = prims.convert_element_type(t2499, dtypes.bfloat16) # t2500: "cuda:0 bf16[1, 71, 2048, 64]"
# t2503 = prims.cat((t2478, t2501), -1) # t2503: "cuda:0 bf16[1, 71, 2048, 64]"
# t2506 = prims.cat((t2500, t2504), -1) # t2506: "cuda:0 bf16[1, 71, 2048, 64]"
del t2457, t2458, t2459, t2479, t2480, t2481, t2501, t2504
t2539 = torch.nn.functional.linear(t2538, t_transformer_h_15_mlp_proj_weight, None) # t2539: "cuda:0 bf16[1, 2048, 4544]"
# t2539 = ltorch.linear(t2538, t_transformer_h_15_mlp_proj_weight, None) # t2539: "cuda:0 bf16[1, 2048, 4544]"
# t2539 = prims.linear(t2538, t_transformer_h_15_mlp_proj_weight, None) # t2539: "cuda:0 bf16[1, 2048, 4544]"
(t2507, t2508, t2509, t2510) = cudnn_sdpa_fwd(t2503, t2506, t2456, None, 0.0, True, scale=0.125)
t2513 = torch.permute(t2507, (0, 2, 1, 3)) # t2513: "cuda:0 bf16[1, 2048, 71, 64]"
# t2513 = ltorch.permute(t2507, (0, 2, 1, 3)) # t2513: "cuda:0 bf16[1, 2048, 71, 64]"
# t2513 = prims.transpose(t2507, (0, 2, 1, 3)) # t2513: "cuda:0 bf16[1, 2048, 71, 64]"
t2517 = torch.reshape(t2513, (1, 2048, 4544)) # t2517: "cuda:0 bf16[1, 2048, 4544]"
# t2517 = ltorch.reshape(t2513, (1, 2048, 4544)) # t2517: "cuda:0 bf16[1, 2048, 4544]"
# t2517 = prims.reshape(t2513, (1, 2048, 4544)) # t2517: "cuda:0 bf16[1, 2048, 4544]"
del t2513
t2518 = torch.nn.functional.linear(t2517, t_transformer_h_15_attn_proj_weight, None) # t2518: "cuda:0 bf16[1, 2048, 4544]"
# t2518 = ltorch.linear(t2517, t_transformer_h_15_attn_proj_weight, None) # t2518: "cuda:0 bf16[1, 2048, 4544]"
# t2518 = prims.linear(t2517, t_transformer_h_15_attn_proj_weight, None) # t2518: "cuda:0 bf16[1, 2048, 4544]"
t5705 = torch.unsqueeze(t_transformer_h_16_norm_1_weight, 0) # t5705: "cuda:0 bf16[1, 4544]"
# t5705 = ltorch.unsqueeze(t_transformer_h_16_norm_1_weight, 0) # t5705: "cuda:0 bf16[1, 4544]"
# t5705 = prims.broadcast_in_dim(t_transformer_h_16_norm_1_weight, [1, 4544], [1]) # t5705: "cuda:0 bf16[1, 4544]"
t5706 = torch.unsqueeze(t5705, 1) # t5706: "cuda:0 bf16[1, 1, 4544]"
# t5706 = ltorch.unsqueeze(t5705, 1) # t5706: "cuda:0 bf16[1, 1, 4544]"
# t5706 = prims.broadcast_in_dim(t5705, [1, 1, 4544], [0, 2]) # t5706: "cuda:0 bf16[1, 1, 4544]"
del t5705
t2565 = Tensor.expand(t5706, (1, 2048, 4544)) # t2565: "cuda:0 bf16[1, 2048, 4544]"
# t2565 = ltorch.expand(t5706, (1, 2048, 4544)) # t2565: "cuda:0 bf16[1, 2048, 4544]"
# t2565 = prims.broadcast_in_dim(t5706, (1, 2048, 4544), (0, 1, 2)) # t2565: "cuda:0 bf16[1, 2048, 4544]"
del t5706
t5708 = torch.unsqueeze(t_transformer_h_16_norm_1_bias, 0) # t5708: "cuda:0 bf16[1, 4544]"
# t5708 = ltorch.unsqueeze(t_transformer_h_16_norm_1_bias, 0) # t5708: "cuda:0 bf16[1, 4544]"
# t5708 = prims.broadcast_in_dim(t_transformer_h_16_norm_1_bias, [1, 4544], [1]) # t5708: "cuda:0 bf16[1, 4544]"
t5709 = torch.unsqueeze(t5708, 1) # t5709: "cuda:0 bf16[1, 1, 4544]"
# t5709 = ltorch.unsqueeze(t5708, 1) # t5709: "cuda:0 bf16[1, 1, 4544]"
# t5709 = prims.broadcast_in_dim(t5708, [1, 1, 4544], [0, 2]) # t5709: "cuda:0 bf16[1, 1, 4544]"
del t5708
t2568 = Tensor.expand(t5709, (1, 2048, 4544)) # t2568: "cuda:0 bf16[1, 2048, 4544]"
# t2568 = ltorch.expand(t5709, (1, 2048, 4544)) # t2568: "cuda:0 bf16[1, 2048, 4544]"
# t2568 = prims.broadcast_in_dim(t5709, (1, 2048, 4544), (0, 1, 2)) # t2568: "cuda:0 bf16[1, 2048, 4544]"
del t5709
[t2547, t2554, t2559, t2571] = nvFusion32(t2386, t2518, t2539, t2565, t2568)
# t2545 = prims.convert_element_type(t2386, dtypes.float32) # t2545: "cuda:0 f32[1, 2048, 4544]"
# t2540 = prims.convert_element_type(t2539, dtypes.float32) # t2540: "cuda:0 f32[1, 2048, 4544]"
# t2541 = prims.convert_element_type(t2518, dtypes.float32) # t2541: "cuda:0 f32[1, 2048, 4544]"
# t2542 = prims.add(t2540, t2541) # t2542: "cuda:0 f32[1, 2048, 4544]"
# t2546 = prims.add(t2542, t2545) # t2546: "cuda:0 f32[1, 2048, 4544]"
# t2547 = prims.convert_element_type(t2546, dtypes.bfloat16) # t2547: "cuda:0 bf16[1, 2048, 4544]"
# (t2553, t2554) = prims.var_mean(t2546, (2,), correction=0)
# t2555 = prims.broadcast_in_dim(t2553, [1, 2048, 1], [0, 1]) # t2555: "cuda:0 f32[1, 2048, 1]"
# t2556 = prims.broadcast_in_dim(t2554, [1, 2048, 1], [0, 1]) # t2556: "cuda:0 f32[1, 2048, 1]"
# t2558 = prims.add(t2555, 1e-05) # t2558: "cuda:0 f32[1, 2048, 1]"
# t2559 = prims.rsqrt(t2558) # t2559: "cuda:0 f32[1, 2048, 1]"
# t2560 = prims.broadcast_in_dim(t2556, (1, 2048, 4544), (0, 1, 2)) # t2560: "cuda:0 f32[1, 2048, 4544]"
# t2562 = prims.sub(t2546, t2560) # t2562: "cuda:0 f32[1, 2048, 4544]"
# t2563 = prims.broadcast_in_dim(t2559, (1, 2048, 4544), (0, 1, 2)) # t2563: "cuda:0 f32[1, 2048, 4544]"
# t2564 = prims.mul(t2562, t2563) # t2564: "cuda:0 f32[1, 2048, 4544]"
# t2566 = prims.convert_element_type(t2565, dtypes.float32) # t2566: "cuda:0 f32[1, 2048, 4544]"
# t2567 = prims.mul(t2564, t2566) # t2567: "cuda:0 f32[1, 2048, 4544]"
# t2569 = prims.convert_element_type(t2568, dtypes.float32) # t2569: "cuda:0 f32[1, 2048, 4544]"
# t2570 = prims.add(t2567, t2569) # t2570: "cuda:0 f32[1, 2048, 4544]"
# t2571 = prims.convert_element_type(t2570, dtypes.bfloat16) # t2571: "cuda:0 bf16[1, 2048, 4544]"
del t2568
t2680 = torch.nn.functional.linear(t2571, t_transformer_h_16_mlp_fc_weight, None) # t2680: "cuda:0 bf16[1, 2048, 18176]"
# t2680 = ltorch.linear(t2571, t_transformer_h_16_mlp_fc_weight, None) # t2680: "cuda:0 bf16[1, 2048, 18176]"
# t2680 = prims.linear(t2571, t_transformer_h_16_mlp_fc_weight, None) # t2680: "cuda:0 bf16[1, 2048, 18176]"
t2572 = torch.nn.functional.linear(t2571, t_transformer_h_16_attn_attn_weight, None) # t2572: "cuda:0 bf16[1, 2048, 4672]"
# t2572 = ltorch.linear(t2571, t_transformer_h_16_attn_attn_weight, None) # t2572: "cuda:0 bf16[1, 2048, 4672]"
# t2572 = prims.linear(t2571, t_transformer_h_16_attn_attn_weight, None) # t2572: "cuda:0 bf16[1, 2048, 4672]"
t2578 = torch.reshape(t2572, (1, 2048, 1, 73, 64)) # t2578: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2578 = ltorch.reshape(t2572, (1, 2048, 1, 73, 64)) # t2578: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2578 = prims.reshape(t2572, (1, 2048, 1, 73, 64)) # t2578: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t2572
t2584 = torch.permute(t2578, (0, 2, 3, 1, 4)) # t2584: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2584 = ltorch.permute(t2578, (0, 2, 3, 1, 4)) # t2584: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2584 = prims.transpose(t2578, (0, 2, 3, 1, 4)) # t2584: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t2578
(t2585, t2586, t2587) = torch.split(t2584, (71, 1, 1), 2)
# (t2585, t2586, t2587) = ltorch.split(t2584, (71, 1, 1), 2)
# t2585 = prims.slice_prim(t2584, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t2585: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2586 = prims.slice_prim(t2584, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t2586: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t2587 = prims.slice_prim(t2584, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t2587: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t2584
t2593 = Tensor.expand(t2586, (1, 1, 71, 2048, 64)) # t2593: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2593 = ltorch.expand(t2586, (1, 1, 71, 2048, 64)) # t2593: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2593 = prims.broadcast_in_dim(t2586, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2593: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2586
t2599 = Tensor.expand(t2587, (1, 1, 71, 2048, 64)) # t2599: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2599 = ltorch.expand(t2587, (1, 1, 71, 2048, 64)) # t2599: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2599 = prims.broadcast_in_dim(t2587, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2599: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2587
t2605 = torch.reshape(t2585, (1, 71, 2048, 64)) # t2605: "cuda:0 bf16[1, 71, 2048, 64]"
# t2605 = ltorch.reshape(t2585, (1, 71, 2048, 64)) # t2605: "cuda:0 bf16[1, 71, 2048, 64]"
# t2605 = prims.reshape(t2585, (1, 71, 2048, 64)) # t2605: "cuda:0 bf16[1, 71, 2048, 64]"
del t2585
t2611 = torch.reshape(t2593, (1, 71, 2048, 64)) # t2611: "cuda:0 bf16[1, 71, 2048, 64]"
# t2611 = ltorch.reshape(t2593, (1, 71, 2048, 64)) # t2611: "cuda:0 bf16[1, 71, 2048, 64]"
# t2611 = prims.reshape(t2593, (1, 71, 2048, 64)) # t2611: "cuda:0 bf16[1, 71, 2048, 64]"
del t2593
t2617 = torch.reshape(t2599, (1, 71, 2048, 64)) # t2617: "cuda:0 bf16[1, 71, 2048, 64]"
# t2617 = ltorch.reshape(t2599, (1, 71, 2048, 64)) # t2617: "cuda:0 bf16[1, 71, 2048, 64]"
# t2617 = prims.reshape(t2599, (1, 71, 2048, 64)) # t2617: "cuda:0 bf16[1, 71, 2048, 64]"
del t2599
t2618 = torch_slice_prim_impl(t2605, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2618: "cuda:0 bf16[1, 71, 2048, 64]"
t2619 = torch_slice_prim_impl(t2618, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2619: "cuda:0 bf16[1, 71, 2048, 32]"
t2620 = torch_slice_prim_impl(t2618, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2620: "cuda:0 bf16[1, 71, 2048, 32]"
t2640 = torch_slice_prim_impl(t2611, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2640: "cuda:0 bf16[1, 71, 2048, 64]"
t2641 = torch_slice_prim_impl(t2640, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2641: "cuda:0 bf16[1, 71, 2048, 32]"
t2642 = torch_slice_prim_impl(t2640, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2642: "cuda:0 bf16[1, 71, 2048, 32]"
t2662 = torch_slice_prim_impl(t2605, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2662: "cuda:0 bf16[1, 71, 2048, 0]"
del t2605
t2665 = torch_slice_prim_impl(t2611, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2665: "cuda:0 bf16[1, 71, 2048, 0]"
del t2611
[t2664, t2667, t2699] = nvFusion33(t2618, t2619, t2620, t2640, t2641, t2642, t2662, t2665, t2680, t61, t66)
# t2681 = prims.convert_element_type(t2680, dtypes.float32) # t2681: "cuda:0 f32[1, 2048, 18176]"
# t2683 = prims.div(t2681, 1.4142135623730951) # t2683: "cuda:0 f32[1, 2048, 18176]"
# t2686 = prims.erf(t2683) # t2686: "cuda:0 f32[1, 2048, 18176]"
# t2690 = prims.mul(0.5, t2686) # t2690: "cuda:0 f32[1, 2048, 18176]"
# t2694 = prims.add(0.5, t2690) # t2694: "cuda:0 f32[1, 2048, 18176]"
# t2698 = prims.mul(t2681, t2694) # t2698: "cuda:0 f32[1, 2048, 18176]"
# t2699 = prims.convert_element_type(t2698, dtypes.bfloat16) # t2699: "cuda:0 bf16[1, 2048, 18176]"
# t2621 = prims.convert_element_type(t2620, dtypes.float32) # t2621: "cuda:0 f32[1, 71, 2048, 32]"
# t2622 = prims.neg(t2621) # t2622: "cuda:0 f32[1, 71, 2048, 32]"
# t2623 = prims.convert_element_type(t2622, dtypes.bfloat16) # t2623: "cuda:0 bf16[1, 71, 2048, 32]"
# t2625 = prims.cat((t2623, t2619), -1) # t2625: "cuda:0 bf16[1, 71, 2048, 64]"
# t2627 = prims.convert_element_type(t2618, dtypes.float32) # t2627: "cuda:0 f32[1, 71, 2048, 64]"
# t2629 = prims.mul(t2627, t61) # t2629: "cuda:0 f32[1, 71, 2048, 64]"
# t2632 = prims.convert_element_type(t2625, dtypes.float32) # t2632: "cuda:0 f32[1, 71, 2048, 64]"
# t2634 = prims.mul(t2632, t66) # t2634: "cuda:0 f32[1, 71, 2048, 64]"
# t2638 = prims.add(t2629, t2634) # t2638: "cuda:0 f32[1, 71, 2048, 64]"
# t2639 = prims.convert_element_type(t2638, dtypes.bfloat16) # t2639: "cuda:0 bf16[1, 71, 2048, 64]"
# t2643 = prims.convert_element_type(t2642, dtypes.float32) # t2643: "cuda:0 f32[1, 71, 2048, 32]"
# t2644 = prims.neg(t2643) # t2644: "cuda:0 f32[1, 71, 2048, 32]"
# t2645 = prims.convert_element_type(t2644, dtypes.bfloat16) # t2645: "cuda:0 bf16[1, 71, 2048, 32]"
# t2647 = prims.cat((t2645, t2641), -1) # t2647: "cuda:0 bf16[1, 71, 2048, 64]"
# t2649 = prims.convert_element_type(t2640, dtypes.float32) # t2649: "cuda:0 f32[1, 71, 2048, 64]"
# t2651 = prims.mul(t2649, t61) # t2651: "cuda:0 f32[1, 71, 2048, 64]"
# t2654 = prims.convert_element_type(t2647, dtypes.float32) # t2654: "cuda:0 f32[1, 71, 2048, 64]"
# t2656 = prims.mul(t2654, t66) # t2656: "cuda:0 f32[1, 71, 2048, 64]"
# t2660 = prims.add(t2651, t2656) # t2660: "cuda:0 f32[1, 71, 2048, 64]"
# t2661 = prims.convert_element_type(t2660, dtypes.bfloat16) # t2661: "cuda:0 bf16[1, 71, 2048, 64]"
# t2664 = prims.cat((t2639, t2662), -1) # t2664: "cuda:0 bf16[1, 71, 2048, 64]"
# t2667 = prims.cat((t2661, t2665), -1) # t2667: "cuda:0 bf16[1, 71, 2048, 64]"
del t2618, t2619, t2620, t2640, t2641, t2642, t2662, t2665
t2700 = torch.nn.functional.linear(t2699, t_transformer_h_16_mlp_proj_weight, None) # t2700: "cuda:0 bf16[1, 2048, 4544]"
# t2700 = ltorch.linear(t2699, t_transformer_h_16_mlp_proj_weight, None) # t2700: "cuda:0 bf16[1, 2048, 4544]"
# t2700 = prims.linear(t2699, t_transformer_h_16_mlp_proj_weight, None) # t2700: "cuda:0 bf16[1, 2048, 4544]"
(t2668, t2669, t2670, t2671) = cudnn_sdpa_fwd(t2664, t2667, t2617, None, 0.0, True, scale=0.125)
t2674 = torch.permute(t2668, (0, 2, 1, 3)) # t2674: "cuda:0 bf16[1, 2048, 71, 64]"
# t2674 = ltorch.permute(t2668, (0, 2, 1, 3)) # t2674: "cuda:0 bf16[1, 2048, 71, 64]"
# t2674 = prims.transpose(t2668, (0, 2, 1, 3)) # t2674: "cuda:0 bf16[1, 2048, 71, 64]"
t2678 = torch.reshape(t2674, (1, 2048, 4544)) # t2678: "cuda:0 bf16[1, 2048, 4544]"
# t2678 = ltorch.reshape(t2674, (1, 2048, 4544)) # t2678: "cuda:0 bf16[1, 2048, 4544]"
# t2678 = prims.reshape(t2674, (1, 2048, 4544)) # t2678: "cuda:0 bf16[1, 2048, 4544]"
del t2674
t2679 = torch.nn.functional.linear(t2678, t_transformer_h_16_attn_proj_weight, None) # t2679: "cuda:0 bf16[1, 2048, 4544]"
# t2679 = ltorch.linear(t2678, t_transformer_h_16_attn_proj_weight, None) # t2679: "cuda:0 bf16[1, 2048, 4544]"
# t2679 = prims.linear(t2678, t_transformer_h_16_attn_proj_weight, None) # t2679: "cuda:0 bf16[1, 2048, 4544]"
t5731 = torch.unsqueeze(t_transformer_h_17_norm_1_weight, 0) # t5731: "cuda:0 bf16[1, 4544]"
# t5731 = ltorch.unsqueeze(t_transformer_h_17_norm_1_weight, 0) # t5731: "cuda:0 bf16[1, 4544]"
# t5731 = prims.broadcast_in_dim(t_transformer_h_17_norm_1_weight, [1, 4544], [1]) # t5731: "cuda:0 bf16[1, 4544]"
t5732 = torch.unsqueeze(t5731, 1) # t5732: "cuda:0 bf16[1, 1, 4544]"
# t5732 = ltorch.unsqueeze(t5731, 1) # t5732: "cuda:0 bf16[1, 1, 4544]"
# t5732 = prims.broadcast_in_dim(t5731, [1, 1, 4544], [0, 2]) # t5732: "cuda:0 bf16[1, 1, 4544]"
del t5731
t2726 = Tensor.expand(t5732, (1, 2048, 4544)) # t2726: "cuda:0 bf16[1, 2048, 4544]"
# t2726 = ltorch.expand(t5732, (1, 2048, 4544)) # t2726: "cuda:0 bf16[1, 2048, 4544]"
# t2726 = prims.broadcast_in_dim(t5732, (1, 2048, 4544), (0, 1, 2)) # t2726: "cuda:0 bf16[1, 2048, 4544]"
del t5732
t5734 = torch.unsqueeze(t_transformer_h_17_norm_1_bias, 0) # t5734: "cuda:0 bf16[1, 4544]"
# t5734 = ltorch.unsqueeze(t_transformer_h_17_norm_1_bias, 0) # t5734: "cuda:0 bf16[1, 4544]"
# t5734 = prims.broadcast_in_dim(t_transformer_h_17_norm_1_bias, [1, 4544], [1]) # t5734: "cuda:0 bf16[1, 4544]"
t5735 = torch.unsqueeze(t5734, 1) # t5735: "cuda:0 bf16[1, 1, 4544]"
# t5735 = ltorch.unsqueeze(t5734, 1) # t5735: "cuda:0 bf16[1, 1, 4544]"
# t5735 = prims.broadcast_in_dim(t5734, [1, 1, 4544], [0, 2]) # t5735: "cuda:0 bf16[1, 1, 4544]"
del t5734
t2729 = Tensor.expand(t5735, (1, 2048, 4544)) # t2729: "cuda:0 bf16[1, 2048, 4544]"
# t2729 = ltorch.expand(t5735, (1, 2048, 4544)) # t2729: "cuda:0 bf16[1, 2048, 4544]"
# t2729 = prims.broadcast_in_dim(t5735, (1, 2048, 4544), (0, 1, 2)) # t2729: "cuda:0 bf16[1, 2048, 4544]"
del t5735
[t2708, t2715, t2720, t2732] = nvFusion34(t2547, t2679, t2700, t2726, t2729)
# t2706 = prims.convert_element_type(t2547, dtypes.float32) # t2706: "cuda:0 f32[1, 2048, 4544]"
# t2701 = prims.convert_element_type(t2700, dtypes.float32) # t2701: "cuda:0 f32[1, 2048, 4544]"
# t2702 = prims.convert_element_type(t2679, dtypes.float32) # t2702: "cuda:0 f32[1, 2048, 4544]"
# t2703 = prims.add(t2701, t2702) # t2703: "cuda:0 f32[1, 2048, 4544]"
# t2707 = prims.add(t2703, t2706) # t2707: "cuda:0 f32[1, 2048, 4544]"
# t2708 = prims.convert_element_type(t2707, dtypes.bfloat16) # t2708: "cuda:0 bf16[1, 2048, 4544]"
# (t2714, t2715) = prims.var_mean(t2707, (2,), correction=0)
# t2716 = prims.broadcast_in_dim(t2714, [1, 2048, 1], [0, 1]) # t2716: "cuda:0 f32[1, 2048, 1]"
# t2717 = prims.broadcast_in_dim(t2715, [1, 2048, 1], [0, 1]) # t2717: "cuda:0 f32[1, 2048, 1]"
# t2719 = prims.add(t2716, 1e-05) # t2719: "cuda:0 f32[1, 2048, 1]"
# t2720 = prims.rsqrt(t2719) # t2720: "cuda:0 f32[1, 2048, 1]"
# t2721 = prims.broadcast_in_dim(t2717, (1, 2048, 4544), (0, 1, 2)) # t2721: "cuda:0 f32[1, 2048, 4544]"
# t2723 = prims.sub(t2707, t2721) # t2723: "cuda:0 f32[1, 2048, 4544]"
# t2724 = prims.broadcast_in_dim(t2720, (1, 2048, 4544), (0, 1, 2)) # t2724: "cuda:0 f32[1, 2048, 4544]"
# t2725 = prims.mul(t2723, t2724) # t2725: "cuda:0 f32[1, 2048, 4544]"
# t2727 = prims.convert_element_type(t2726, dtypes.float32) # t2727: "cuda:0 f32[1, 2048, 4544]"
# t2728 = prims.mul(t2725, t2727) # t2728: "cuda:0 f32[1, 2048, 4544]"
# t2730 = prims.convert_element_type(t2729, dtypes.float32) # t2730: "cuda:0 f32[1, 2048, 4544]"
# t2731 = prims.add(t2728, t2730) # t2731: "cuda:0 f32[1, 2048, 4544]"
# t2732 = prims.convert_element_type(t2731, dtypes.bfloat16) # t2732: "cuda:0 bf16[1, 2048, 4544]"
del t2729
t2841 = torch.nn.functional.linear(t2732, t_transformer_h_17_mlp_fc_weight, None) # t2841: "cuda:0 bf16[1, 2048, 18176]"
# t2841 = ltorch.linear(t2732, t_transformer_h_17_mlp_fc_weight, None) # t2841: "cuda:0 bf16[1, 2048, 18176]"
# t2841 = prims.linear(t2732, t_transformer_h_17_mlp_fc_weight, None) # t2841: "cuda:0 bf16[1, 2048, 18176]"
t2733 = torch.nn.functional.linear(t2732, t_transformer_h_17_attn_attn_weight, None) # t2733: "cuda:0 bf16[1, 2048, 4672]"
# t2733 = ltorch.linear(t2732, t_transformer_h_17_attn_attn_weight, None) # t2733: "cuda:0 bf16[1, 2048, 4672]"
# t2733 = prims.linear(t2732, t_transformer_h_17_attn_attn_weight, None) # t2733: "cuda:0 bf16[1, 2048, 4672]"
t2739 = torch.reshape(t2733, (1, 2048, 1, 73, 64)) # t2739: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2739 = ltorch.reshape(t2733, (1, 2048, 1, 73, 64)) # t2739: "cuda:0 bf16[1, 2048, 1, 73, 64]"
# t2739 = prims.reshape(t2733, (1, 2048, 1, 73, 64)) # t2739: "cuda:0 bf16[1, 2048, 1, 73, 64]"
del t2733
t2745 = torch.permute(t2739, (0, 2, 3, 1, 4)) # t2745: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2745 = ltorch.permute(t2739, (0, 2, 3, 1, 4)) # t2745: "cuda:0 bf16[1, 1, 73, 2048, 64]"
# t2745 = prims.transpose(t2739, (0, 2, 3, 1, 4)) # t2745: "cuda:0 bf16[1, 1, 73, 2048, 64]"
del t2739
(t2746, t2747, t2748) = torch.split(t2745, (71, 1, 1), 2)
# (t2746, t2747, t2748) = ltorch.split(t2745, (71, 1, 1), 2)
# t2746 = prims.slice_prim(t2745, [0, 0, 0, 0, 0], [1, 1, 71, 2048, 64], [1, 1, 1, 1, 1]) # t2746: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2747 = prims.slice_prim(t2745, [0, 0, 71, 0, 0], [1, 1, 72, 2048, 64], [1, 1, 1, 1, 1]) # t2747: "cuda:0 bf16[1, 1, 1, 2048, 64]"
# t2748 = prims.slice_prim(t2745, [0, 0, 72, 0, 0], [1, 1, 73, 2048, 64], [1, 1, 1, 1, 1]) # t2748: "cuda:0 bf16[1, 1, 1, 2048, 64]"
del t2745
t2754 = Tensor.expand(t2747, (1, 1, 71, 2048, 64)) # t2754: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2754 = ltorch.expand(t2747, (1, 1, 71, 2048, 64)) # t2754: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2754 = prims.broadcast_in_dim(t2747, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2754: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2747
t2760 = Tensor.expand(t2748, (1, 1, 71, 2048, 64)) # t2760: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2760 = ltorch.expand(t2748, (1, 1, 71, 2048, 64)) # t2760: "cuda:0 bf16[1, 1, 71, 2048, 64]"
# t2760 = prims.broadcast_in_dim(t2748, (1, 1, 71, 2048, 64), (0, 1, 2, 3, 4)) # t2760: "cuda:0 bf16[1, 1, 71, 2048, 64]"
del t2748
t2766 = torch.reshape(t2746, (1, 71, 2048, 64)) # t2766: "cuda:0 bf16[1, 71, 2048, 64]"
# t2766 = ltorch.reshape(t2746, (1, 71, 2048, 64)) # t2766: "cuda:0 bf16[1, 71, 2048, 64]"
# t2766 = prims.reshape(t2746, (1, 71, 2048, 64)) # t2766: "cuda:0 bf16[1, 71, 2048, 64]"
del t2746
t2772 = torch.reshape(t2754, (1, 71, 2048, 64)) # t2772: "cuda:0 bf16[1, 71, 2048, 64]"
# t2772 = ltorch.reshape(t2754, (1, 71, 2048, 64)) # t2772: "cuda:0 bf16[1, 71, 2048, 64]"
# t2772 = prims.reshape(t2754, (1, 71, 2048, 64)) # t2772: "cuda:0 bf16[1, 71, 2048, 64]"
del t2754
t2778 = torch.reshape(t2760, (1, 71, 2048, 64)) # t2778: "cuda:0 bf16[1, 71, 2048, 64]"
# t2778 = ltorch.reshape(t2760, (1, 71, 2048, 64)) # t2778: "cuda:0 bf16[1, 71, 2048, 64]"
# t2778 = prims.reshape(t2760, (1, 71, 2048, 64)) # t2778: "cuda:0 bf16[1, 71, 2048, 64]"
del t2760
t2779 = torch_slice_prim_impl(t2766, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2779: "cuda:0 bf16[1, 71, 2048, 64]"
t2780 = torch_slice_prim_impl(t2779, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2780: "cuda:0 bf16[1, 71, 2048, 32]"
t2781 = torch_slice_prim_impl(t2779, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2781: "cuda:0 bf16[1, 71, 2048, 32]"
t2801 = torch_slice_prim_impl(t2772, [0, 0, 0, 0], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2801: "cuda:0 bf16[1, 71, 2048, 64]"
t2802 = torch_slice_prim_impl(t2801, [0, 0, 0, 0], [1, 71, 2048, 32], [1, 1, 1, 1]) # t2802: "cuda:0 bf16[1, 71, 2048, 32]"
t2803 = torch_slice_prim_impl(t2801, [0, 0, 0, 32], [1, 71, 2048, 64], [1, 1, 1, 1]) # t2803: "cuda:0 bf16[1, 71, 2048, 32]"
t2823 = torch_slice_prim_impl(t2766, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2823: "cuda:0 bf16[1, 71, 2048, 0]"
del t2766
t2826 = torch_slice_prim_impl(t2772, [0, 0, 0, 0], [1, 71, 2048, 0], [1, 1, 1, 1]) # t2826: "cuda:0 bf16[1, 71, 2048, 0]"
del t2772
[t2825, t2828, t2860] = nvFusion35(t2779, t2780, t2781, t2801, t2802, t2803, t2823, t2826, t2841, t61, t66)
# t2842 = prims.convert_element_type(t2841, dtypes.float32) # t2842: "cuda:0 f32[1, 2048, 18176]"
# t2844 = prims.div(t2842, 1.4142135623730951) # t2844: "cuda:0 f32[1, 2048, 18176]"
# t2847 = prims.erf(t2844) # t2847: "cuda:0 f32[1, 2048, 18176]"
# t2851 = prims.mul(0.5, t2847) # t2851: "cuda:0 f32[1, 2048, 18176]"
# t2855 = prims.add(0.5, t2851) # t2855: "cuda:0 f32[1, 2048, 18176]"
# t2859 = prims.mul(t2842, t2855) # t2859: "cuda:0 f32[1, 2048, 18176]"
# t2860 = prims.convert_element_type(t2859, dtypes.bfloat16) # t2860: "cuda:0 bf16[1, 2048, 18176]"
# t2782 = prims.convert_element_type(t2781, dtypes.float32) # t2782: "cuda:0 f32[1, 71, 2048, 32]"
# t2783 = prims.neg(t2782) # t2783: "cuda:0 f32[1, 71, 2048, 32]"
# t2784 = prims.convert_element_type(t2783, dtypes.bfloat16) # t2784: "cuda:0 bf16[1, 71, 2048, 32]"
# t2786 = prims.cat((t2784, t2780), -1) # t2786: "cuda:0 bf16[1, 71, 2048, 64]"
# t2788 = prims.convert_element_type(t2779, dtypes.float32) # t2788: "cuda:0 f32[1, 71, 2048, 64]"
# t2790 = prims.mul(t2788, t61) # t2790: "cuda:0 f32[1, 71, 2048, 64]"
# t2793 = prims.convert_element_type(t2786, dtypes.float32) # t2793: "cuda:0 f32[1, 71, 2048, 64]"
# t2795 = prims.mul(t2793, t66) # t2795: "cuda:0 f32[1, 71, 2048, 64]"
# t2799 = prims.add(t2790, t2795) # t2799: "cuda:0 f32[1, 71, 2048, 64]"
# t2800 = prims.convert_element_type(t2799, dtypes.bfloat16) # t2800: "cuda:0 bf16[1, 71, 2048, 64]"
# t2804 = prims.convert_element_type(t2803, dtypes.float32) # t2804: "cuda:0 f32[1, 71, 2048, 32]"
# t2805 = prims.neg(t2804) # t2805: "cuda:0 f32[1, 71, 2048, 32]"
# t2806 = prims.convert_element_type(t2805, dtypes.bfloat16) # t2806: "cuda:0 bf16[1, 71, 2048, 32]"
# t2808 = prims.cat((t2806, t2802), -1) # t2808: "cuda:0 bf16[1, 71, 2048, 64]"
# t2810 = prims.convert_element_type(t2801, dtypes.float32) # t2810: "cuda:0 f32[1, 71, 2048, 64]"
# t2812 = prims.mul(t2810, t61) # t2812: "cuda:0 f32[1, 71, 2048, 64]"
# t2815 = prims.convert_element_type(t2808, dtypes.float32) # t2815: "cuda:0 f32[1, 71, 2048, 64]"
# t2817 = prims.mul(t2815, t66) # t2817: "cuda:0 f32[1, 71, 2048, 64]"
# t2821 = prims.add(t2812, t2817) # t2821: "cuda:0 f32[1, 71, 2048, 64]"
# t2822 = prims.convert_element_type(t2821, dtypes.bfloat16) # t2822: "cuda:0 bf16[1, 71, 2048, 64]"
# t2825 = prims.cat((t2800, t2823), -1) # t2825: "cuda:0 bf16[1, 71, 2048, 64]"
# t2828 = prims.cat((t2822, t2826), -1) # t2828: "cuda:0 bf16[1, 71, 2048, 64]"
del t2779, t2780, t2781, t2801, t2802, t2803, t2823, t2826
t2861 = torch.nn.functional.linear(t2860, t_transformer_h_17_mlp_proj_weight, None) # t2861: "cuda:0 bf16[1, 2048, 4544]"
# t2861 = ltorch.linear(t2860, t_transformer_h_17_mlp_proj_weight, None) # t2861: "cuda:0 bf16[1, 2048, 4544]"
# t2861 = prims.linear(t2860, t_transformer_h_17_mlp_proj_weight, None) # t2861: "cuda:0 bf16[1, 2048, 4544]"
(t2829, t2830, t2831, t2832) = cudnn_sdpa_fwd(t2825, t2828, t2778, None, 0.0, True, scale=0.125)
t2835 = torch.permute(t2829, (0, 2, 1, 3)) # t2835: "cuda:0 bf16[1, 2048, 71, 64]"
# t2835 = ltorch.permute(t2829, (0, 2, 1, 3)) # t2835: "cuda:0 bf16[1, 2048, 71, 64]"
# t2835 = prims.transpose(t2829, (0, 2, 1, 3)) # t2835: "cuda:0 bf16[1, 2048, 71, 64]"
t2839 = torch.reshape(t2835, (1, 2048, 4544)) # t2839: "cuda:0 bf16[1, 2048, 4544]"
# t2839 = ltorch.reshape(t2835, (1, 2048, 4544)) # t2839: "cuda:0 bf16[1, 2048, 4544]"
# t2839 = prims.reshape(t2835, (1, 2048, 4544)) # t2839: "cuda:0 bf16[1, 2048, 4544]"
del t2835
t2840 = torch.nn.functional.linear(t2839, t_transformer_h_17_attn_proj_weight, None) # t2840: "cuda:0 bf16[1, 2048, 4544]"
# t2840 = ltorch.linear(t2839, t_transformer_h_17_attn_proj_weight, None) # t2840: "cuda:0 bf16[1, 2048, 4544]"
# t2840 = prims.linear(t2839, t_transformer_h_17_attn_proj_weight, None) # t2840: "cuda:0 bf16[1, 2048, 4544]"
t5757 = torch.unsqueeze(t_transformer_h_18_norm_1_weight, 0) # t5757: "cuda:0 bf16[1, 4544]"
# t5757 = ltorch.unsqueeze(t_transformer_h_18_norm_1_weight, 0) # t5757: "cuda:0 bf16[1, 4544]"
# t5757 = prims.broadcast_in_dim(t_transformer_h_18_norm_1_weight, [1, 4544], [1]) # t5757: "cuda:0 bf16[1, 4544]"
t5758 = torch.unsqueeze(t5757, 1) # t5758: "cuda:0 bf16[1, 1, 4544]"
# t5758 = ltorch.unsqueeze(t5757, 1) # t5758: "cuda:0 bf16[1, 1, 4544]"
# t5758 = prims.broadcast_in_dim(t5757, [1, 1, 4544], [0, 2]) # t5758: "cuda:0 bf16[1, 1, 4544]"
del t5757
t2887 = Tensor.expand(t5758, (1, 2048, 4544)) # t2887: "cuda:0 bf16[1, 2048, 4544]"
# t2887 = ltorch.expand(t5758, (1, 2048, 4544)) # t2887: "cuda:0 bf16[1, 2048, 4544]"
# t2887 = prims.broadcast_in_dim(t5758, (1, 2048, 4544), (0, 1, 2)) # t2887: "cuda:0 bf16[1, 2048, 4544]"
del t5758
t5760 = torch.unsqueeze(t_transformer_h_18_norm_1_bias, 0) # t5760: "cuda:0 bf16[1, 4544]"
# t5760 = ltorch.unsqueeze(t_transformer_h_18_norm_1_bias, 0) # t5760: "cuda:0 bf16[1, 4544]"
# t5760 = prims.broadcast_in_dim(t_transformer_h_18_norm_1_bias, [1, 4544], [1]) # t5760: "cuda:0 bf16[1, 4544]"
t5761 = torch.unsqueeze(t5760, 1) # t5761: "cuda:0 bf16[1, 1, 4544]"
# t5761 = ltorch.unsqueeze(t5760, 1) # t5761: "cuda:0 bf16[1, 1, 4544]"
# t5761 = prims.broadcast_in_dim(t5760, [1, 1, 4544], [0, 2]) # t5761: "cuda:0 bf16[1, 1, 4544]"
del t5760
t2890 = Tensor.expand(t5761, (1, 2048, 4544)) # t2890: "cuda:0 bf16[1, 2048, 4544]"
# t2890 = ltorch.expand(t5761, (1, 2048, 4544)) # t2890: "cuda:0 bf16[1, 2048, 4544]"
# t2890 = prims.broadcast_in_dim(t5761, (1, 2048, 4544), (0, 1, 2)) # t2890: "cuda:0 bf16[1, 2048, 4544]"
del t5761
[t2869, t2876, t2881, t2893] = nvFusion36(t2708, t2840, t2861, t2887, t2890)
# t2867 = prims.convert_element_type(t2708, dtypes.float32) # t2867: "cuda:0 f32[1, 2048, 4544]"
# t2862 = prims.convert_element_type(t2861, dtypes.float32) # t2862: "cuda:0 f32[1, 2048, 4544]"
# t2863 = prims.convert_element_type(t2840, dtypes.float32) # t2863: "cuda:0 f32[1, 2048, 4544]"
# t2864 = prims.add(t2862, t2863) # t2864: "cuda:0 f32[1, 2048, 4544]"
# t2868 = prims.add(t2864, t2867) # t2868: "cuda:0 f32[1, 2048, 4544]"
# t2869 = prims.convert_element_type(t2868, dtypes.bfloat16) # t2869: "cuda:0 bf16[1, 2048, 4544]"
# (t2875, t2876) = prims.var_mean(t2868, (2,), correction=0)
# t2877 = prims.broadcast_in_dim(t2875, [1, 2048, 1], [0, 1]) # t2877: "cuda:0 f32[1, 2048, 1]"
# t2878 = prims.broadcast_in_dim(t2876, [1, 2048, 1], [0, 1]) # t2878: "cuda:0 f32[1,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment