Skip to content

Instantly share code, notes, and snippets.

@kiya00
Last active October 24, 2024 09:08
Show Gist options
  • Save kiya00/c5c05bf24a1462d04eae8340d9cb811d to your computer and use it in GitHub Desktop.
Save kiya00/c5c05bf24a1462d04eae8340d9cb811d to your computer and use it in GitHub Desktop.
##########
#Graph0-ThunderFn0 last backward trace
##########
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, _, = saved_for_backward
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t582, = cotangents
clear_mutable_collection(cotangents)
del cotangents
l_idx_, l_self_buffers_cos_, l_self_buffers_sin_, \
l_self_modules_lm_head_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, \
l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, \
l_self_modules_transformer_modules_ln_f_parameters_weight_, rsqrt, t533, t563, \
x, x_1, x_4, = C0
clear_mutable_collection(C0)
del C0
[t583] = nvFusion0(t582)
# t583 = prims.reshape(t582, (16384, 32000)) # t583: "cuda:0 bf16[16384, 32000]"
t584 = torch.matmul(t583, l_self_modules_lm_head_parameters_weight_) # t584: "cuda:0 bf16[16384, 4096]"
# t584 = ltorch.matmul(t583, l_self_modules_lm_head_parameters_weight_) # t584: "cuda:0 bf16[16384, 4096]"
# t584 = prims.matmul(t583, l_self_modules_lm_head_parameters_weight_) # t584: "cuda:0 bf16[16384, 4096]"
del t583, l_self_modules_lm_head_parameters_weight_
[t614, t749, t753, t594] = nvFusion1(t584, l_self_modules_transformer_modules_ln_f_parameters_weight_, t533, t563, rsqrt)
# t585 = prims.reshape(t584, (1, 16384, 4096)) # t585: "cuda:0 bf16[1, 16384, 4096]"
# float_2 = prims.convert_element_type(l_self_modules_transformer_modules_ln_f_parameters_weight_, dtypes.float32) # float_2: "cuda:0 f32[4096]"
# t590 = prims.convert_element_type(t585, dtypes.float32) # t590: "cuda:0 f32[1, 16384, 4096]"
# t578 = prims.broadcast_in_dim(float_2, (1, 16384, 4096), (2,)) # t578: "cuda:0 f32[1, 16384, 4096]"
# t565 = prims.convert_element_type(t533, dtypes.float32) # t565: "cuda:0 f32[1, 16384, 4096]"
# t564 = prims.convert_element_type(t563, dtypes.float32) # t564: "cuda:0 f32[1, 16384, 4096]"
# t591 = prims.mul(t578, t590) # t591: "cuda:0 f32[1, 16384, 4096]"
# t566 = prims.add(t564, t565) # t566: "cuda:0 f32[1, 16384, 4096]"
# t596 = prims.mul(t566, t591) # t596: "cuda:0 f32[1, 16384, 4096]"
# t597 = prims.sum(t596, (0, 2)) # t597: "cuda:0 f32[16384]"
# t598 = prims.broadcast_in_dim(t597, [1, 16384, 1], [1]) # t598: "cuda:0 f32[1, 16384, 1]"
# t600 = prims.pow(rsqrt, 3.0) # t600: "cuda:0 f32[1, 16384, 1]"
# t599 = prims.mul(-0.5, t598) # t599: "cuda:0 f32[1, 16384, 1]"
# t601 = prims.mul(t599, t600) # t601: "cuda:0 f32[1, 16384, 1]"
# t604 = prims.div(t601, 4096.0) # t604: "cuda:0 f32[1, 16384, 1]"
# t605 = prims.sum(t604, (0, 2)) # t605: "cuda:0 f32[16384]"
# t606 = prims.broadcast_in_dim(t605, [1, 16384], [1]) # t606: "cuda:0 f32[1, 16384]"
# t608 = prims.broadcast_in_dim(t606, [1, 16384, 1], [0, 1]) # t608: "cuda:0 f32[1, 16384, 1]"
# t609 = prims.broadcast_in_dim(t608, (1, 16384, 4096), (0, 1, 2)) # t609: "cuda:0 f32[1, 16384, 4096]"
# t575 = prims.broadcast_in_dim(rsqrt, (1, 16384, 4096), (0, 1, 2)) # t575: "cuda:0 f32[1, 16384, 4096]"
# t610 = prims.mul(t566, t609) # t610: "cuda:0 f32[1, 16384, 4096]"
# t595 = prims.mul(t575, t591) # t595: "cuda:0 f32[1, 16384, 4096]"
# t612 = prims.add(t595, t610) # t612: "cuda:0 f32[1, 16384, 4096]"
# t613 = prims.add(t612, t610) # t613: "cuda:0 f32[1, 16384, 4096]"
# x_normed = prims.mul(t566, t575) # x_normed: "cuda:0 f32[1, 16384, 4096]"
# t614 = prims.convert_element_type(t613, dtypes.bfloat16) # t614: "cuda:0 bf16[1, 16384, 4096]"
# t592 = prims.mul(x_normed, t590) # t592: "cuda:0 f32[1, 16384, 4096]"
# t749 = prims.reshape(t614, (16384, 4096)) # t749: "cuda:0 bf16[16384, 4096]"
# t593 = prims.sum(t592, (0, 1)) # t593: "cuda:0 f32[4096]"
# t753 = prims.transpose(t749, (1, 0)) # t753: "cuda:0 bf16[4096, 16384]"
# t594 = prims.convert_element_type(t593, dtypes.bfloat16) # t594: "cuda:0 bf16[4096]"
del t584, l_self_modules_transformer_modules_ln_f_parameters_weight_, t533, t563, rsqrt
t750 = torch.matmul(t749, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_) # t750: "cuda:0 bf16[16384, 11008]"
# t750 = ltorch.matmul(t749, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_) # t750: "cuda:0 bf16[16384, 11008]"
# t750 = prims.matmul(t749, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_) # t750: "cuda:0 bf16[16384, 11008]"
del t749, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_
[t624, t630, t975] = nvFusion2(x_1, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_)
# t615 = prims.convert_element_type(x_1, dtypes.float32) # t615: "cuda:0 f32[1, 16384, 4096]"
# t616 = prims.mul(t615, t615) # t616: "cuda:0 f32[1, 16384, 4096]"
# t618 = prims.sum(t616, (2,)) # t618: "cuda:0 f32[1, 16384]"
# t619 = prims.broadcast_in_dim(t618, [1, 16384, 1], [0, 1]) # t619: "cuda:0 f32[1, 16384, 1]"
# t621 = prims.div(t619, 4096.0) # t621: "cuda:0 f32[1, 16384, 1]"
# t623 = prims.add(t621, 1e-06) # t623: "cuda:0 f32[1, 16384, 1]"
# t624 = prims.rsqrt(t623) # t624: "cuda:0 f32[1, 16384, 1]"
# t625 = prims.broadcast_in_dim(t624, (1, 16384, 4096), (0, 1, 2)) # t625: "cuda:0 f32[1, 16384, 4096]"
# t626 = prims.mul(t615, t625) # t626: "cuda:0 f32[1, 16384, 4096]"
# t627 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, dtypes.float32) # t627: "cuda:0 f32[4096]"
# t628 = prims.broadcast_in_dim(t627, (1, 16384, 4096), (2,)) # t628: "cuda:0 f32[1, 16384, 4096]"
# t629 = prims.mul(t626, t628) # t629: "cuda:0 f32[1, 16384, 4096]"
# t630 = prims.convert_element_type(t629, dtypes.bfloat16) # t630: "cuda:0 bf16[1, 16384, 4096]"
# t975 = prims.reshape(t630, (16384, 4096)) # t975: "cuda:0 bf16[16384, 4096]"
t631 = torch.nn.functional.linear(t630, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_, None) # t631: "cuda:0 bf16[1, 16384, 12288]"
# t631 = ltorch.linear(t630, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_, None) # t631: "cuda:0 bf16[1, 16384, 12288]"
# t631 = prims.linear(t630, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_, None) # t631: "cuda:0 bf16[1, 16384, 12288]"
del t630
[t649, t694, t697] = TorchCompile1(t631, l_self_buffers_cos_, l_self_buffers_sin_)
# t632 = prims.reshape(t631, (1, 16384, 32, 3, 128)) # t632: "cuda:0 bf16[1, 16384, 32, 3, 128]"
# t633 = prims.transpose(t632, (0, 2, 3, 1, 4)) # t633: "cuda:0 bf16[1, 32, 3, 16384, 128]"
# (t634, t635, t636) = ltorch.split(t633, (1, 1, 1), 2)
# t634 = prims.slice_prim(t633, [0, 0, 0, 0, 0], [1, 32, 1, 16384, 128], [1, 1, 1, 1, 1]) # t634: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t635 = prims.slice_prim(t633, [0, 0, 1, 0, 0], [1, 32, 2, 16384, 128], [1, 1, 1, 1, 1]) # t635: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t636 = prims.slice_prim(t633, [0, 0, 2, 0, 0], [1, 32, 3, 16384, 128], [1, 1, 1, 1, 1]) # t636: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t637 = prims.reshape(t634, (1, 32, 16384, 128)) # t637: "cuda:0 bf16[1, 32, 16384, 128]"
# t643 = prims.reshape(t635, (1, 32, 16384, 128)) # t643: "cuda:0 bf16[1, 32, 16384, 128]"
# t649 = prims.reshape(t636, (1, 32, 16384, 128)) # t649: "cuda:0 bf16[1, 32, 16384, 128]"
# t650 = prims.slice_prim(t637, [0, 0, 0, 0], [1, 32, 16384, 64], [1, 1, 1, 1]) # t650: "cuda:0 bf16[1, 32, 16384, 64]"
# t651 = prims.slice_prim(t637, [0, 0, 0, 64], [1, 32, 16384, 128], [1, 1, 1, 1]) # t651: "cuda:0 bf16[1, 32, 16384, 64]"
# t652 = prims.convert_element_type(t651, dtypes.float32) # t652: "cuda:0 f32[1, 32, 16384, 64]"
# t653 = prims.neg(t652) # t653: "cuda:0 f32[1, 32, 16384, 64]"
# t654 = prims.convert_element_type(t653, dtypes.bfloat16) # t654: "cuda:0 bf16[1, 32, 16384, 64]"
# t655 = prims.cat([t654, t650], -1) # t655: "cuda:0 bf16[1, 32, 16384, 128]"
# t656 = prims.broadcast_in_dim(l_self_buffers_cos_, (1, 32, 16384, 128), (2, 3)) # t656: "cuda:0 bf16[1, 32, 16384, 128]"
# t657 = prims.convert_element_type(t637, dtypes.float32) # t657: "cuda:0 f32[1, 32, 16384, 128]"
# t658 = prims.convert_element_type(t656, dtypes.float32) # t658: "cuda:0 f32[1, 32, 16384, 128]"
# t659 = ltorch.mul(t657, t658) # t659: "cuda:0 f32[1, 32, 16384, 128]"
# t659 = prims.mul(t657, t658) # t659: "cuda:0 f32[1, 32, 16384, 128]"
# t660 = prims.convert_element_type(t659, dtypes.bfloat16) # t660: "cuda:0 bf16[1, 32, 16384, 128]"
# t661 = prims.broadcast_in_dim(l_self_buffers_sin_, (1, 32, 16384, 128), (2, 3)) # t661: "cuda:0 bf16[1, 32, 16384, 128]"
# t662 = prims.convert_element_type(t655, dtypes.float32) # t662: "cuda:0 f32[1, 32, 16384, 128]"
# t663 = prims.convert_element_type(t661, dtypes.float32) # t663: "cuda:0 f32[1, 32, 16384, 128]"
# t664 = ltorch.mul(t662, t663) # t664: "cuda:0 f32[1, 32, 16384, 128]"
# t664 = prims.mul(t662, t663) # t664: "cuda:0 f32[1, 32, 16384, 128]"
# t665 = prims.convert_element_type(t664, dtypes.bfloat16) # t665: "cuda:0 bf16[1, 32, 16384, 128]"
# t668 = ltorch.add(t659, t664, alpha=1) # t668: "cuda:0 f32[1, 32, 16384, 128]"
# t668 = prims.add(t659, t664) # t668: "cuda:0 f32[1, 32, 16384, 128]"
# t669 = prims.convert_element_type(t668, dtypes.bfloat16) # t669: "cuda:0 bf16[1, 32, 16384, 128]"
# t671 = prims.slice_prim(t643, [0, 0, 0, 0], [1, 32, 16384, 64], [1, 1, 1, 1]) # t671: "cuda:0 bf16[1, 32, 16384, 64]"
# t673 = prims.slice_prim(t643, [0, 0, 0, 64], [1, 32, 16384, 128], [1, 1, 1, 1]) # t673: "cuda:0 bf16[1, 32, 16384, 64]"
# t674 = prims.convert_element_type(t673, dtypes.float32) # t674: "cuda:0 f32[1, 32, 16384, 64]"
# t675 = prims.neg(t674) # t675: "cuda:0 f32[1, 32, 16384, 64]"
# t676 = prims.convert_element_type(t675, dtypes.bfloat16) # t676: "cuda:0 bf16[1, 32, 16384, 64]"
# t678 = prims.cat([t676, t671], -1) # t678: "cuda:0 bf16[1, 32, 16384, 128]"
# t680 = prims.convert_element_type(t643, dtypes.float32) # t680: "cuda:0 f32[1, 32, 16384, 128]"
# t682 = ltorch.mul(t680, t658) # t682: "cuda:0 f32[1, 32, 16384, 128]"
# t682 = prims.mul(t680, t658) # t682: "cuda:0 f32[1, 32, 16384, 128]"
# t683 = prims.convert_element_type(t682, dtypes.bfloat16) # t683: "cuda:0 bf16[1, 32, 16384, 128]"
# t685 = prims.convert_element_type(t678, dtypes.float32) # t685: "cuda:0 f32[1, 32, 16384, 128]"
# t687 = ltorch.mul(t685, t663) # t687: "cuda:0 f32[1, 32, 16384, 128]"
# t687 = prims.mul(t685, t663) # t687: "cuda:0 f32[1, 32, 16384, 128]"
# t688 = prims.convert_element_type(t687, dtypes.bfloat16) # t688: "cuda:0 bf16[1, 32, 16384, 128]"
# t691 = ltorch.add(t682, t687, alpha=1) # t691: "cuda:0 f32[1, 32, 16384, 128]"
# t691 = prims.add(t682, t687) # t691: "cuda:0 f32[1, 32, 16384, 128]"
# t692 = prims.convert_element_type(t691, dtypes.bfloat16) # t692: "cuda:0 bf16[1, 32, 16384, 128]"
# t693 = prims.slice_prim(t637, [0, 0, 0, 0], [1, 32, 16384, 0], [1, 1, 1, 1]) # t693: "cuda:0 bf16[1, 32, 16384, 0]"
# t694 = prims.cat([t669, t693], -1) # t694: "cuda:0 bf16[1, 32, 16384, 128]"
# t695 = prims.slice_prim(t643, [0, 0, 0, 0], [1, 32, 16384, 0], [1, 1, 1, 1]) # t695: "cuda:0 bf16[1, 32, 16384, 0]"
# t697 = prims.cat([t692, t695], -1) # t697: "cuda:0 bf16[1, 32, 16384, 128]"
del t631
(t698, t699, t700, t701) = cudnn_sdpa_fwd(t694, t697, t649, None, 0.0, True, scale=0.08838834764831843)
[t703, t832] = nvFusion3(t698)
# t702 = prims.transpose(t698, (0, 2, 1, 3)) # t702: "cuda:0 bf16[1, 16384, 32, 128]"
# t703 = prims.reshape(t702, (1, 16384, 4096)) # t703: "cuda:0 bf16[1, 16384, 4096]"
# t832 = prims.reshape(t703, (16384, 4096)) # t832: "cuda:0 bf16[16384, 4096]"
t704 = torch.nn.functional.linear(t703, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_, None) # t704: "cuda:0 bf16[1, 16384, 4096]"
# t704 = ltorch.linear(t703, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_, None) # t704: "cuda:0 bf16[1, 16384, 4096]"
# t704 = prims.linear(t703, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_, None) # t704: "cuda:0 bf16[1, 16384, 4096]"
del t703
[t718, t724, t782] = nvFusion4(t704, x_1, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_)
# t705 = prims.convert_element_type(t704, dtypes.float32) # t705: "cuda:0 f32[1, 16384, 4096]"
# t706 = prims.convert_element_type(x_1, dtypes.float32) # t706: "cuda:0 f32[1, 16384, 4096]"
# t707 = prims.add(t705, t706) # t707: "cuda:0 f32[1, 16384, 4096]"
# t710 = prims.mul(t707, t707) # t710: "cuda:0 f32[1, 16384, 4096]"
# t712 = prims.sum(t710, (2,)) # t712: "cuda:0 f32[1, 16384]"
# t713 = prims.broadcast_in_dim(t712, [1, 16384, 1], [0, 1]) # t713: "cuda:0 f32[1, 16384, 1]"
# t715 = prims.div(t713, 4096.0) # t715: "cuda:0 f32[1, 16384, 1]"
# t717 = prims.add(t715, 1e-06) # t717: "cuda:0 f32[1, 16384, 1]"
# t718 = prims.rsqrt(t717) # t718: "cuda:0 f32[1, 16384, 1]"
# t719 = prims.broadcast_in_dim(t718, (1, 16384, 4096), (0, 1, 2)) # t719: "cuda:0 f32[1, 16384, 4096]"
# t720 = prims.mul(t707, t719) # t720: "cuda:0 f32[1, 16384, 4096]"
# t721 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, dtypes.float32) # t721: "cuda:0 f32[4096]"
# t722 = prims.broadcast_in_dim(t721, (1, 16384, 4096), (2,)) # t722: "cuda:0 f32[1, 16384, 4096]"
# t723 = prims.mul(t720, t722) # t723: "cuda:0 f32[1, 16384, 4096]"
# t724 = prims.convert_element_type(t723, dtypes.bfloat16) # t724: "cuda:0 bf16[1, 16384, 4096]"
# t782 = prims.reshape(t724, (16384, 4096)) # t782: "cuda:0 bf16[16384, 4096]"
t726 = torch.nn.functional.linear(t724, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_, None) # t726: "cuda:0 bf16[1, 16384, 11008]"
# t726 = ltorch.linear(t724, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_, None) # t726: "cuda:0 bf16[1, 16384, 11008]"
# t726 = prims.linear(t724, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_, None) # t726: "cuda:0 bf16[1, 16384, 11008]"
t725 = torch.nn.functional.linear(t724, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_, None) # t725: "cuda:0 bf16[1, 16384, 11008]"
# t725 = ltorch.linear(t724, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_, None) # t725: "cuda:0 bf16[1, 16384, 11008]"
# t725 = prims.linear(t724, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_, None) # t725: "cuda:0 bf16[1, 16384, 11008]"
del t724
[t754, t777, t781, t784, t788] = nvFusion5(t725, t726, t750)
# t727 = prims.convert_element_type(t725, dtypes.float32) # t727: "cuda:0 f32[1, 16384, 11008]"
# t728 = prims.neg(t727) # t728: "cuda:0 f32[1, 16384, 11008]"
# t729 = prims.exp(t728) # t729: "cuda:0 f32[1, 16384, 11008]"
# t730 = prims.add(1.0, t729) # t730: "cuda:0 f32[1, 16384, 11008]"
# t731 = prims.reciprocal(t730) # t731: "cuda:0 f32[1, 16384, 11008]"
# t735 = prims.mul(t727, t731) # t735: "cuda:0 f32[1, 16384, 11008]"
# t738 = prims.convert_element_type(t726, dtypes.float32) # t738: "cuda:0 f32[1, 16384, 11008]"
# t739 = prims.mul(t735, t738) # t739: "cuda:0 f32[1, 16384, 11008]"
# t740 = prims.convert_element_type(t739, dtypes.bfloat16) # t740: "cuda:0 bf16[1, 16384, 11008]"
# t751 = prims.reshape(t750, (1, 16384, 11008)) # t751: "cuda:0 bf16[1, 16384, 11008]"
# t754 = prims.reshape(t740, (16384, 11008)) # t754: "cuda:0 bf16[16384, 11008]"
# t756 = prims.convert_element_type(t751, dtypes.float32) # t756: "cuda:0 f32[1, 16384, 11008]"
# t757 = prims.mul(t738, t756) # t757: "cuda:0 f32[1, 16384, 11008]"
# t758 = prims.mul(t735, t756) # t758: "cuda:0 f32[1, 16384, 11008]"
# t759 = prims.convert_element_type(t758, dtypes.bfloat16) # t759: "cuda:0 bf16[1, 16384, 11008]"
# t762 = prims.mul(t731, t757) # t762: "cuda:0 f32[1, 16384, 11008]"
# t763 = prims.mul(t727, t757) # t763: "cuda:0 f32[1, 16384, 11008]"
# t767 = prims.neg(t763) # t767: "cuda:0 f32[1, 16384, 11008]"
# t768 = prims.mul(t767, t731) # t768: "cuda:0 f32[1, 16384, 11008]"
# t769 = prims.mul(t768, t731) # t769: "cuda:0 f32[1, 16384, 11008]"
# t770 = prims.mul(t769, t729) # t770: "cuda:0 f32[1, 16384, 11008]"
# t771 = prims.neg(t770) # t771: "cuda:0 f32[1, 16384, 11008]"
# t775 = prims.add(t762, t771) # t775: "cuda:0 f32[1, 16384, 11008]"
# t776 = prims.convert_element_type(t775, dtypes.bfloat16) # t776: "cuda:0 bf16[1, 16384, 11008]"
# t777 = prims.reshape(t759, (16384, 11008)) # t777: "cuda:0 bf16[16384, 11008]"
# t781 = prims.transpose(t777, (1, 0)) # t781: "cuda:0 bf16[11008, 16384]"
# t784 = prims.reshape(t776, (16384, 11008)) # t784: "cuda:0 bf16[16384, 11008]"
# t788 = prims.transpose(t784, (1, 0)) # t788: "cuda:0 bf16[11008, 16384]"
del t725, t726, t750
t778 = torch.matmul(t777, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_) # t778: "cuda:0 bf16[16384, 4096]"
# t778 = ltorch.matmul(t777, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_) # t778: "cuda:0 bf16[16384, 4096]"
# t778 = prims.matmul(t777, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_) # t778: "cuda:0 bf16[16384, 4096]"
del t777, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_
t785 = torch.matmul(t784, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_) # t785: "cuda:0 bf16[16384, 4096]"
# t785 = ltorch.matmul(t784, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_) # t785: "cuda:0 bf16[16384, 4096]"
# t785 = prims.matmul(t784, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_) # t785: "cuda:0 bf16[16384, 4096]"
del t784, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_
[t822, t830, t799, t831] = nvFusion6(t785, t778, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, x_1, t704, t718, t614)
# t786 = prims.reshape(t785, (1, 16384, 4096)) # t786: "cuda:0 bf16[1, 16384, 4096]"
# t779 = prims.reshape(t778, (1, 16384, 4096)) # t779: "cuda:0 bf16[1, 16384, 4096]"
# t792 = prims.convert_element_type(t786, dtypes.float32) # t792: "cuda:0 f32[1, 16384, 4096]"
# t791 = prims.convert_element_type(t779, dtypes.float32) # t791: "cuda:0 f32[1, 16384, 4096]"
# t721 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, dtypes.float32) # t721: "cuda:0 f32[4096]"
# t793 = prims.add(t791, t792) # t793: "cuda:0 f32[1, 16384, 4096]"
# t722 = prims.broadcast_in_dim(t721, (1, 16384, 4096), (2,)) # t722: "cuda:0 f32[1, 16384, 4096]"
# t706 = prims.convert_element_type(x_1, dtypes.float32) # t706: "cuda:0 f32[1, 16384, 4096]"
# t705 = prims.convert_element_type(t704, dtypes.float32) # t705: "cuda:0 f32[1, 16384, 4096]"
# t796 = prims.mul(t722, t793) # t796: "cuda:0 f32[1, 16384, 4096]"
# t707 = prims.add(t705, t706) # t707: "cuda:0 f32[1, 16384, 4096]"
# t801 = prims.mul(t707, t796) # t801: "cuda:0 f32[1, 16384, 4096]"
# t802 = prims.sum(t801, (0, 2)) # t802: "cuda:0 f32[16384]"
# t803 = prims.broadcast_in_dim(t802, [1, 16384, 1], [1]) # t803: "cuda:0 f32[1, 16384, 1]"
# t805 = prims.pow(t718, 3.0) # t805: "cuda:0 f32[1, 16384, 1]"
# t804 = prims.mul(-0.5, t803) # t804: "cuda:0 f32[1, 16384, 1]"
# t806 = prims.mul(t804, t805) # t806: "cuda:0 f32[1, 16384, 1]"
# t809 = prims.div(t806, 4096.0) # t809: "cuda:0 f32[1, 16384, 1]"
# t810 = prims.sum(t809, (0, 2)) # t810: "cuda:0 f32[16384]"
# t811 = prims.broadcast_in_dim(t810, [1, 16384], [1]) # t811: "cuda:0 f32[1, 16384]"
# t813 = prims.broadcast_in_dim(t811, [1, 16384, 1], [0, 1]) # t813: "cuda:0 f32[1, 16384, 1]"
# t814 = prims.broadcast_in_dim(t813, (1, 16384, 4096), (0, 1, 2)) # t814: "cuda:0 f32[1, 16384, 4096]"
# t719 = prims.broadcast_in_dim(t718, (1, 16384, 4096), (0, 1, 2)) # t719: "cuda:0 f32[1, 16384, 4096]"
# t815 = prims.mul(t707, t814) # t815: "cuda:0 f32[1, 16384, 4096]"
# t800 = prims.mul(t719, t796) # t800: "cuda:0 f32[1, 16384, 4096]"
# t817 = prims.add(t800, t815) # t817: "cuda:0 f32[1, 16384, 4096]"
# t818 = prims.add(t817, t815) # t818: "cuda:0 f32[1, 16384, 4096]"
# t820 = prims.convert_element_type(t614, dtypes.float32) # t820: "cuda:0 f32[1, 16384, 4096]"
# t720 = prims.mul(t707, t719) # t720: "cuda:0 f32[1, 16384, 4096]"
# t822 = prims.add(t820, t818) # t822: "cuda:0 f32[1, 16384, 4096]"
# t797 = prims.mul(t720, t793) # t797: "cuda:0 f32[1, 16384, 4096]"
# t823 = prims.convert_element_type(t822, dtypes.bfloat16) # t823: "cuda:0 bf16[1, 16384, 4096]"
# t798 = prims.sum(t797, (0, 1)) # t798: "cuda:0 f32[4096]"
# t830 = prims.reshape(t823, (16384, 4096)) # t830: "cuda:0 bf16[16384, 4096]"
# t799 = prims.convert_element_type(t798, dtypes.bfloat16) # t799: "cuda:0 bf16[4096]"
# t831 = prims.transpose(t830, (1, 0)) # t831: "cuda:0 bf16[4096, 16384]"
del t785, t778, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, t704, t718, t614
t828 = torch.matmul(t830, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_) # t828: "cuda:0 bf16[16384, 4096]"
# t828 = ltorch.matmul(t830, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_) # t828: "cuda:0 bf16[16384, 4096]"
# t828 = prims.matmul(t830, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_) # t828: "cuda:0 bf16[16384, 4096]"
del t830, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_
[t840] = nvFusion7(t828)
# t829 = prims.reshape(t828, (1, 16384, 4096)) # t829: "cuda:0 bf16[1, 16384, 4096]"
# t837 = prims.reshape(t829, (1, 16384, 32, 128)) # t837: "cuda:0 bf16[1, 16384, 32, 128]"
# t840 = prims.transpose(t837, (0, 2, 1, 3)) # t840: "cuda:0 bf16[1, 32, 16384, 128]"
del t828
(t841, t842, t843) = cudnn_sdpa_bwd(t840, t694, t697, t649, None, 0.0, True, t698, t699, t700, t701, scale=0.08838834764831843, cat_grad_qkv=False)
del t840, t694, t697, t649, t698, t699, t700, t701
[t973, t974] = TorchCompile2(t842, t841, l_self_buffers_sin_, l_self_buffers_cos_, t843)
# t845 = prims.slice_prim(t842, [0, 0, 0, 0], [1, 32, 16384, 128], [1, 1, 1, 1]) # t845: "cuda:0 bf16[1, 32, 16384, 128]"
# t849 = prims.slice_prim(t841, [0, 0, 0, 0], [1, 32, 16384, 128], [1, 1, 1, 1]) # t849: "cuda:0 bf16[1, 32, 16384, 128]"
# t852 = prims.convert_element_type(t845, dtypes.float32) # t852: "cuda:0 f32[1, 32, 16384, 128]"
# t893 = prims.convert_element_type(t849, dtypes.float32) # t893: "cuda:0 f32[1, 32, 16384, 128]"
# t684 = prims.broadcast_in_dim(l_self_buffers_sin_, (1, 32, 16384, 128), (2, 3)) # t684: "cuda:0 bf16[1, 32, 16384, 128]"
# t686 = prims.convert_element_type(t684, dtypes.float32) # t686: "cuda:0 f32[1, 32, 16384, 128]"
# t856 = ltorch.mul(t686, t852) # t856: "cuda:0 f32[1, 32, 16384, 128]"
# t856 = prims.mul(t686, t852) # t856: "cuda:0 f32[1, 32, 16384, 128]"
# t679 = prims.broadcast_in_dim(l_self_buffers_cos_, (1, 32, 16384, 128), (2, 3)) # t679: "cuda:0 bf16[1, 32, 16384, 128]"
# t897 = ltorch.mul(t686, t893) # t897: "cuda:0 f32[1, 32, 16384, 128]"
# t897 = prims.mul(t686, t893) # t897: "cuda:0 f32[1, 32, 16384, 128]"
# t859 = prims.convert_element_type(t856, dtypes.bfloat16) # t859: "cuda:0 bf16[1, 32, 16384, 128]"
# t681 = prims.convert_element_type(t679, dtypes.float32) # t681: "cuda:0 f32[1, 32, 16384, 128]"
# t900 = prims.convert_element_type(t897, dtypes.bfloat16) # t900: "cuda:0 bf16[1, 32, 16384, 128]"
# t876 = prims.slice_prim(t859, [0, 0, 0, 0], [1, 32, 16384, 64], [1, 1, 1, 1]) # t876: "cuda:0 bf16[1, 32, 16384, 64]"
# t864 = ltorch.mul(t681, t852) # t864: "cuda:0 f32[1, 32, 16384, 128]"
# t864 = prims.mul(t681, t852) # t864: "cuda:0 f32[1, 32, 16384, 128]"
# t846 = prims.full((1, 32, 16384, 0), 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t846: "cuda:0 bf16[1, 32, 16384, 0]"
# t925 = prims.slice_prim(t900, [0, 0, 0, 0], [1, 32, 16384, 64], [1, 1, 1, 1]) # t925: "cuda:0 bf16[1, 32, 16384, 64]"
# t909 = ltorch.mul(t681, t893) # t909: "cuda:0 f32[1, 32, 16384, 128]"
# t909 = prims.mul(t681, t893) # t909: "cuda:0 f32[1, 32, 16384, 128]"
# t878 = prims.convert_element_type(t876, dtypes.float32) # t878: "cuda:0 f32[1, 32, 16384, 64]"
# t867 = prims.convert_element_type(t864, dtypes.bfloat16) # t867: "cuda:0 bf16[1, 32, 16384, 128]"
# t847 = prims.pad(t846, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 128, 0))) # t847: "cuda:0 bf16[1, 32, 16384, 128]"
# t927 = prims.convert_element_type(t925, dtypes.float32) # t927: "cuda:0 f32[1, 32, 16384, 64]"
# t912 = prims.convert_element_type(t909, dtypes.bfloat16) # t912: "cuda:0 bf16[1, 32, 16384, 128]"
# t879 = ltorch.neg(t878) # t879: "cuda:0 f32[1, 32, 16384, 64]"
# t879 = prims.neg(t878) # t879: "cuda:0 f32[1, 32, 16384, 64]"
# t871 = prims.convert_element_type(t847, dtypes.float32) # t871: "cuda:0 f32[1, 32, 16384, 128]"
# t928 = ltorch.neg(t927) # t928: "cuda:0 f32[1, 32, 16384, 64]"
# t928 = prims.neg(t927) # t928: "cuda:0 f32[1, 32, 16384, 64]"
# t880 = prims.convert_element_type(t879, dtypes.bfloat16) # t880: "cuda:0 bf16[1, 32, 16384, 64]"
# t873 = prims.add(t871, t864) # t873: "cuda:0 f32[1, 32, 16384, 128]"
# t929 = prims.convert_element_type(t928, dtypes.bfloat16) # t929: "cuda:0 bf16[1, 32, 16384, 64]"
# t918 = prims.add(t871, t909) # t918: "cuda:0 f32[1, 32, 16384, 128]"
# t882 = prims.pad(t880, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (64, 0, 0))) # t882: "cuda:0 bf16[1, 32, 16384, 128]"
# t874 = prims.convert_element_type(t873, dtypes.bfloat16) # t874: "cuda:0 bf16[1, 32, 16384, 128]"
# t931 = prims.pad(t929, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (64, 0, 0))) # t931: "cuda:0 bf16[1, 32, 16384, 128]"
# t919 = prims.convert_element_type(t918, dtypes.bfloat16) # t919: "cuda:0 bf16[1, 32, 16384, 128]"
# t884 = prims.convert_element_type(t882, dtypes.float32) # t884: "cuda:0 f32[1, 32, 16384, 128]"
# t933 = prims.convert_element_type(t931, dtypes.float32) # t933: "cuda:0 f32[1, 32, 16384, 128]"
# t877 = prims.slice_prim(t859, [0, 0, 0, 64], [1, 32, 16384, 128], [1, 1, 1, 1]) # t877: "cuda:0 bf16[1, 32, 16384, 64]"
# t885 = prims.add(t873, t884) # t885: "cuda:0 f32[1, 32, 16384, 128]"
# t926 = prims.slice_prim(t900, [0, 0, 0, 64], [1, 32, 16384, 128], [1, 1, 1, 1]) # t926: "cuda:0 bf16[1, 32, 16384, 64]"
# t934 = prims.add(t918, t933) # t934: "cuda:0 f32[1, 32, 16384, 128]"
# t888 = prims.pad(t877, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t888: "cuda:0 bf16[1, 32, 16384, 128]"
# t886 = prims.convert_element_type(t885, dtypes.bfloat16) # t886: "cuda:0 bf16[1, 32, 16384, 128]"
# t937 = prims.pad(t926, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t937: "cuda:0 bf16[1, 32, 16384, 128]"
# t935 = prims.convert_element_type(t934, dtypes.bfloat16) # t935: "cuda:0 bf16[1, 32, 16384, 128]"
# t890 = prims.convert_element_type(t888, dtypes.float32) # t890: "cuda:0 f32[1, 32, 16384, 128]"
# t939 = prims.convert_element_type(t937, dtypes.float32) # t939: "cuda:0 f32[1, 32, 16384, 128]"
# t891 = prims.add(t885, t890) # t891: "cuda:0 f32[1, 32, 16384, 128]"
# t940 = prims.add(t934, t939) # t940: "cuda:0 f32[1, 32, 16384, 128]"
# t892 = prims.convert_element_type(t891, dtypes.bfloat16) # t892: "cuda:0 bf16[1, 32, 16384, 128]"
# t941 = prims.convert_element_type(t940, dtypes.bfloat16) # t941: "cuda:0 bf16[1, 32, 16384, 128]"
# t946 = prims.reshape(t843, (1, 32, 1, 16384, 128)) # t946: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t951 = prims.reshape(t892, (1, 32, 1, 16384, 128)) # t951: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t956 = prims.reshape(t941, (1, 32, 1, 16384, 128)) # t956: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t957 = ltorch.cat((t956, t951, t946), 2) # t957: "cuda:0 bf16[1, 32, 3, 16384, 128]"
# t957 = prims.cat([t956, t951, t946], 2) # t957: "cuda:0 bf16[1, 32, 3, 16384, 128]"
# t963 = prims.transpose(t957, (0, 3, 1, 2, 4)) # t963: "cuda:0 bf16[1, 16384, 32, 3, 128]"
# t969 = prims.reshape(t963, (1, 16384, 12288)) # t969: "cuda:0 bf16[1, 16384, 12288]"
# t973 = ltorch.reshape(t969, -1, 12288) # t973: "cuda:0 bf16[16384, 12288]"
# t973 = prims.reshape(t969, (16384, 12288)) # t973: "cuda:0 bf16[16384, 12288]"
# t974 = prims.transpose(t973, (1, 0)) # t974: "cuda:0 bf16[12288, 16384]"
del t842, t841, t843
t971 = torch.matmul(t973, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_) # t971: "cuda:0 bf16[16384, 4096]"
# t971 = ltorch.matmul(t973, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_) # t971: "cuda:0 bf16[16384, 4096]"
# t971 = prims.matmul(t973, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_) # t971: "cuda:0 bf16[16384, 4096]"
del t973, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_
[t1005, t1167, t1168, t981] = nvFusion8(t971, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, x_1, t624, t822)
# t972 = prims.reshape(t971, (1, 16384, 4096)) # t972: "cuda:0 bf16[1, 16384, 4096]"
# t627 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, dtypes.float32) # t627: "cuda:0 f32[4096]"
# t977 = prims.convert_element_type(t972, dtypes.float32) # t977: "cuda:0 f32[1, 16384, 4096]"
# t628 = prims.broadcast_in_dim(t627, (1, 16384, 4096), (2,)) # t628: "cuda:0 f32[1, 16384, 4096]"
# t978 = prims.mul(t628, t977) # t978: "cuda:0 f32[1, 16384, 4096]"
# t615 = prims.convert_element_type(x_1, dtypes.float32) # t615: "cuda:0 f32[1, 16384, 4096]"
# t983 = prims.mul(t615, t978) # t983: "cuda:0 f32[1, 16384, 4096]"
# t984 = prims.sum(t983, (0, 2)) # t984: "cuda:0 f32[16384]"
# t985 = prims.broadcast_in_dim(t984, [1, 16384, 1], [1]) # t985: "cuda:0 f32[1, 16384, 1]"
# t987 = prims.pow(t624, 3.0) # t987: "cuda:0 f32[1, 16384, 1]"
# t986 = prims.mul(-0.5, t985) # t986: "cuda:0 f32[1, 16384, 1]"
# t988 = prims.mul(t986, t987) # t988: "cuda:0 f32[1, 16384, 1]"
# t991 = prims.div(t988, 4096.0) # t991: "cuda:0 f32[1, 16384, 1]"
# t992 = prims.sum(t991, (0, 2)) # t992: "cuda:0 f32[16384]"
# t993 = prims.broadcast_in_dim(t992, [1, 16384], [1]) # t993: "cuda:0 f32[1, 16384]"
# t995 = prims.broadcast_in_dim(t993, [1, 16384, 1], [0, 1]) # t995: "cuda:0 f32[1, 16384, 1]"
# t996 = prims.broadcast_in_dim(t995, (1, 16384, 4096), (0, 1, 2)) # t996: "cuda:0 f32[1, 16384, 4096]"
# t625 = prims.broadcast_in_dim(t624, (1, 16384, 4096), (0, 1, 2)) # t625: "cuda:0 f32[1, 16384, 4096]"
# t997 = prims.mul(t615, t996) # t997: "cuda:0 f32[1, 16384, 4096]"
# t982 = prims.mul(t625, t978) # t982: "cuda:0 f32[1, 16384, 4096]"
# t999 = prims.add(t982, t997) # t999: "cuda:0 f32[1, 16384, 4096]"
# t1000 = prims.add(t999, t997) # t1000: "cuda:0 f32[1, 16384, 4096]"
# t1004 = prims.add(t822, t1000) # t1004: "cuda:0 f32[1, 16384, 4096]"
# t1005 = prims.convert_element_type(t1004, dtypes.bfloat16) # t1005: "cuda:0 bf16[1, 16384, 4096]"
# t626 = prims.mul(t615, t625) # t626: "cuda:0 f32[1, 16384, 4096]"
# t979 = prims.mul(t626, t977) # t979: "cuda:0 f32[1, 16384, 4096]"
# t1167 = prims.reshape(t1005, (16384, 4096)) # t1167: "cuda:0 bf16[16384, 4096]"
# t980 = prims.sum(t979, (0, 1)) # t980: "cuda:0 f32[4096]"
# t1168 = prims.transpose(t1167, (1, 0)) # t1168: "cuda:0 bf16[4096, 16384]"
# t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: "cuda:0 bf16[4096]"
del t971, l_self_modules_transformer_modules_h_modules_1_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, x_1, t624, t822
t1165 = torch.matmul(t1167, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_) # t1165: "cuda:0 bf16[16384, 11008]"
# t1165 = ltorch.matmul(t1167, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_) # t1165: "cuda:0 bf16[16384, 11008]"
# t1165 = prims.matmul(t1167, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_) # t1165: "cuda:0 bf16[16384, 11008]"
del t1167, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_proj_parameters_weight_
[t1015, t1021, t1390] = nvFusion9(x, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_)
# t1006 = prims.convert_element_type(x, dtypes.float32) # t1006: "cuda:0 f32[1, 16384, 4096]"
# t1007 = prims.mul(t1006, t1006) # t1007: "cuda:0 f32[1, 16384, 4096]"
# t1009 = prims.sum(t1007, (2,)) # t1009: "cuda:0 f32[1, 16384]"
# t1010 = prims.broadcast_in_dim(t1009, [1, 16384, 1], [0, 1]) # t1010: "cuda:0 f32[1, 16384, 1]"
# t1012 = prims.div(t1010, 4096.0) # t1012: "cuda:0 f32[1, 16384, 1]"
# t1014 = prims.add(t1012, 1e-06) # t1014: "cuda:0 f32[1, 16384, 1]"
# t1015 = prims.rsqrt(t1014) # t1015: "cuda:0 f32[1, 16384, 1]"
# t1016 = prims.broadcast_in_dim(t1015, (1, 16384, 4096), (0, 1, 2)) # t1016: "cuda:0 f32[1, 16384, 4096]"
# t1017 = prims.mul(t1006, t1016) # t1017: "cuda:0 f32[1, 16384, 4096]"
# t1018 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, dtypes.float32) # t1018: "cuda:0 f32[4096]"
# t1019 = prims.broadcast_in_dim(t1018, (1, 16384, 4096), (2,)) # t1019: "cuda:0 f32[1, 16384, 4096]"
# t1020 = prims.mul(t1017, t1019) # t1020: "cuda:0 f32[1, 16384, 4096]"
# t1021 = prims.convert_element_type(t1020, dtypes.bfloat16) # t1021: "cuda:0 bf16[1, 16384, 4096]"
# t1390 = prims.reshape(t1021, (16384, 4096)) # t1390: "cuda:0 bf16[16384, 4096]"
t1022 = torch.nn.functional.linear(t1021, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_, None) # t1022: "cuda:0 bf16[1, 16384, 12288]"
# t1022 = ltorch.linear(t1021, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_, None) # t1022: "cuda:0 bf16[1, 16384, 12288]"
# t1022 = prims.linear(t1021, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_, None) # t1022: "cuda:0 bf16[1, 16384, 12288]"
del t1021
[t1055, t1104, t1107] = TorchCompile0(t1022, l_self_buffers_cos_, l_self_buffers_sin_)
# t1028 = prims.reshape(t1022, (1, 16384, 32, 3, 128)) # t1028: "cuda:0 bf16[1, 16384, 32, 3, 128]"
# t1034 = prims.transpose(t1028, (0, 2, 3, 1, 4)) # t1034: "cuda:0 bf16[1, 32, 3, 16384, 128]"
# (t1035, t1036, t1037) = ltorch.split(t1034, (1, 1, 1), 2)
# t1035 = prims.slice_prim(t1034, [0, 0, 0, 0, 0], [1, 32, 1, 16384, 128], [1, 1, 1, 1, 1]) # t1035: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t1036 = prims.slice_prim(t1034, [0, 0, 1, 0, 0], [1, 32, 2, 16384, 128], [1, 1, 1, 1, 1]) # t1036: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t1037 = prims.slice_prim(t1034, [0, 0, 2, 0, 0], [1, 32, 3, 16384, 128], [1, 1, 1, 1, 1]) # t1037: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t1043 = prims.reshape(t1035, (1, 32, 16384, 128)) # t1043: "cuda:0 bf16[1, 32, 16384, 128]"
# t1049 = prims.reshape(t1036, (1, 32, 16384, 128)) # t1049: "cuda:0 bf16[1, 32, 16384, 128]"
# t1055 = prims.reshape(t1037, (1, 32, 16384, 128)) # t1055: "cuda:0 bf16[1, 32, 16384, 128]"
# t1057 = prims.slice_prim(t1043, [0, 0, 0, 0], [1, 32, 16384, 64], [1, 1, 1, 1]) # t1057: "cuda:0 bf16[1, 32, 16384, 64]"
# t1059 = prims.slice_prim(t1043, [0, 0, 0, 64], [1, 32, 16384, 128], [1, 1, 1, 1]) # t1059: "cuda:0 bf16[1, 32, 16384, 64]"
# t1060 = prims.convert_element_type(t1059, dtypes.float32) # t1060: "cuda:0 f32[1, 32, 16384, 64]"
# t1061 = prims.neg(t1060) # t1061: "cuda:0 f32[1, 32, 16384, 64]"
# t1062 = prims.convert_element_type(t1061, dtypes.bfloat16) # t1062: "cuda:0 bf16[1, 32, 16384, 64]"
# t1064 = prims.cat([t1062, t1057], -1) # t1064: "cuda:0 bf16[1, 32, 16384, 128]"
# t1065 = prims.broadcast_in_dim(l_self_buffers_cos_, (1, 32, 16384, 128), (2, 3)) # t1065: "cuda:0 bf16[1, 32, 16384, 128]"
# t1066 = prims.convert_element_type(t1043, dtypes.float32) # t1066: "cuda:0 f32[1, 32, 16384, 128]"
# t1067 = prims.convert_element_type(t1065, dtypes.float32) # t1067: "cuda:0 f32[1, 32, 16384, 128]"
# t1068 = ltorch.mul(t1066, t1067) # t1068: "cuda:0 f32[1, 32, 16384, 128]"
# t1068 = prims.mul(t1066, t1067) # t1068: "cuda:0 f32[1, 32, 16384, 128]"
# t1069 = prims.convert_element_type(t1068, dtypes.bfloat16) # t1069: "cuda:0 bf16[1, 32, 16384, 128]"
# t1070 = prims.broadcast_in_dim(l_self_buffers_sin_, (1, 32, 16384, 128), (2, 3)) # t1070: "cuda:0 bf16[1, 32, 16384, 128]"
# t1071 = prims.convert_element_type(t1064, dtypes.float32) # t1071: "cuda:0 f32[1, 32, 16384, 128]"
# t1072 = prims.convert_element_type(t1070, dtypes.float32) # t1072: "cuda:0 f32[1, 32, 16384, 128]"
# t1073 = ltorch.mul(t1071, t1072) # t1073: "cuda:0 f32[1, 32, 16384, 128]"
# t1073 = prims.mul(t1071, t1072) # t1073: "cuda:0 f32[1, 32, 16384, 128]"
# t1074 = prims.convert_element_type(t1073, dtypes.bfloat16) # t1074: "cuda:0 bf16[1, 32, 16384, 128]"
# t1077 = ltorch.add(t1068, t1073, alpha=1) # t1077: "cuda:0 f32[1, 32, 16384, 128]"
# t1077 = prims.add(t1068, t1073) # t1077: "cuda:0 f32[1, 32, 16384, 128]"
# t1078 = prims.convert_element_type(t1077, dtypes.bfloat16) # t1078: "cuda:0 bf16[1, 32, 16384, 128]"
# t1080 = prims.slice_prim(t1049, [0, 0, 0, 0], [1, 32, 16384, 64], [1, 1, 1, 1]) # t1080: "cuda:0 bf16[1, 32, 16384, 64]"
# t1082 = prims.slice_prim(t1049, [0, 0, 0, 64], [1, 32, 16384, 128], [1, 1, 1, 1]) # t1082: "cuda:0 bf16[1, 32, 16384, 64]"
# t1083 = prims.convert_element_type(t1082, dtypes.float32) # t1083: "cuda:0 f32[1, 32, 16384, 64]"
# t1084 = prims.neg(t1083) # t1084: "cuda:0 f32[1, 32, 16384, 64]"
# t1085 = prims.convert_element_type(t1084, dtypes.bfloat16) # t1085: "cuda:0 bf16[1, 32, 16384, 64]"
# t1087 = prims.cat([t1085, t1080], -1) # t1087: "cuda:0 bf16[1, 32, 16384, 128]"
# t1089 = prims.convert_element_type(t1049, dtypes.float32) # t1089: "cuda:0 f32[1, 32, 16384, 128]"
# t1091 = ltorch.mul(t1089, t1067) # t1091: "cuda:0 f32[1, 32, 16384, 128]"
# t1091 = prims.mul(t1089, t1067) # t1091: "cuda:0 f32[1, 32, 16384, 128]"
# t1092 = prims.convert_element_type(t1091, dtypes.bfloat16) # t1092: "cuda:0 bf16[1, 32, 16384, 128]"
# t1094 = prims.convert_element_type(t1087, dtypes.float32) # t1094: "cuda:0 f32[1, 32, 16384, 128]"
# t1096 = ltorch.mul(t1094, t1072) # t1096: "cuda:0 f32[1, 32, 16384, 128]"
# t1096 = prims.mul(t1094, t1072) # t1096: "cuda:0 f32[1, 32, 16384, 128]"
# t1097 = prims.convert_element_type(t1096, dtypes.bfloat16) # t1097: "cuda:0 bf16[1, 32, 16384, 128]"
# t1100 = ltorch.add(t1091, t1096, alpha=1) # t1100: "cuda:0 f32[1, 32, 16384, 128]"
# t1100 = prims.add(t1091, t1096) # t1100: "cuda:0 f32[1, 32, 16384, 128]"
# t1101 = prims.convert_element_type(t1100, dtypes.bfloat16) # t1101: "cuda:0 bf16[1, 32, 16384, 128]"
# t1102 = prims.slice_prim(t1043, [0, 0, 0, 0], [1, 32, 16384, 0], [1, 1, 1, 1]) # t1102: "cuda:0 bf16[1, 32, 16384, 0]"
# t1104 = prims.cat([t1078, t1102], -1) # t1104: "cuda:0 bf16[1, 32, 16384, 128]"
# t1105 = prims.slice_prim(t1049, [0, 0, 0, 0], [1, 32, 16384, 0], [1, 1, 1, 1]) # t1105: "cuda:0 bf16[1, 32, 16384, 0]"
# t1107 = prims.cat([t1101, t1105], -1) # t1107: "cuda:0 bf16[1, 32, 16384, 128]"
del t1022
(t1108, t1109, t1110, t1111) = cudnn_sdpa_fwd(t1104, t1107, t1055, None, 0.0, True, scale=0.08838834764831843)
[t1118, t1247] = nvFusion10(t1108)
# t1114 = prims.transpose(t1108, (0, 2, 1, 3)) # t1114: "cuda:0 bf16[1, 16384, 32, 128]"
# t1118 = prims.reshape(t1114, (1, 16384, 4096)) # t1118: "cuda:0 bf16[1, 16384, 4096]"
# t1247 = prims.reshape(t1118, (16384, 4096)) # t1247: "cuda:0 bf16[16384, 4096]"
t1119 = torch.nn.functional.linear(t1118, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_, None) # t1119: "cuda:0 bf16[1, 16384, 4096]"
# t1119 = ltorch.linear(t1118, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_, None) # t1119: "cuda:0 bf16[1, 16384, 4096]"
# t1119 = prims.linear(t1118, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_, None) # t1119: "cuda:0 bf16[1, 16384, 4096]"
del t1118
[t1133, t1139, t1197] = nvFusion11(t1119, x, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_)
# t1120 = prims.convert_element_type(t1119, dtypes.float32) # t1120: "cuda:0 f32[1, 16384, 4096]"
# t1121 = prims.convert_element_type(x, dtypes.float32) # t1121: "cuda:0 f32[1, 16384, 4096]"
# t1122 = prims.add(t1120, t1121) # t1122: "cuda:0 f32[1, 16384, 4096]"
# t1125 = prims.mul(t1122, t1122) # t1125: "cuda:0 f32[1, 16384, 4096]"
# t1127 = prims.sum(t1125, (2,)) # t1127: "cuda:0 f32[1, 16384]"
# t1128 = prims.broadcast_in_dim(t1127, [1, 16384, 1], [0, 1]) # t1128: "cuda:0 f32[1, 16384, 1]"
# t1130 = prims.div(t1128, 4096.0) # t1130: "cuda:0 f32[1, 16384, 1]"
# t1132 = prims.add(t1130, 1e-06) # t1132: "cuda:0 f32[1, 16384, 1]"
# t1133 = prims.rsqrt(t1132) # t1133: "cuda:0 f32[1, 16384, 1]"
# t1134 = prims.broadcast_in_dim(t1133, (1, 16384, 4096), (0, 1, 2)) # t1134: "cuda:0 f32[1, 16384, 4096]"
# t1135 = prims.mul(t1122, t1134) # t1135: "cuda:0 f32[1, 16384, 4096]"
# t1136 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, dtypes.float32) # t1136: "cuda:0 f32[4096]"
# t1137 = prims.broadcast_in_dim(t1136, (1, 16384, 4096), (2,)) # t1137: "cuda:0 f32[1, 16384, 4096]"
# t1138 = prims.mul(t1135, t1137) # t1138: "cuda:0 f32[1, 16384, 4096]"
# t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: "cuda:0 bf16[1, 16384, 4096]"
# t1197 = prims.reshape(t1139, (16384, 4096)) # t1197: "cuda:0 bf16[16384, 4096]"
t1141 = torch.nn.functional.linear(t1139, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_, None) # t1141: "cuda:0 bf16[1, 16384, 11008]"
# t1141 = ltorch.linear(t1139, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_, None) # t1141: "cuda:0 bf16[1, 16384, 11008]"
# t1141 = prims.linear(t1139, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_, None) # t1141: "cuda:0 bf16[1, 16384, 11008]"
t1140 = torch.nn.functional.linear(t1139, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_, None) # t1140: "cuda:0 bf16[1, 16384, 11008]"
# t1140 = ltorch.linear(t1139, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_, None) # t1140: "cuda:0 bf16[1, 16384, 11008]"
# t1140 = prims.linear(t1139, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_, None) # t1140: "cuda:0 bf16[1, 16384, 11008]"
del t1139
[t1169, t1192, t1196, t1199, t1203] = nvFusion12(t1140, t1141, t1165)
# t1142 = prims.convert_element_type(t1140, dtypes.float32) # t1142: "cuda:0 f32[1, 16384, 11008]"
# t1143 = prims.neg(t1142) # t1143: "cuda:0 f32[1, 16384, 11008]"
# t1144 = prims.exp(t1143) # t1144: "cuda:0 f32[1, 16384, 11008]"
# t1145 = prims.add(1.0, t1144) # t1145: "cuda:0 f32[1, 16384, 11008]"
# t1146 = prims.reciprocal(t1145) # t1146: "cuda:0 f32[1, 16384, 11008]"
# t1150 = prims.mul(t1142, t1146) # t1150: "cuda:0 f32[1, 16384, 11008]"
# t1153 = prims.convert_element_type(t1141, dtypes.float32) # t1153: "cuda:0 f32[1, 16384, 11008]"
# t1154 = prims.mul(t1150, t1153) # t1154: "cuda:0 f32[1, 16384, 11008]"
# t1155 = prims.convert_element_type(t1154, dtypes.bfloat16) # t1155: "cuda:0 bf16[1, 16384, 11008]"
# t1166 = prims.reshape(t1165, (1, 16384, 11008)) # t1166: "cuda:0 bf16[1, 16384, 11008]"
# t1169 = prims.reshape(t1155, (16384, 11008)) # t1169: "cuda:0 bf16[16384, 11008]"
# t1171 = prims.convert_element_type(t1166, dtypes.float32) # t1171: "cuda:0 f32[1, 16384, 11008]"
# t1172 = prims.mul(t1153, t1171) # t1172: "cuda:0 f32[1, 16384, 11008]"
# t1173 = prims.mul(t1150, t1171) # t1173: "cuda:0 f32[1, 16384, 11008]"
# t1174 = prims.convert_element_type(t1173, dtypes.bfloat16) # t1174: "cuda:0 bf16[1, 16384, 11008]"
# t1177 = prims.mul(t1146, t1172) # t1177: "cuda:0 f32[1, 16384, 11008]"
# t1178 = prims.mul(t1142, t1172) # t1178: "cuda:0 f32[1, 16384, 11008]"
# t1182 = prims.neg(t1178) # t1182: "cuda:0 f32[1, 16384, 11008]"
# t1183 = prims.mul(t1182, t1146) # t1183: "cuda:0 f32[1, 16384, 11008]"
# t1184 = prims.mul(t1183, t1146) # t1184: "cuda:0 f32[1, 16384, 11008]"
# t1185 = prims.mul(t1184, t1144) # t1185: "cuda:0 f32[1, 16384, 11008]"
# t1186 = prims.neg(t1185) # t1186: "cuda:0 f32[1, 16384, 11008]"
# t1190 = prims.add(t1177, t1186) # t1190: "cuda:0 f32[1, 16384, 11008]"
# t1191 = prims.convert_element_type(t1190, dtypes.bfloat16) # t1191: "cuda:0 bf16[1, 16384, 11008]"
# t1192 = prims.reshape(t1174, (16384, 11008)) # t1192: "cuda:0 bf16[16384, 11008]"
# t1196 = prims.transpose(t1192, (1, 0)) # t1196: "cuda:0 bf16[11008, 16384]"
# t1199 = prims.reshape(t1191, (16384, 11008)) # t1199: "cuda:0 bf16[16384, 11008]"
# t1203 = prims.transpose(t1199, (1, 0)) # t1203: "cuda:0 bf16[11008, 16384]"
del t1140, t1141, t1165
t1193 = torch.matmul(t1192, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_) # t1193: "cuda:0 bf16[16384, 4096]"
# t1193 = ltorch.matmul(t1192, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_) # t1193: "cuda:0 bf16[16384, 4096]"
# t1193 = prims.matmul(t1192, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_) # t1193: "cuda:0 bf16[16384, 4096]"
del t1192, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_2_parameters_weight_
t1200 = torch.matmul(t1199, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_) # t1200: "cuda:0 bf16[16384, 4096]"
# t1200 = ltorch.matmul(t1199, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_) # t1200: "cuda:0 bf16[16384, 4096]"
# t1200 = prims.matmul(t1199, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_) # t1200: "cuda:0 bf16[16384, 4096]"
del t1199, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_mlp_modules_fc_1_parameters_weight_
[t1237, t1245, t1214, t1246] = nvFusion13(t1200, t1193, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, x, t1119, t1133, t1005)
# t1201 = prims.reshape(t1200, (1, 16384, 4096)) # t1201: "cuda:0 bf16[1, 16384, 4096]"
# t1194 = prims.reshape(t1193, (1, 16384, 4096)) # t1194: "cuda:0 bf16[1, 16384, 4096]"
# t1207 = prims.convert_element_type(t1201, dtypes.float32) # t1207: "cuda:0 f32[1, 16384, 4096]"
# t1206 = prims.convert_element_type(t1194, dtypes.float32) # t1206: "cuda:0 f32[1, 16384, 4096]"
# t1136 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, dtypes.float32) # t1136: "cuda:0 f32[4096]"
# t1208 = prims.add(t1206, t1207) # t1208: "cuda:0 f32[1, 16384, 4096]"
# t1137 = prims.broadcast_in_dim(t1136, (1, 16384, 4096), (2,)) # t1137: "cuda:0 f32[1, 16384, 4096]"
# t1121 = prims.convert_element_type(x, dtypes.float32) # t1121: "cuda:0 f32[1, 16384, 4096]"
# t1120 = prims.convert_element_type(t1119, dtypes.float32) # t1120: "cuda:0 f32[1, 16384, 4096]"
# t1211 = prims.mul(t1137, t1208) # t1211: "cuda:0 f32[1, 16384, 4096]"
# t1122 = prims.add(t1120, t1121) # t1122: "cuda:0 f32[1, 16384, 4096]"
# t1216 = prims.mul(t1122, t1211) # t1216: "cuda:0 f32[1, 16384, 4096]"
# t1217 = prims.sum(t1216, (0, 2)) # t1217: "cuda:0 f32[16384]"
# t1218 = prims.broadcast_in_dim(t1217, [1, 16384, 1], [1]) # t1218: "cuda:0 f32[1, 16384, 1]"
# t1220 = prims.pow(t1133, 3.0) # t1220: "cuda:0 f32[1, 16384, 1]"
# t1219 = prims.mul(-0.5, t1218) # t1219: "cuda:0 f32[1, 16384, 1]"
# t1221 = prims.mul(t1219, t1220) # t1221: "cuda:0 f32[1, 16384, 1]"
# t1224 = prims.div(t1221, 4096.0) # t1224: "cuda:0 f32[1, 16384, 1]"
# t1225 = prims.sum(t1224, (0, 2)) # t1225: "cuda:0 f32[16384]"
# t1226 = prims.broadcast_in_dim(t1225, [1, 16384], [1]) # t1226: "cuda:0 f32[1, 16384]"
# t1228 = prims.broadcast_in_dim(t1226, [1, 16384, 1], [0, 1]) # t1228: "cuda:0 f32[1, 16384, 1]"
# t1229 = prims.broadcast_in_dim(t1228, (1, 16384, 4096), (0, 1, 2)) # t1229: "cuda:0 f32[1, 16384, 4096]"
# t1134 = prims.broadcast_in_dim(t1133, (1, 16384, 4096), (0, 1, 2)) # t1134: "cuda:0 f32[1, 16384, 4096]"
# t1230 = prims.mul(t1122, t1229) # t1230: "cuda:0 f32[1, 16384, 4096]"
# t1215 = prims.mul(t1134, t1211) # t1215: "cuda:0 f32[1, 16384, 4096]"
# t1232 = prims.add(t1215, t1230) # t1232: "cuda:0 f32[1, 16384, 4096]"
# t1233 = prims.add(t1232, t1230) # t1233: "cuda:0 f32[1, 16384, 4096]"
# t1235 = prims.convert_element_type(t1005, dtypes.float32) # t1235: "cuda:0 f32[1, 16384, 4096]"
# t1135 = prims.mul(t1122, t1134) # t1135: "cuda:0 f32[1, 16384, 4096]"
# t1237 = prims.add(t1235, t1233) # t1237: "cuda:0 f32[1, 16384, 4096]"
# t1212 = prims.mul(t1135, t1208) # t1212: "cuda:0 f32[1, 16384, 4096]"
# t1238 = prims.convert_element_type(t1237, dtypes.bfloat16) # t1238: "cuda:0 bf16[1, 16384, 4096]"
# t1213 = prims.sum(t1212, (0, 1)) # t1213: "cuda:0 f32[4096]"
# t1245 = prims.reshape(t1238, (16384, 4096)) # t1245: "cuda:0 bf16[16384, 4096]"
# t1214 = prims.convert_element_type(t1213, dtypes.bfloat16) # t1214: "cuda:0 bf16[4096]"
# t1246 = prims.transpose(t1245, (1, 0)) # t1246: "cuda:0 bf16[4096, 16384]"
del t1200, t1193, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_2_parameters_weight_, t1119, t1133, t1005
t1243 = torch.matmul(t1245, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_) # t1243: "cuda:0 bf16[16384, 4096]"
# t1243 = ltorch.matmul(t1245, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_) # t1243: "cuda:0 bf16[16384, 4096]"
# t1243 = prims.matmul(t1245, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_) # t1243: "cuda:0 bf16[16384, 4096]"
del t1245, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_proj_parameters_weight_
[t1255] = nvFusion14(t1243)
# t1244 = prims.reshape(t1243, (1, 16384, 4096)) # t1244: "cuda:0 bf16[1, 16384, 4096]"
# t1252 = prims.reshape(t1244, (1, 16384, 32, 128)) # t1252: "cuda:0 bf16[1, 16384, 32, 128]"
# t1255 = prims.transpose(t1252, (0, 2, 1, 3)) # t1255: "cuda:0 bf16[1, 32, 16384, 128]"
del t1243
(t1256, t1257, t1258) = cudnn_sdpa_bwd(t1255, t1104, t1107, t1055, None, 0.0, True, t1108, t1109, t1110, t1111, scale=0.08838834764831843, cat_grad_qkv=False)
del t1255, t1104, t1107, t1055, t1108, t1109, t1110, t1111
[t1388, t1389] = TorchCompile3(t1257, t1256, l_self_buffers_sin_, l_self_buffers_cos_, t1258)
# t1260 = prims.slice_prim(t1257, [0, 0, 0, 0], [1, 32, 16384, 128], [1, 1, 1, 1]) # t1260: "cuda:0 bf16[1, 32, 16384, 128]"
# t1264 = prims.slice_prim(t1256, [0, 0, 0, 0], [1, 32, 16384, 128], [1, 1, 1, 1]) # t1264: "cuda:0 bf16[1, 32, 16384, 128]"
# t1267 = prims.convert_element_type(t1260, dtypes.float32) # t1267: "cuda:0 f32[1, 32, 16384, 128]"
# t1308 = prims.convert_element_type(t1264, dtypes.float32) # t1308: "cuda:0 f32[1, 32, 16384, 128]"
# t1093 = prims.broadcast_in_dim(l_self_buffers_sin_, (1, 32, 16384, 128), (2, 3)) # t1093: "cuda:0 bf16[1, 32, 16384, 128]"
# t1095 = prims.convert_element_type(t1093, dtypes.float32) # t1095: "cuda:0 f32[1, 32, 16384, 128]"
# t1271 = ltorch.mul(t1095, t1267) # t1271: "cuda:0 f32[1, 32, 16384, 128]"
# t1271 = prims.mul(t1095, t1267) # t1271: "cuda:0 f32[1, 32, 16384, 128]"
# t1088 = prims.broadcast_in_dim(l_self_buffers_cos_, (1, 32, 16384, 128), (2, 3)) # t1088: "cuda:0 bf16[1, 32, 16384, 128]"
# t1312 = ltorch.mul(t1095, t1308) # t1312: "cuda:0 f32[1, 32, 16384, 128]"
# t1312 = prims.mul(t1095, t1308) # t1312: "cuda:0 f32[1, 32, 16384, 128]"
# t1274 = prims.convert_element_type(t1271, dtypes.bfloat16) # t1274: "cuda:0 bf16[1, 32, 16384, 128]"
# t1090 = prims.convert_element_type(t1088, dtypes.float32) # t1090: "cuda:0 f32[1, 32, 16384, 128]"
# t1315 = prims.convert_element_type(t1312, dtypes.bfloat16) # t1315: "cuda:0 bf16[1, 32, 16384, 128]"
# t1291 = prims.slice_prim(t1274, [0, 0, 0, 0], [1, 32, 16384, 64], [1, 1, 1, 1]) # t1291: "cuda:0 bf16[1, 32, 16384, 64]"
# t1279 = ltorch.mul(t1090, t1267) # t1279: "cuda:0 f32[1, 32, 16384, 128]"
# t1279 = prims.mul(t1090, t1267) # t1279: "cuda:0 f32[1, 32, 16384, 128]"
# t1261 = prims.full((1, 32, 16384, 0), 0, device=devices.Device("cuda:0"), dtype=dtypes.bfloat16) # t1261: "cuda:0 bf16[1, 32, 16384, 0]"
# t1340 = prims.slice_prim(t1315, [0, 0, 0, 0], [1, 32, 16384, 64], [1, 1, 1, 1]) # t1340: "cuda:0 bf16[1, 32, 16384, 64]"
# t1324 = ltorch.mul(t1090, t1308) # t1324: "cuda:0 f32[1, 32, 16384, 128]"
# t1324 = prims.mul(t1090, t1308) # t1324: "cuda:0 f32[1, 32, 16384, 128]"
# t1293 = prims.convert_element_type(t1291, dtypes.float32) # t1293: "cuda:0 f32[1, 32, 16384, 64]"
# t1282 = prims.convert_element_type(t1279, dtypes.bfloat16) # t1282: "cuda:0 bf16[1, 32, 16384, 128]"
# t1262 = prims.pad(t1261, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 128, 0))) # t1262: "cuda:0 bf16[1, 32, 16384, 128]"
# t1342 = prims.convert_element_type(t1340, dtypes.float32) # t1342: "cuda:0 f32[1, 32, 16384, 64]"
# t1327 = prims.convert_element_type(t1324, dtypes.bfloat16) # t1327: "cuda:0 bf16[1, 32, 16384, 128]"
# t1294 = ltorch.neg(t1293) # t1294: "cuda:0 f32[1, 32, 16384, 64]"
# t1294 = prims.neg(t1293) # t1294: "cuda:0 f32[1, 32, 16384, 64]"
# t1286 = prims.convert_element_type(t1262, dtypes.float32) # t1286: "cuda:0 f32[1, 32, 16384, 128]"
# t1343 = ltorch.neg(t1342) # t1343: "cuda:0 f32[1, 32, 16384, 64]"
# t1343 = prims.neg(t1342) # t1343: "cuda:0 f32[1, 32, 16384, 64]"
# t1295 = prims.convert_element_type(t1294, dtypes.bfloat16) # t1295: "cuda:0 bf16[1, 32, 16384, 64]"
# t1288 = prims.add(t1286, t1279) # t1288: "cuda:0 f32[1, 32, 16384, 128]"
# t1344 = prims.convert_element_type(t1343, dtypes.bfloat16) # t1344: "cuda:0 bf16[1, 32, 16384, 64]"
# t1333 = prims.add(t1286, t1324) # t1333: "cuda:0 f32[1, 32, 16384, 128]"
# t1297 = prims.pad(t1295, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (64, 0, 0))) # t1297: "cuda:0 bf16[1, 32, 16384, 128]"
# t1289 = prims.convert_element_type(t1288, dtypes.bfloat16) # t1289: "cuda:0 bf16[1, 32, 16384, 128]"
# t1346 = prims.pad(t1344, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (64, 0, 0))) # t1346: "cuda:0 bf16[1, 32, 16384, 128]"
# t1334 = prims.convert_element_type(t1333, dtypes.bfloat16) # t1334: "cuda:0 bf16[1, 32, 16384, 128]"
# t1299 = prims.convert_element_type(t1297, dtypes.float32) # t1299: "cuda:0 f32[1, 32, 16384, 128]"
# t1348 = prims.convert_element_type(t1346, dtypes.float32) # t1348: "cuda:0 f32[1, 32, 16384, 128]"
# t1292 = prims.slice_prim(t1274, [0, 0, 0, 64], [1, 32, 16384, 128], [1, 1, 1, 1]) # t1292: "cuda:0 bf16[1, 32, 16384, 64]"
# t1300 = prims.add(t1288, t1299) # t1300: "cuda:0 f32[1, 32, 16384, 128]"
# t1341 = prims.slice_prim(t1315, [0, 0, 0, 64], [1, 32, 16384, 128], [1, 1, 1, 1]) # t1341: "cuda:0 bf16[1, 32, 16384, 64]"
# t1349 = prims.add(t1333, t1348) # t1349: "cuda:0 f32[1, 32, 16384, 128]"
# t1303 = prims.pad(t1292, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t1303: "cuda:0 bf16[1, 32, 16384, 128]"
# t1301 = prims.convert_element_type(t1300, dtypes.bfloat16) # t1301: "cuda:0 bf16[1, 32, 16384, 128]"
# t1352 = prims.pad(t1341, 0.0, ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 64, 0))) # t1352: "cuda:0 bf16[1, 32, 16384, 128]"
# t1350 = prims.convert_element_type(t1349, dtypes.bfloat16) # t1350: "cuda:0 bf16[1, 32, 16384, 128]"
# t1305 = prims.convert_element_type(t1303, dtypes.float32) # t1305: "cuda:0 f32[1, 32, 16384, 128]"
# t1354 = prims.convert_element_type(t1352, dtypes.float32) # t1354: "cuda:0 f32[1, 32, 16384, 128]"
# t1306 = prims.add(t1300, t1305) # t1306: "cuda:0 f32[1, 32, 16384, 128]"
# t1355 = prims.add(t1349, t1354) # t1355: "cuda:0 f32[1, 32, 16384, 128]"
# t1307 = prims.convert_element_type(t1306, dtypes.bfloat16) # t1307: "cuda:0 bf16[1, 32, 16384, 128]"
# t1356 = prims.convert_element_type(t1355, dtypes.bfloat16) # t1356: "cuda:0 bf16[1, 32, 16384, 128]"
# t1361 = prims.reshape(t1258, (1, 32, 1, 16384, 128)) # t1361: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t1366 = prims.reshape(t1307, (1, 32, 1, 16384, 128)) # t1366: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t1371 = prims.reshape(t1356, (1, 32, 1, 16384, 128)) # t1371: "cuda:0 bf16[1, 32, 1, 16384, 128]"
# t1372 = ltorch.cat((t1371, t1366, t1361), 2) # t1372: "cuda:0 bf16[1, 32, 3, 16384, 128]"
# t1372 = prims.cat([t1371, t1366, t1361], 2) # t1372: "cuda:0 bf16[1, 32, 3, 16384, 128]"
# t1378 = prims.transpose(t1372, (0, 3, 1, 2, 4)) # t1378: "cuda:0 bf16[1, 16384, 32, 3, 128]"
# t1384 = prims.reshape(t1378, (1, 16384, 12288)) # t1384: "cuda:0 bf16[1, 16384, 12288]"
# t1388 = ltorch.reshape(t1384, -1, 12288) # t1388: "cuda:0 bf16[16384, 12288]"
# t1388 = prims.reshape(t1384, (16384, 12288)) # t1388: "cuda:0 bf16[16384, 12288]"
# t1389 = prims.transpose(t1388, (1, 0)) # t1389: "cuda:0 bf16[12288, 16384]"
del t1257, t1256, l_self_buffers_sin_, l_self_buffers_cos_, t1258
t1386 = torch.matmul(t1388, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_) # t1386: "cuda:0 bf16[16384, 4096]"
# t1386 = ltorch.matmul(t1388, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_) # t1386: "cuda:0 bf16[16384, 4096]"
# t1386 = prims.matmul(t1388, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_) # t1386: "cuda:0 bf16[16384, 4096]"
del t1388, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_attn_modules_attn_parameters_weight_
[t1396, t1420] = nvFusion15(t1386, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, x, t1015, t1237)
# t1387 = prims.reshape(t1386, (1, 16384, 4096)) # t1387: "cuda:0 bf16[1, 16384, 4096]"
# t1018 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, dtypes.float32) # t1018: "cuda:0 f32[4096]"
# t1392 = prims.convert_element_type(t1387, dtypes.float32) # t1392: "cuda:0 f32[1, 16384, 4096]"
# t1019 = prims.broadcast_in_dim(t1018, (1, 16384, 4096), (2,)) # t1019: "cuda:0 f32[1, 16384, 4096]"
# t1393 = prims.mul(t1019, t1392) # t1393: "cuda:0 f32[1, 16384, 4096]"
# t1006 = prims.convert_element_type(x, dtypes.float32) # t1006: "cuda:0 f32[1, 16384, 4096]"
# t1398 = prims.mul(t1006, t1393) # t1398: "cuda:0 f32[1, 16384, 4096]"
# t1399 = prims.sum(t1398, (0, 2)) # t1399: "cuda:0 f32[16384]"
# t1400 = prims.broadcast_in_dim(t1399, [1, 16384, 1], [1]) # t1400: "cuda:0 f32[1, 16384, 1]"
# t1402 = prims.pow(t1015, 3.0) # t1402: "cuda:0 f32[1, 16384, 1]"
# t1401 = prims.mul(-0.5, t1400) # t1401: "cuda:0 f32[1, 16384, 1]"
# t1403 = prims.mul(t1401, t1402) # t1403: "cuda:0 f32[1, 16384, 1]"
# t1406 = prims.div(t1403, 4096.0) # t1406: "cuda:0 f32[1, 16384, 1]"
# t1407 = prims.sum(t1406, (0, 2)) # t1407: "cuda:0 f32[16384]"
# t1408 = prims.broadcast_in_dim(t1407, [1, 16384], [1]) # t1408: "cuda:0 f32[1, 16384]"
# t1410 = prims.broadcast_in_dim(t1408, [1, 16384, 1], [0, 1]) # t1410: "cuda:0 f32[1, 16384, 1]"
# t1411 = prims.broadcast_in_dim(t1410, (1, 16384, 4096), (0, 1, 2)) # t1411: "cuda:0 f32[1, 16384, 4096]"
# t1016 = prims.broadcast_in_dim(t1015, (1, 16384, 4096), (0, 1, 2)) # t1016: "cuda:0 f32[1, 16384, 4096]"
# t1412 = prims.mul(t1006, t1411) # t1412: "cuda:0 f32[1, 16384, 4096]"
# t1397 = prims.mul(t1016, t1393) # t1397: "cuda:0 f32[1, 16384, 4096]"
# t1414 = prims.add(t1397, t1412) # t1414: "cuda:0 f32[1, 16384, 4096]"
# t1415 = prims.add(t1414, t1412) # t1415: "cuda:0 f32[1, 16384, 4096]"
# t1017 = prims.mul(t1006, t1016) # t1017: "cuda:0 f32[1, 16384, 4096]"
# t1394 = prims.mul(t1017, t1392) # t1394: "cuda:0 f32[1, 16384, 4096]"
# t1419 = prims.add(t1237, t1415) # t1419: "cuda:0 f32[1, 16384, 4096]"
# t1395 = prims.sum(t1394, (0, 1)) # t1395: "cuda:0 f32[4096]"
# t1420 = prims.convert_element_type(t1419, dtypes.bfloat16) # t1420: "cuda:0 bf16[1, 16384, 4096]"
# t1396 = prims.convert_element_type(t1395, dtypes.bfloat16) # t1396: "cuda:0 bf16[4096]"
del t1386, l_self_modules_transformer_modules_h_modules_0_modules_checkpoint_wrapped_module_modules_norm_1_parameters_weight_, x, t1015, t1237
t1429 = torch.torch.ops.aten.embedding_backward(t1420, l_idx_, 32000, -1, False, False) # t1429: "cuda:0 bf16[32000, 4096]"
# t1429 = ltorch.embedding_backward(t1420, l_idx_, 32000, -1, False, False) # t1429: "cuda:0 bf16[32000, 4096]"
# t1429 = prims.embedding_backward(t1420, l_idx_, 32000, -1, False, False) # t1429: "cuda:0 bf16[32000, 4096]"
del t1420, l_idx_
t1391 = torch.matmul(t1389, t1390) # t1391: "cuda:0 bf16[12288, 4096]"
# t1391 = ltorch.matmul(t1389, t1390) # t1391: "cuda:0 bf16[12288, 4096]"
# t1391 = prims.matmul(t1389, t1390) # t1391: "cuda:0 bf16[12288, 4096]"
del t1389, t1390
t1248 = torch.matmul(t1246, t1247) # t1248: "cuda:0 bf16[4096, 4096]"
# t1248 = ltorch.matmul(t1246, t1247) # t1248: "cuda:0 bf16[4096, 4096]"
# t1248 = prims.matmul(t1246, t1247) # t1248: "cuda:0 bf16[4096, 4096]"
del t1246, t1247
t1205 = torch.matmul(t1203, t1197) # t1205: "cuda:0 bf16[11008, 4096]"
# t1205 = ltorch.matmul(t1203, t1197) # t1205: "cuda:0 bf16[11008, 4096]"
# t1205 = prims.matmul(t1203, t1197) # t1205: "cuda:0 bf16[11008, 4096]"
del t1203
t1198 = torch.matmul(t1196, t1197) # t1198: "cuda:0 bf16[11008, 4096]"
# t1198 = ltorch.matmul(t1196, t1197) # t1198: "cuda:0 bf16[11008, 4096]"
# t1198 = prims.matmul(t1196, t1197) # t1198: "cuda:0 bf16[11008, 4096]"
del t1196, t1197
t1170 = torch.matmul(t1168, t1169) # t1170: "cuda:0 bf16[4096, 11008]"
# t1170 = ltorch.matmul(t1168, t1169) # t1170: "cuda:0 bf16[4096, 11008]"
# t1170 = prims.matmul(t1168, t1169) # t1170: "cuda:0 bf16[4096, 11008]"
del t1168, t1169
t976 = torch.matmul(t974, t975) # t976: "cuda:0 bf16[12288, 4096]"
# t976 = ltorch.matmul(t974, t975) # t976: "cuda:0 bf16[12288, 4096]"
# t976 = prims.matmul(t974, t975) # t976: "cuda:0 bf16[12288, 4096]"
del t974, t975
t833 = torch.matmul(t831, t832) # t833: "cuda:0 bf16[4096, 4096]"
# t833 = ltorch.matmul(t831, t832) # t833: "cuda:0 bf16[4096, 4096]"
# t833 = prims.matmul(t831, t832) # t833: "cuda:0 bf16[4096, 4096]"
del t831, t832
t790 = torch.matmul(t788, t782) # t790: "cuda:0 bf16[11008, 4096]"
# t790 = ltorch.matmul(t788, t782) # t790: "cuda:0 bf16[11008, 4096]"
# t790 = prims.matmul(t788, t782) # t790: "cuda:0 bf16[11008, 4096]"
del t788
t783 = torch.matmul(t781, t782) # t783: "cuda:0 bf16[11008, 4096]"
# t783 = ltorch.matmul(t781, t782) # t783: "cuda:0 bf16[11008, 4096]"
# t783 = prims.matmul(t781, t782) # t783: "cuda:0 bf16[11008, 4096]"
del t781, t782
t755 = torch.matmul(t753, t754) # t755: "cuda:0 bf16[4096, 11008]"
# t755 = ltorch.matmul(t753, t754) # t755: "cuda:0 bf16[4096, 11008]"
# t755 = prims.matmul(t753, t754) # t755: "cuda:0 bf16[4096, 11008]"
del t753, t754
[t588] = nvFusion16(x_4)
# t588 = prims.reshape(x_4, (16384, 4096)) # t588: "cuda:0 bf16[16384, 4096]"
del x_4
[t587] = nvFusion17(t582)
# t586 = prims.reshape(t582, (16384, 32000)) # t586: "cuda:0 bf16[16384, 32000]"
# t587 = prims.transpose(t586, (1, 0)) # t587: "cuda:0 bf16[32000, 16384]"
del t582
t589 = torch.matmul(t587, t588) # t589: "cuda:0 bf16[32000, 4096]"
# t589 = ltorch.matmul(t587, t588) # t589: "cuda:0 bf16[32000, 4096]"
# t589 = prims.matmul(t587, t588) # t589: "cuda:0 bf16[32000, 4096]"
del t587, t588
return (None, None, None, t1429, t1396, t1391, t1248, t1214, t1205, t1198, t1170, t981, t976, t833, t799, t790, t783, t755, t594, t589)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment