Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created June 19, 2024 07:04
Show Gist options
  • Save shunting314/acaee2be945abd786b912302011e86ed to your computer and use it in GitHub Desktop.
Save shunting314/acaee2be945abd786b912302011e86ed to your computer and use it in GitHub Desktop.
class GraphModule(torch.nn.Module):
def forward(self, primals_3: "f32[768]", primals_9: "f32[768]", primals_15: "f32[768]", primals_21: "f32[768]", primals_27: "f32[768]", primals_33: "f32[768]", primals_39: "f32[768]", primals_45: "f32[768]", primals_51: "f32[768]", primals_57: "f32[768]", primals_63: "f32[768]", primals_69: "f32[768]", primals_75: "f32[768]", primals_81: "f32[768]", primals_87: "f32[768]", primals_93: "f32[768]", primals_99: "f32[768]", primals_105: "f32[768]", primals_111: "f32[768]", primals_117: "f32[768]", primals_123: "f32[768]", primals_129: "f32[768]", primals_135: "f32[768]", primals_141: "f32[768]", primals_147: "f32[768]", primals_150: "i64[32, 1024]", primals_151: "i64[32, 1024]", iota: "i64[1024]", embedding: "f32[32, 1024, 768]", embedding_1: "f32[1024, 768]", getitem_1: "f32[32, 1024, 1]", rsqrt: "f32[32, 1024, 1]", view: "bf16[32768, 768]", permute_1: "bf16[32, 12, 1024, 64]", permute_2: "bf16[32, 12, 1024, 64]", permute_3: "bf16[32, 12, 1024, 64]", getitem_5: "bf16[32, 12, 1024, 64]", getitem_6: "f32[32, 12, 1024]", getitem_11: "i64[]", getitem_12: "i64[]", mul_2: "f32[32, 1024, 768]", view_8: "bf16[32768, 768]", addmm_2: "bf16[32768, 3072]", view_10: "bf16[32768, 3072]", mul_8: "f32[32, 1024, 768]", view_12: "bf16[32768, 768]", permute_9: "bf16[32, 12, 1024, 64]", permute_10: "bf16[32, 12, 1024, 64]", permute_11: "bf16[32, 12, 1024, 64]", getitem_21: "bf16[32, 12, 1024, 64]", getitem_22: "f32[32, 12, 1024]", getitem_27: "i64[]", getitem_28: "i64[]", mul_10: "f32[32, 1024, 768]", view_20: "bf16[32768, 768]", addmm_6: "bf16[32768, 3072]", view_22: "bf16[32768, 3072]", mul_16: "f32[32, 1024, 768]", view_24: "bf16[32768, 768]", permute_17: "bf16[32, 12, 1024, 64]", permute_18: "bf16[32, 12, 1024, 64]", permute_19: "bf16[32, 12, 1024, 64]", getitem_37: "bf16[32, 12, 1024, 64]", getitem_38: "f32[32, 12, 1024]", getitem_43: "i64[]", getitem_44: "i64[]", mul_18: "f32[32, 1024, 768]", view_32: "bf16[32768, 768]", addmm_10: "bf16[32768, 3072]", view_34: "bf16[32768, 3072]", mul_24: "f32[32, 1024, 768]", view_36: "bf16[32768, 768]", permute_25: "bf16[32, 12, 1024, 64]", permute_26: "bf16[32, 12, 1024, 64]", permute_27: "bf16[32, 12, 1024, 64]", getitem_53: "bf16[32, 12, 1024, 64]", getitem_54: "f32[32, 12, 1024]", getitem_59: "i64[]", getitem_60: "i64[]", mul_26: "f32[32, 1024, 768]", view_44: "bf16[32768, 768]", addmm_14: "bf16[32768, 3072]", view_46: "bf16[32768, 3072]", mul_32: "f32[32, 1024, 768]", view_48: "bf16[32768, 768]", permute_33: "bf16[32, 12, 1024, 64]", permute_34: "bf16[32, 12, 1024, 64]", permute_35: "bf16[32, 12, 1024, 64]", getitem_69: "bf16[32, 12, 1024, 64]", getitem_70: "f32[32, 12, 1024]", getitem_75: "i64[]", getitem_76: "i64[]", mul_34: "f32[32, 1024, 768]", view_56: "bf16[32768, 768]", addmm_18: "bf16[32768, 3072]", view_58: "bf16[32768, 3072]", mul_40: "f32[32, 1024, 768]", view_60: "bf16[32768, 768]", permute_41: "bf16[32, 12, 1024, 64]", permute_42: "bf16[32, 12, 1024, 64]", permute_43: "bf16[32, 12, 1024, 64]", getitem_85: "bf16[32, 12, 1024, 64]", getitem_86: "f32[32, 12, 1024]", getitem_91: "i64[]", getitem_92: "i64[]", mul_42: "f32[32, 1024, 768]", view_68: "bf16[32768, 768]", addmm_22: "bf16[32768, 3072]", view_70: "bf16[32768, 3072]", mul_48: "f32[32, 1024, 768]", view_72: "bf16[32768, 768]", permute_49: "bf16[32, 12, 1024, 64]", permute_50: "bf16[32, 12, 1024, 64]", permute_51: "bf16[32, 12, 1024, 64]", getitem_101: "bf16[32, 12, 1024, 64]", getitem_102: "f32[32, 12, 1024]", getitem_107: "i64[]", getitem_108: "i64[]", mul_50: "f32[32, 1024, 768]", view_80: "bf16[32768, 768]", addmm_26: "bf16[32768, 3072]", view_82: "bf16[32768, 3072]", mul_56: "f32[32, 1024, 768]", view_84: "bf16[32768, 768]", permute_57: "bf16[32, 12, 1024, 64]", permute_58: "bf16[32, 12, 1024, 64]", permute_59: "bf16[32, 12, 1024, 64]", getitem_117: "bf16[32, 12, 1024, 64]", getitem_118: "f32[32, 12, 1024]", getitem_123: "i64[]", getitem_124: "i64[]", mul_58: "f32[32, 1024, 768]", view_92: "bf16[32768, 768]", addmm_30: "bf16[32768, 3072]", view_94: "bf16[32768, 3072]", mul_64: "f32[32, 1024, 768]", view_96: "bf16[32768, 768]", permute_65: "bf16[32, 12, 1024, 64]", permute_66: "bf16[32, 12, 1024, 64]", permute_67: "bf16[32, 12, 1024, 64]", getitem_133: "bf16[32, 12, 1024, 64]", getitem_134: "f32[32, 12, 1024]", getitem_139: "i64[]", getitem_140: "i64[]", mul_66: "f32[32, 1024, 768]", view_104: "bf16[32768, 768]", addmm_34: "bf16[32768, 3072]", view_106: "bf16[32768, 3072]", mul_72: "f32[32, 1024, 768]", view_108: "bf16[32768, 768]", permute_73: "bf16[32, 12, 1024, 64]", permute_74: "bf16[32, 12, 1024, 64]", permute_75: "bf16[32, 12, 1024, 64]", getitem_149: "bf16[32, 12, 1024, 64]", getitem_150: "f32[32, 12, 1024]", getitem_155: "i64[]", getitem_156: "i64[]", mul_74: "f32[32, 1024, 768]", view_116: "bf16[32768, 768]", addmm_38: "bf16[32768, 3072]", view_118: "bf16[32768, 3072]", mul_80: "f32[32, 1024, 768]", view_120: "bf16[32768, 768]", permute_81: "bf16[32, 12, 1024, 64]", permute_82: "bf16[32, 12, 1024, 64]", permute_83: "bf16[32, 12, 1024, 64]", getitem_165: "bf16[32, 12, 1024, 64]", getitem_166: "f32[32, 12, 1024]", getitem_171: "i64[]", getitem_172: "i64[]", mul_82: "f32[32, 1024, 768]", view_128: "bf16[32768, 768]", addmm_42: "bf16[32768, 3072]", view_130: "bf16[32768, 3072]", mul_88: "f32[32, 1024, 768]", view_132: "bf16[32768, 768]", permute_89: "bf16[32, 12, 1024, 64]", permute_90: "bf16[32, 12, 1024, 64]", permute_91: "bf16[32, 12, 1024, 64]", getitem_181: "bf16[32, 12, 1024, 64]", getitem_182: "f32[32, 12, 1024]", getitem_187: "i64[]", getitem_188: "i64[]", mul_90: "f32[32, 1024, 768]", view_140: "bf16[32768, 768]", addmm_46: "bf16[32768, 3072]", view_142: "bf16[32768, 3072]", mul_96: "f32[32, 1024, 768]", view_144: "bf16[32768, 768]", mm_default_2: "bf16[32768, 50264]", amax: "f32[32768, 1]", log: "f32[32768, 1]", convert_element_type_295: "f32[]", permute_99: "bf16[50257, 768]", div_2: "f32[32, 1024, 1]", permute_101: "bf16[768, 3072]", permute_105: "bf16[3072, 768]", div_3: "f32[32, 1024, 1]", permute_109: "bf16[768, 768]", permute_117: "bf16[2304, 768]", div_4: "f32[32, 1024, 1]", permute_121: "bf16[768, 3072]", permute_125: "bf16[3072, 768]", div_5: "f32[32, 1024, 1]", permute_129: "bf16[768, 768]", permute_137: "bf16[2304, 768]", div_6: "f32[32, 1024, 1]", permute_141: "bf16[768, 3072]", permute_145: "bf16[3072, 768]", div_7: "f32[32, 1024, 1]", permute_149: "bf16[768, 768]", permute_157: "bf16[2304, 768]", div_8: "f32[32, 1024, 1]", permute_161: "bf16[768, 3072]", permute_165: "bf16[3072, 768]", div_9: "f32[32, 1024, 1]", permute_169: "bf16[768, 768]", permute_177: "bf16[2304, 768]", div_10: "f32[32, 1024, 1]", permute_181: "bf16[768, 3072]", permute_185: "bf16[3072, 768]", div_11: "f32[32, 1024, 1]", permute_189: "bf16[768, 768]", permute_197: "bf16[2304, 768]", div_12: "f32[32, 1024, 1]", permute_201: "bf16[768, 3072]", permute_205: "bf16[3072, 768]", div_13: "f32[32, 1024, 1]", permute_209: "bf16[768, 768]", permute_217: "bf16[2304, 768]", div_14: "f32[32, 1024, 1]", permute_221: "bf16[768, 3072]", permute_225: "bf16[3072, 768]", div_15: "f32[32, 1024, 1]", permute_229: "bf16[768, 768]", permute_237: "bf16[2304, 768]", div_16: "f32[32, 1024, 1]", permute_241: "bf16[768, 3072]", permute_245: "bf16[3072, 768]", div_17: "f32[32, 1024, 1]", permute_249: "bf16[768, 768]", permute_257: "bf16[2304, 768]", div_18: "f32[32, 1024, 1]", permute_261: "bf16[768, 3072]", permute_265: "bf16[3072, 768]", div_19: "f32[32, 1024, 1]", permute_269: "bf16[768, 768]", permute_277: "bf16[2304, 768]", div_20: "f32[32, 1024, 1]", permute_281: "bf16[768, 3072]", permute_285: "bf16[3072, 768]", div_21: "f32[32, 1024, 1]", permute_289: "bf16[768, 768]", permute_297: "bf16[2304, 768]", div_22: "f32[32, 1024, 1]", permute_301: "bf16[768, 3072]", permute_305: "bf16[3072, 768]", div_23: "f32[32, 1024, 1]", permute_309: "bf16[768, 768]", permute_317: "bf16[2304, 768]", div_24: "f32[32, 1024, 1]", permute_321: "bf16[768, 3072]", permute_325: "bf16[3072, 768]", div_25: "f32[32, 1024, 1]", permute_329: "bf16[768, 768]", permute_337: "bf16[2304, 768]", tangents_1: "f32[]"):
# File: /home/shunting/ws/llm.c/train_gpt2.py:170 in forward, code: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
div_1: "f32[]" = torch.ops.aten.div.Tensor(tangents_1, convert_element_type_295); tangents_1 = convert_element_type_295 = None
view_147: "i64[32768]" = torch.ops.aten.reshape.default(primals_151, [-1]); primals_151 = None
unsqueeze_1: "i64[32768, 1]" = torch.ops.aten.unsqueeze.default(view_147, 1); view_147 = None
ne_3: "b8[32768, 1]" = torch.ops.aten.ne.Scalar(unsqueeze_1, -1)
full_default: "i64[]" = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where_2: "i64[32768, 1]" = torch.ops.aten.where.self(ne_3, unsqueeze_1, full_default); unsqueeze_1 = full_default = None
full_default_3: "f32[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
scatter: "f32[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0); full_default_3 = where_2 = None
full_default_1: "f32[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where_3: "f32[32768, 1]" = torch.ops.aten.where.self(ne_3, div_1, full_default_1); ne_3 = div_1 = None
mul_98: "f32[32768, 50257]" = torch.ops.aten.mul.Tensor(scatter, where_3); scatter = where_3 = None
# No stacktrace found for following nodes
slice_tensor_1: "bf16[32768, 50257]" = torch.ops.aten.slice.Tensor(mm_default_2, 1, 0, -7); mm_default_2 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:169 in forward, code: logits = self.lm_head(x)
view_145: "bf16[32, 1024, 50257]" = torch.ops.aten.reshape.default(slice_tensor_1, [32, 1024, 50257]); slice_tensor_1 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:170 in forward, code: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
view_146: "bf16[32768, 50257]" = torch.ops.aten.reshape.default(view_145, [-1, 50257]); view_145 = None
convert_element_type_292: "f32[32768, 50257]" = torch.ops.prims.convert_element_type.default(view_146, torch.float32); view_146 = None
sub_25: "f32[32768, 50257]" = torch.ops.aten.sub.Tensor(convert_element_type_292, amax); convert_element_type_292 = amax = None
sub_26: "f32[32768, 50257]" = torch.ops.aten.sub.Tensor(sub_25, log); sub_25 = log = None
convert_element_type_293: "bf16[32768, 50257]" = torch.ops.prims.convert_element_type.default(sub_26, torch.bfloat16); sub_26 = None
convert_element_type_294: "f32[32768, 50257]" = torch.ops.prims.convert_element_type.default(convert_element_type_293, torch.float32); convert_element_type_293 = None
exp_1: "f32[32768, 50257]" = torch.ops.aten.exp.default(convert_element_type_294); convert_element_type_294 = None
sum_4: "f32[32768, 1]" = torch.ops.aten.sum.dim_IntList(mul_98, [1], True)
mul_99: "f32[32768, 50257]" = torch.ops.aten.mul.Tensor(exp_1, sum_4); exp_1 = sum_4 = None
sub_27: "f32[32768, 50257]" = torch.ops.aten.sub.Tensor(mul_98, mul_99); mul_98 = mul_99 = None
convert_element_type_299: "bf16[32768, 50257]" = torch.ops.prims.convert_element_type.default(sub_27, torch.bfloat16); sub_27 = None
view_148: "bf16[32, 1024, 50257]" = torch.ops.aten.reshape.default(convert_element_type_299, [32, 1024, 50257]); convert_element_type_299 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:169 in forward, code: logits = self.lm_head(x)
view_149: "bf16[32768, 50257]" = torch.ops.aten.reshape.default(view_148, [32768, 50257]); view_148 = None
permute_97: "bf16[50257, 32768]" = torch.ops.aten.permute.default(view_149, [1, 0])
# No stacktrace found for following nodes
constant_pad_nd_default_2: "bf16[50264, 32768]" = torch.ops.aten.constant_pad_nd.default(permute_97, [0, 0, 0, 7]); permute_97 = None
mm_default_1: "bf16[50264, 768]" = torch.ops.aten.mm.default(constant_pad_nd_default_2, view_144); constant_pad_nd_default_2 = view_144 = None
slice_tensor: "bf16[50257, 768]" = torch.ops.aten.slice.Tensor(mm_default_1, 0, 0, -7); mm_default_1 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:169 in forward, code: logits = self.lm_head(x)
permute_98: "bf16[768, 50257]" = torch.ops.aten.permute.default(slice_tensor, [1, 0]); slice_tensor = None
# No stacktrace found for following nodes
constant_pad_nd_default: "bf16[32768, 50264]" = torch.ops.aten.constant_pad_nd.default(view_149, [0, 7, 0, 0]); view_149 = None
constant_pad_nd_default_1: "bf16[50264, 768]" = torch.ops.aten.constant_pad_nd.default(permute_99, [0, 0, 0, 7]); permute_99 = None
mm_default: "bf16[32768, 768]" = torch.ops.aten.mm.default(constant_pad_nd_default, constant_pad_nd_default_1); constant_pad_nd_default = constant_pad_nd_default_1 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:169 in forward, code: logits = self.lm_head(x)
view_150: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_default, [32, 1024, 768]); mm_default = None
permute_100: "bf16[50257, 768]" = torch.ops.aten.permute.default(permute_98, [1, 0]); permute_98 = None
convert_element_type_304: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_150, torch.float32); view_150 = None
convert_element_type_305: "f32[50257, 768]" = torch.ops.prims.convert_element_type.default(permute_100, torch.float32); permute_100 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:165 in forward, code: x = self.transformer.ln_f(x)
mul_101: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_304, primals_147); primals_147 = None
mul_102: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_101, 768)
sum_5: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_101, [2], True)
mul_103: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_101, mul_96); mul_101 = None
sum_6: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_103, [2], True); mul_103 = None
mul_104: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_96, sum_6); sum_6 = None
sub_29: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_102, sum_5); mul_102 = sum_5 = None
sub_30: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_29, mul_104); sub_29 = mul_104 = None
mul_105: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_2, sub_30); div_2 = sub_30 = None
mul_106: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_304, mul_96); mul_96 = None
sum_7: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_106, [0, 1]); mul_106 = None
sum_8: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_304, [0, 1]); convert_element_type_304 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_306: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(mul_105, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_151: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_306, [32768, 768]); convert_element_type_306 = None
mm_3: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_151, permute_101); permute_101 = None
permute_102: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_151, [1, 0])
mm_4: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_102, view_142); permute_102 = view_142 = None
permute_103: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_4, [1, 0]); mm_4 = None
sum_9: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_151, [0], True, dtype = torch.float32); view_151 = None
view_152: "f32[768]" = torch.ops.aten.reshape.default(sum_9, [768]); sum_9 = None
permute_104: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_103, [1, 0]); permute_103 = None
view_153: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_3, [32, 1024, 3072]); mm_3 = None
convert_element_type_312: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_153, torch.float32); view_153 = None
convert_element_type_313: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_104, torch.float32); permute_104 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_141: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_46, [32, 1024, 3072]); addmm_46 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_92: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_141, 0.5)
mul_107: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_312, mul_92); mul_92 = None
convert_element_type_281: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_141, torch.float32)
pow_12: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_281, 3.0)
mul_93: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_12, 0.044715); pow_12 = None
add_94: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_141, mul_93); view_141 = mul_93 = None
mul_94: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_94, 0.7978845608028654); add_94 = None
tanh_11: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_94); mul_94 = None
add_95: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_11, 1.0)
mul_108: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_312, add_95); convert_element_type_312 = add_95 = None
convert_element_type_315: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_108, torch.bfloat16); mul_108 = None
mul_109: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_11, tanh_11); tanh_11 = None
sub_31: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_109); mul_109 = None
mul_110: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_107, sub_31); mul_107 = sub_31 = None
mul_111: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_110, 0.7978845608028654); mul_110 = None
convert_element_type_316: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_111, torch.bfloat16)
mul_112: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_111, 0.044715); mul_111 = None
pow_13: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_281, 2.0); convert_element_type_281 = None
mul_113: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_13, 3.0); pow_13 = None
mul_114: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_112, mul_113); mul_112 = mul_113 = None
convert_element_type_317: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_114, torch.bfloat16); mul_114 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_99: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_316, convert_element_type_317); convert_element_type_316 = convert_element_type_317 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_115: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_315, 0.5); convert_element_type_315 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_100: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_99, mul_115); add_99 = mul_115 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_154: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_100, [32768, 3072]); add_100 = None
mm_5: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_154, permute_105); permute_105 = None
permute_106: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_154, [1, 0])
mm_6: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_106, view_140); permute_106 = view_140 = None
permute_107: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_6, [1, 0]); mm_6 = None
sum_10: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_154, [0], True, dtype = torch.float32); view_154 = None
view_155: "f32[3072]" = torch.ops.aten.reshape.default(sum_10, [3072]); sum_10 = None
permute_108: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_107, [1, 0]); permute_107 = None
view_156: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_5, [32, 1024, 768]); mm_5 = None
convert_element_type_323: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_156, torch.float32); view_156 = None
convert_element_type_324: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_108, torch.float32); permute_108 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_117: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_323, primals_141); primals_141 = None
mul_118: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_117, 768)
sum_11: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_117, [2], True)
mul_119: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_117, mul_90); mul_117 = None
sum_12: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_119, [2], True); mul_119 = None
mul_120: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_90, sum_12); sum_12 = None
sub_33: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_118, sum_11); mul_118 = sum_11 = None
sub_34: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_33, mul_120); sub_33 = mul_120 = None
mul_121: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_3, sub_34); div_3 = sub_34 = None
mul_122: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_323, mul_90); mul_90 = None
sum_13: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_122, [0, 1]); mul_122 = None
sum_14: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_323, [0, 1]); convert_element_type_323 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_101: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(mul_105, mul_121); mul_105 = mul_121 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_326: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_101, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_157: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_326, [32768, 768]); convert_element_type_326 = None
mm_7: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_157, permute_109); permute_109 = None
permute_110: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_157, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_92: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_181, [0, 2, 1, 3])
view_137: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_92, [32, 1024, 768]); permute_92 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_138: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_137, [32768, 768]); view_137 = None
mm_8: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_110, view_138); permute_110 = view_138 = None
permute_111: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_8, [1, 0]); mm_8 = None
sum_15: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_157, [0], True, dtype = torch.float32); view_157 = None
view_158: "f32[768]" = torch.ops.aten.reshape.default(sum_15, [768]); sum_15 = None
permute_112: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_111, [1, 0]); permute_111 = None
view_159: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_7, [32, 1024, 768]); mm_7 = None
convert_element_type_332: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_112, torch.float32); permute_112 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_160: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_159, [32, 1024, 12, 64]); view_159 = None
permute_113: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_160, [0, 2, 1, 3]); view_160 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_113, permute_90, permute_89, permute_91, getitem_181, getitem_182, None, None, 1024, 1024, 0.0, True, getitem_187, getitem_188, scale = 0.125); permute_113 = permute_90 = permute_89 = permute_91 = getitem_181 = getitem_182 = getitem_187 = getitem_188 = None
getitem_194: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward[0]
getitem_195: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward[1]
getitem_196: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward[2]; _scaled_dot_product_flash_attention_backward = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_114: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_196, [0, 2, 1, 3]); getitem_196 = None
view_161: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_114, [32, 1024, 768]); permute_114 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_115: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_194, [0, 2, 1, 3]); getitem_194 = None
view_162: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_115, [32, 1024, 768]); permute_115 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_116: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_195, [0, 2, 1, 3]); getitem_195 = None
view_163: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_116, [32, 1024, 768]); permute_116 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_162, view_163, view_161], 2); view_162 = view_163 = view_161 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_164: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat, [32768, 2304]); cat = None
mm_9: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_164, permute_117); permute_117 = None
permute_118: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_164, [1, 0])
mm_10: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_118, view_132); permute_118 = view_132 = None
permute_119: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_10, [1, 0]); mm_10 = None
sum_16: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_164, [0], True, dtype = torch.float32); view_164 = None
view_165: "f32[2304]" = torch.ops.aten.reshape.default(sum_16, [2304]); sum_16 = None
permute_120: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_119, [1, 0]); permute_119 = None
view_166: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_9, [32, 1024, 768]); mm_9 = None
convert_element_type_339: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_166, torch.float32); view_166 = None
convert_element_type_340: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_120, torch.float32); permute_120 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_124: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_339, primals_135); primals_135 = None
mul_125: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_124, 768)
sum_17: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_124, [2], True)
mul_126: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_124, mul_88); mul_124 = None
sum_18: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_126, [2], True); mul_126 = None
mul_127: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_88, sum_18); sum_18 = None
sub_36: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_125, sum_17); mul_125 = sum_17 = None
sub_37: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_36, mul_127); sub_36 = mul_127 = None
mul_128: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_4, sub_37); div_4 = sub_37 = None
mul_129: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_339, mul_88); mul_88 = None
sum_19: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_129, [0, 1]); mul_129 = None
sum_20: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_339, [0, 1]); convert_element_type_339 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_102: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_101, mul_128); add_101 = mul_128 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_342: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_102, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_167: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_342, [32768, 768]); convert_element_type_342 = None
mm_11: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_167, permute_121); permute_121 = None
permute_122: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_167, [1, 0])
mm_12: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_122, view_130); permute_122 = view_130 = None
permute_123: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_12, [1, 0]); mm_12 = None
sum_21: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_167, [0], True, dtype = torch.float32); view_167 = None
view_168: "f32[768]" = torch.ops.aten.reshape.default(sum_21, [768]); sum_21 = None
permute_124: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_123, [1, 0]); permute_123 = None
view_169: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_11, [32, 1024, 3072]); mm_11 = None
convert_element_type_348: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_169, torch.float32); view_169 = None
convert_element_type_349: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_124, torch.float32); permute_124 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_129: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_42, [32, 1024, 3072]); addmm_42 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_84: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_129, 0.5)
mul_130: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_348, mul_84); mul_84 = None
convert_element_type_257: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_129, torch.float32)
pow_11: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_257, 3.0)
mul_85: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_11, 0.044715); pow_11 = None
add_86: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_129, mul_85); view_129 = mul_85 = None
mul_86: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_86, 0.7978845608028654); add_86 = None
tanh_10: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_86); mul_86 = None
add_87: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_10, 1.0)
mul_131: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_348, add_87); convert_element_type_348 = add_87 = None
convert_element_type_351: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_131, torch.bfloat16); mul_131 = None
mul_132: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_10, tanh_10); tanh_10 = None
sub_38: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_132); mul_132 = None
mul_133: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_130, sub_38); mul_130 = sub_38 = None
mul_134: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_133, 0.7978845608028654); mul_133 = None
convert_element_type_352: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16)
mul_135: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_134, 0.044715); mul_134 = None
pow_14: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_257, 2.0); convert_element_type_257 = None
mul_136: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_14, 3.0); pow_14 = None
mul_137: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_135, mul_136); mul_135 = mul_136 = None
convert_element_type_353: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_137, torch.bfloat16); mul_137 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_103: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_352, convert_element_type_353); convert_element_type_352 = convert_element_type_353 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_138: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_351, 0.5); convert_element_type_351 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_104: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_103, mul_138); add_103 = mul_138 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_170: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_104, [32768, 3072]); add_104 = None
mm_13: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_170, permute_125); permute_125 = None
permute_126: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_170, [1, 0])
mm_14: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_126, view_128); permute_126 = view_128 = None
permute_127: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_14, [1, 0]); mm_14 = None
sum_22: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_170, [0], True, dtype = torch.float32); view_170 = None
view_171: "f32[3072]" = torch.ops.aten.reshape.default(sum_22, [3072]); sum_22 = None
permute_128: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_127, [1, 0]); permute_127 = None
view_172: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_13, [32, 1024, 768]); mm_13 = None
convert_element_type_359: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_172, torch.float32); view_172 = None
convert_element_type_360: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_128, torch.float32); permute_128 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_140: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_359, primals_129); primals_129 = None
mul_141: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_140, 768)
sum_23: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_140, [2], True)
mul_142: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_140, mul_82); mul_140 = None
sum_24: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_142, [2], True); mul_142 = None
mul_143: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_82, sum_24); sum_24 = None
sub_40: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_141, sum_23); mul_141 = sum_23 = None
sub_41: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_40, mul_143); sub_40 = mul_143 = None
mul_144: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_5, sub_41); div_5 = sub_41 = None
mul_145: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_359, mul_82); mul_82 = None
sum_25: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_145, [0, 1]); mul_145 = None
sum_26: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_359, [0, 1]); convert_element_type_359 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_105: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_102, mul_144); add_102 = mul_144 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_362: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_105, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_173: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_362, [32768, 768]); convert_element_type_362 = None
mm_15: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_173, permute_129); permute_129 = None
permute_130: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_173, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_84: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_165, [0, 2, 1, 3])
view_125: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_84, [32, 1024, 768]); permute_84 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_126: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_125, [32768, 768]); view_125 = None
mm_16: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_130, view_126); permute_130 = view_126 = None
permute_131: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_16, [1, 0]); mm_16 = None
sum_27: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_173, [0], True, dtype = torch.float32); view_173 = None
view_174: "f32[768]" = torch.ops.aten.reshape.default(sum_27, [768]); sum_27 = None
permute_132: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_131, [1, 0]); permute_131 = None
view_175: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_15, [32, 1024, 768]); mm_15 = None
convert_element_type_368: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_132, torch.float32); permute_132 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_176: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_175, [32, 1024, 12, 64]); view_175 = None
permute_133: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_176, [0, 2, 1, 3]); view_176 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_1 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_133, permute_82, permute_81, permute_83, getitem_165, getitem_166, None, None, 1024, 1024, 0.0, True, getitem_171, getitem_172, scale = 0.125); permute_133 = permute_82 = permute_81 = permute_83 = getitem_165 = getitem_166 = getitem_171 = getitem_172 = None
getitem_197: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_1[0]
getitem_198: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_1[1]
getitem_199: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_1[2]; _scaled_dot_product_flash_attention_backward_1 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_134: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_199, [0, 2, 1, 3]); getitem_199 = None
view_177: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_134, [32, 1024, 768]); permute_134 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_135: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_197, [0, 2, 1, 3]); getitem_197 = None
view_178: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_135, [32, 1024, 768]); permute_135 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_136: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_198, [0, 2, 1, 3]); getitem_198 = None
view_179: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_136, [32, 1024, 768]); permute_136 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_1: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_178, view_179, view_177], 2); view_178 = view_179 = view_177 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_180: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_1, [32768, 2304]); cat_1 = None
mm_17: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_180, permute_137); permute_137 = None
permute_138: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_180, [1, 0])
mm_18: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_138, view_120); permute_138 = view_120 = None
permute_139: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_18, [1, 0]); mm_18 = None
sum_28: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_180, [0], True, dtype = torch.float32); view_180 = None
view_181: "f32[2304]" = torch.ops.aten.reshape.default(sum_28, [2304]); sum_28 = None
permute_140: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_139, [1, 0]); permute_139 = None
view_182: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_17, [32, 1024, 768]); mm_17 = None
convert_element_type_375: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_182, torch.float32); view_182 = None
convert_element_type_376: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_140, torch.float32); permute_140 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_147: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_375, primals_123); primals_123 = None
mul_148: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_147, 768)
sum_29: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_147, [2], True)
mul_149: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_147, mul_80); mul_147 = None
sum_30: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_149, [2], True); mul_149 = None
mul_150: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_80, sum_30); sum_30 = None
sub_43: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_148, sum_29); mul_148 = sum_29 = None
sub_44: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_43, mul_150); sub_43 = mul_150 = None
mul_151: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_6, sub_44); div_6 = sub_44 = None
mul_152: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_375, mul_80); mul_80 = None
sum_31: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_152, [0, 1]); mul_152 = None
sum_32: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_375, [0, 1]); convert_element_type_375 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_106: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_105, mul_151); add_105 = mul_151 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_378: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_106, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_183: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_378, [32768, 768]); convert_element_type_378 = None
mm_19: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_183, permute_141); permute_141 = None
permute_142: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_183, [1, 0])
mm_20: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_142, view_118); permute_142 = view_118 = None
permute_143: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_20, [1, 0]); mm_20 = None
sum_33: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_183, [0], True, dtype = torch.float32); view_183 = None
view_184: "f32[768]" = torch.ops.aten.reshape.default(sum_33, [768]); sum_33 = None
permute_144: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_143, [1, 0]); permute_143 = None
view_185: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_19, [32, 1024, 3072]); mm_19 = None
convert_element_type_384: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_185, torch.float32); view_185 = None
convert_element_type_385: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_144, torch.float32); permute_144 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_117: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_38, [32, 1024, 3072]); addmm_38 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_76: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_117, 0.5)
mul_153: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_384, mul_76); mul_76 = None
convert_element_type_233: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_117, torch.float32)
pow_10: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 3.0)
mul_77: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_10, 0.044715); pow_10 = None
add_78: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_117, mul_77); view_117 = mul_77 = None
mul_78: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_78, 0.7978845608028654); add_78 = None
tanh_9: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_78); mul_78 = None
add_79: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_9, 1.0)
mul_154: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_384, add_79); convert_element_type_384 = add_79 = None
convert_element_type_387: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_154, torch.bfloat16); mul_154 = None
mul_155: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_9, tanh_9); tanh_9 = None
sub_45: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_155); mul_155 = None
mul_156: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_153, sub_45); mul_153 = sub_45 = None
mul_157: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_156, 0.7978845608028654); mul_156 = None
convert_element_type_388: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_157, torch.bfloat16)
mul_158: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_157, 0.044715); mul_157 = None
pow_15: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_233, 2.0); convert_element_type_233 = None
mul_159: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_15, 3.0); pow_15 = None
mul_160: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_158, mul_159); mul_158 = mul_159 = None
convert_element_type_389: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_160, torch.bfloat16); mul_160 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_107: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_388, convert_element_type_389); convert_element_type_388 = convert_element_type_389 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_161: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_387, 0.5); convert_element_type_387 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_108: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_107, mul_161); add_107 = mul_161 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_186: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_108, [32768, 3072]); add_108 = None
mm_21: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_186, permute_145); permute_145 = None
permute_146: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_186, [1, 0])
mm_22: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_146, view_116); permute_146 = view_116 = None
permute_147: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_22, [1, 0]); mm_22 = None
sum_34: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_186, [0], True, dtype = torch.float32); view_186 = None
view_187: "f32[3072]" = torch.ops.aten.reshape.default(sum_34, [3072]); sum_34 = None
permute_148: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_147, [1, 0]); permute_147 = None
view_188: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_21, [32, 1024, 768]); mm_21 = None
convert_element_type_395: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_188, torch.float32); view_188 = None
convert_element_type_396: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_148, torch.float32); permute_148 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_163: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_395, primals_117); primals_117 = None
mul_164: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_163, 768)
sum_35: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_163, [2], True)
mul_165: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_163, mul_74); mul_163 = None
sum_36: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_165, [2], True); mul_165 = None
mul_166: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_74, sum_36); sum_36 = None
sub_47: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_164, sum_35); mul_164 = sum_35 = None
sub_48: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_47, mul_166); sub_47 = mul_166 = None
mul_167: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_7, sub_48); div_7 = sub_48 = None
mul_168: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_395, mul_74); mul_74 = None
sum_37: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_168, [0, 1]); mul_168 = None
sum_38: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_395, [0, 1]); convert_element_type_395 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_109: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_106, mul_167); add_106 = mul_167 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_398: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_109, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_189: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_398, [32768, 768]); convert_element_type_398 = None
mm_23: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_189, permute_149); permute_149 = None
permute_150: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_189, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_76: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_149, [0, 2, 1, 3])
view_113: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_76, [32, 1024, 768]); permute_76 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_114: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_113, [32768, 768]); view_113 = None
mm_24: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_150, view_114); permute_150 = view_114 = None
permute_151: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_24, [1, 0]); mm_24 = None
sum_39: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_189, [0], True, dtype = torch.float32); view_189 = None
view_190: "f32[768]" = torch.ops.aten.reshape.default(sum_39, [768]); sum_39 = None
permute_152: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_151, [1, 0]); permute_151 = None
view_191: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_23, [32, 1024, 768]); mm_23 = None
convert_element_type_404: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_152, torch.float32); permute_152 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_192: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_191, [32, 1024, 12, 64]); view_191 = None
permute_153: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_192, [0, 2, 1, 3]); view_192 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_2 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_153, permute_74, permute_73, permute_75, getitem_149, getitem_150, None, None, 1024, 1024, 0.0, True, getitem_155, getitem_156, scale = 0.125); permute_153 = permute_74 = permute_73 = permute_75 = getitem_149 = getitem_150 = getitem_155 = getitem_156 = None
getitem_200: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_2[0]
getitem_201: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_2[1]
getitem_202: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_2[2]; _scaled_dot_product_flash_attention_backward_2 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_154: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_202, [0, 2, 1, 3]); getitem_202 = None
view_193: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_154, [32, 1024, 768]); permute_154 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_155: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_200, [0, 2, 1, 3]); getitem_200 = None
view_194: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_155, [32, 1024, 768]); permute_155 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_156: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_201, [0, 2, 1, 3]); getitem_201 = None
view_195: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_156, [32, 1024, 768]); permute_156 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_2: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_194, view_195, view_193], 2); view_194 = view_195 = view_193 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_196: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_2, [32768, 2304]); cat_2 = None
mm_25: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_196, permute_157); permute_157 = None
permute_158: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_196, [1, 0])
mm_26: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_158, view_108); permute_158 = view_108 = None
permute_159: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_26, [1, 0]); mm_26 = None
sum_40: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_196, [0], True, dtype = torch.float32); view_196 = None
view_197: "f32[2304]" = torch.ops.aten.reshape.default(sum_40, [2304]); sum_40 = None
permute_160: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_159, [1, 0]); permute_159 = None
view_198: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_25, [32, 1024, 768]); mm_25 = None
convert_element_type_411: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_198, torch.float32); view_198 = None
convert_element_type_412: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_160, torch.float32); permute_160 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_170: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_411, primals_111); primals_111 = None
mul_171: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_170, 768)
sum_41: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_170, [2], True)
mul_172: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_170, mul_72); mul_170 = None
sum_42: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_172, [2], True); mul_172 = None
mul_173: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_72, sum_42); sum_42 = None
sub_50: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_171, sum_41); mul_171 = sum_41 = None
sub_51: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_50, mul_173); sub_50 = mul_173 = None
mul_174: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_8, sub_51); div_8 = sub_51 = None
mul_175: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_411, mul_72); mul_72 = None
sum_43: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_175, [0, 1]); mul_175 = None
sum_44: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_411, [0, 1]); convert_element_type_411 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_110: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_109, mul_174); add_109 = mul_174 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_414: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_110, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_199: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_414, [32768, 768]); convert_element_type_414 = None
mm_27: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_199, permute_161); permute_161 = None
permute_162: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_199, [1, 0])
mm_28: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_162, view_106); permute_162 = view_106 = None
permute_163: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_28, [1, 0]); mm_28 = None
sum_45: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_199, [0], True, dtype = torch.float32); view_199 = None
view_200: "f32[768]" = torch.ops.aten.reshape.default(sum_45, [768]); sum_45 = None
permute_164: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_163, [1, 0]); permute_163 = None
view_201: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_27, [32, 1024, 3072]); mm_27 = None
convert_element_type_420: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_201, torch.float32); view_201 = None
convert_element_type_421: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_164, torch.float32); permute_164 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_105: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_34, [32, 1024, 3072]); addmm_34 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_68: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_105, 0.5)
mul_176: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_420, mul_68); mul_68 = None
convert_element_type_209: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_105, torch.float32)
pow_9: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_209, 3.0)
mul_69: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_9, 0.044715); pow_9 = None
add_70: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_105, mul_69); view_105 = mul_69 = None
mul_70: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_70, 0.7978845608028654); add_70 = None
tanh_8: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_70); mul_70 = None
add_71: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_8, 1.0)
mul_177: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_420, add_71); convert_element_type_420 = add_71 = None
convert_element_type_423: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_177, torch.bfloat16); mul_177 = None
mul_178: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_8, tanh_8); tanh_8 = None
sub_52: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_178); mul_178 = None
mul_179: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_176, sub_52); mul_176 = sub_52 = None
mul_180: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_179, 0.7978845608028654); mul_179 = None
convert_element_type_424: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_180, torch.bfloat16)
mul_181: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_180, 0.044715); mul_180 = None
pow_16: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_209, 2.0); convert_element_type_209 = None
mul_182: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_16, 3.0); pow_16 = None
mul_183: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_181, mul_182); mul_181 = mul_182 = None
convert_element_type_425: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_183, torch.bfloat16); mul_183 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_111: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_424, convert_element_type_425); convert_element_type_424 = convert_element_type_425 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_184: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_423, 0.5); convert_element_type_423 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_112: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_111, mul_184); add_111 = mul_184 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_202: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_112, [32768, 3072]); add_112 = None
mm_29: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_202, permute_165); permute_165 = None
permute_166: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_202, [1, 0])
mm_30: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_166, view_104); permute_166 = view_104 = None
permute_167: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_30, [1, 0]); mm_30 = None
sum_46: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_202, [0], True, dtype = torch.float32); view_202 = None
view_203: "f32[3072]" = torch.ops.aten.reshape.default(sum_46, [3072]); sum_46 = None
permute_168: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_167, [1, 0]); permute_167 = None
view_204: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_29, [32, 1024, 768]); mm_29 = None
convert_element_type_431: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_204, torch.float32); view_204 = None
convert_element_type_432: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_168, torch.float32); permute_168 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_186: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_431, primals_105); primals_105 = None
mul_187: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_186, 768)
sum_47: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_186, [2], True)
mul_188: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_186, mul_66); mul_186 = None
sum_48: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_188, [2], True); mul_188 = None
mul_189: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_66, sum_48); sum_48 = None
sub_54: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_187, sum_47); mul_187 = sum_47 = None
sub_55: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_54, mul_189); sub_54 = mul_189 = None
mul_190: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_9, sub_55); div_9 = sub_55 = None
mul_191: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_431, mul_66); mul_66 = None
sum_49: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_191, [0, 1]); mul_191 = None
sum_50: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_431, [0, 1]); convert_element_type_431 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_113: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_110, mul_190); add_110 = mul_190 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_434: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_113, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_205: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_434, [32768, 768]); convert_element_type_434 = None
mm_31: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_205, permute_169); permute_169 = None
permute_170: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_205, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_68: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_133, [0, 2, 1, 3])
view_101: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_68, [32, 1024, 768]); permute_68 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_102: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_101, [32768, 768]); view_101 = None
mm_32: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_170, view_102); permute_170 = view_102 = None
permute_171: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_32, [1, 0]); mm_32 = None
sum_51: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_205, [0], True, dtype = torch.float32); view_205 = None
view_206: "f32[768]" = torch.ops.aten.reshape.default(sum_51, [768]); sum_51 = None
permute_172: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_171, [1, 0]); permute_171 = None
view_207: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_31, [32, 1024, 768]); mm_31 = None
convert_element_type_440: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_172, torch.float32); permute_172 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_208: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_207, [32, 1024, 12, 64]); view_207 = None
permute_173: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_208, [0, 2, 1, 3]); view_208 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_3 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_173, permute_66, permute_65, permute_67, getitem_133, getitem_134, None, None, 1024, 1024, 0.0, True, getitem_139, getitem_140, scale = 0.125); permute_173 = permute_66 = permute_65 = permute_67 = getitem_133 = getitem_134 = getitem_139 = getitem_140 = None
getitem_203: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_3[0]
getitem_204: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_3[1]
getitem_205: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_3[2]; _scaled_dot_product_flash_attention_backward_3 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_174: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_205, [0, 2, 1, 3]); getitem_205 = None
view_209: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_174, [32, 1024, 768]); permute_174 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_175: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_203, [0, 2, 1, 3]); getitem_203 = None
view_210: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_175, [32, 1024, 768]); permute_175 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_176: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_204, [0, 2, 1, 3]); getitem_204 = None
view_211: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_176, [32, 1024, 768]); permute_176 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_3: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_210, view_211, view_209], 2); view_210 = view_211 = view_209 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_212: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_3, [32768, 2304]); cat_3 = None
mm_33: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_212, permute_177); permute_177 = None
permute_178: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_212, [1, 0])
mm_34: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_178, view_96); permute_178 = view_96 = None
permute_179: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_34, [1, 0]); mm_34 = None
sum_52: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_212, [0], True, dtype = torch.float32); view_212 = None
view_213: "f32[2304]" = torch.ops.aten.reshape.default(sum_52, [2304]); sum_52 = None
permute_180: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_179, [1, 0]); permute_179 = None
view_214: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_33, [32, 1024, 768]); mm_33 = None
convert_element_type_447: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_214, torch.float32); view_214 = None
convert_element_type_448: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_180, torch.float32); permute_180 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_193: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_447, primals_99); primals_99 = None
mul_194: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_193, 768)
sum_53: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_193, [2], True)
mul_195: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_193, mul_64); mul_193 = None
sum_54: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_195, [2], True); mul_195 = None
mul_196: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_64, sum_54); sum_54 = None
sub_57: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_194, sum_53); mul_194 = sum_53 = None
sub_58: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_57, mul_196); sub_57 = mul_196 = None
mul_197: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_10, sub_58); div_10 = sub_58 = None
mul_198: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_447, mul_64); mul_64 = None
sum_55: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_198, [0, 1]); mul_198 = None
sum_56: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_447, [0, 1]); convert_element_type_447 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_114: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_113, mul_197); add_113 = mul_197 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_450: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_114, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_215: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_450, [32768, 768]); convert_element_type_450 = None
mm_35: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_215, permute_181); permute_181 = None
permute_182: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_215, [1, 0])
mm_36: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_182, view_94); permute_182 = view_94 = None
permute_183: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_36, [1, 0]); mm_36 = None
sum_57: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_215, [0], True, dtype = torch.float32); view_215 = None
view_216: "f32[768]" = torch.ops.aten.reshape.default(sum_57, [768]); sum_57 = None
permute_184: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_183, [1, 0]); permute_183 = None
view_217: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_35, [32, 1024, 3072]); mm_35 = None
convert_element_type_456: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_217, torch.float32); view_217 = None
convert_element_type_457: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_184, torch.float32); permute_184 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_93: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_30, [32, 1024, 3072]); addmm_30 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_60: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_93, 0.5)
mul_199: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_456, mul_60); mul_60 = None
convert_element_type_185: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_93, torch.float32)
pow_8: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_185, 3.0)
mul_61: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_8, 0.044715); pow_8 = None
add_62: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_93, mul_61); view_93 = mul_61 = None
mul_62: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_62, 0.7978845608028654); add_62 = None
tanh_7: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_62); mul_62 = None
add_63: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_7, 1.0)
mul_200: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_456, add_63); convert_element_type_456 = add_63 = None
convert_element_type_459: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_200, torch.bfloat16); mul_200 = None
mul_201: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_7, tanh_7); tanh_7 = None
sub_59: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_201); mul_201 = None
mul_202: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_199, sub_59); mul_199 = sub_59 = None
mul_203: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_202, 0.7978845608028654); mul_202 = None
convert_element_type_460: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_203, torch.bfloat16)
mul_204: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_203, 0.044715); mul_203 = None
pow_17: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_185, 2.0); convert_element_type_185 = None
mul_205: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_17, 3.0); pow_17 = None
mul_206: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_204, mul_205); mul_204 = mul_205 = None
convert_element_type_461: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_206, torch.bfloat16); mul_206 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_115: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_460, convert_element_type_461); convert_element_type_460 = convert_element_type_461 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_207: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_459, 0.5); convert_element_type_459 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_116: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_115, mul_207); add_115 = mul_207 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_218: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_116, [32768, 3072]); add_116 = None
mm_37: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_218, permute_185); permute_185 = None
permute_186: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_218, [1, 0])
mm_38: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_186, view_92); permute_186 = view_92 = None
permute_187: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_38, [1, 0]); mm_38 = None
sum_58: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_218, [0], True, dtype = torch.float32); view_218 = None
view_219: "f32[3072]" = torch.ops.aten.reshape.default(sum_58, [3072]); sum_58 = None
permute_188: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_187, [1, 0]); permute_187 = None
view_220: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_37, [32, 1024, 768]); mm_37 = None
convert_element_type_467: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_220, torch.float32); view_220 = None
convert_element_type_468: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_188, torch.float32); permute_188 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_209: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_467, primals_93); primals_93 = None
mul_210: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_209, 768)
sum_59: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_209, [2], True)
mul_211: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_209, mul_58); mul_209 = None
sum_60: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_211, [2], True); mul_211 = None
mul_212: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_58, sum_60); sum_60 = None
sub_61: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_210, sum_59); mul_210 = sum_59 = None
sub_62: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_61, mul_212); sub_61 = mul_212 = None
mul_213: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_11, sub_62); div_11 = sub_62 = None
mul_214: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_467, mul_58); mul_58 = None
sum_61: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_214, [0, 1]); mul_214 = None
sum_62: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_467, [0, 1]); convert_element_type_467 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_117: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_114, mul_213); add_114 = mul_213 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_470: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_117, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_221: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_470, [32768, 768]); convert_element_type_470 = None
mm_39: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_221, permute_189); permute_189 = None
permute_190: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_221, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_60: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_117, [0, 2, 1, 3])
view_89: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_60, [32, 1024, 768]); permute_60 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_90: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_89, [32768, 768]); view_89 = None
mm_40: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_190, view_90); permute_190 = view_90 = None
permute_191: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_40, [1, 0]); mm_40 = None
sum_63: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_221, [0], True, dtype = torch.float32); view_221 = None
view_222: "f32[768]" = torch.ops.aten.reshape.default(sum_63, [768]); sum_63 = None
permute_192: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_191, [1, 0]); permute_191 = None
view_223: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_39, [32, 1024, 768]); mm_39 = None
convert_element_type_476: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_192, torch.float32); permute_192 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_224: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_223, [32, 1024, 12, 64]); view_223 = None
permute_193: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_224, [0, 2, 1, 3]); view_224 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_4 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_193, permute_58, permute_57, permute_59, getitem_117, getitem_118, None, None, 1024, 1024, 0.0, True, getitem_123, getitem_124, scale = 0.125); permute_193 = permute_58 = permute_57 = permute_59 = getitem_117 = getitem_118 = getitem_123 = getitem_124 = None
getitem_206: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_4[0]
getitem_207: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_4[1]
getitem_208: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_4[2]; _scaled_dot_product_flash_attention_backward_4 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_194: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_208, [0, 2, 1, 3]); getitem_208 = None
view_225: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_194, [32, 1024, 768]); permute_194 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_195: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_206, [0, 2, 1, 3]); getitem_206 = None
view_226: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_195, [32, 1024, 768]); permute_195 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_196: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]); getitem_207 = None
view_227: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_196, [32, 1024, 768]); permute_196 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_4: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_226, view_227, view_225], 2); view_226 = view_227 = view_225 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_228: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_4, [32768, 2304]); cat_4 = None
mm_41: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_228, permute_197); permute_197 = None
permute_198: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_228, [1, 0])
mm_42: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_198, view_84); permute_198 = view_84 = None
permute_199: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_42, [1, 0]); mm_42 = None
sum_64: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_228, [0], True, dtype = torch.float32); view_228 = None
view_229: "f32[2304]" = torch.ops.aten.reshape.default(sum_64, [2304]); sum_64 = None
permute_200: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_199, [1, 0]); permute_199 = None
view_230: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_41, [32, 1024, 768]); mm_41 = None
convert_element_type_483: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_230, torch.float32); view_230 = None
convert_element_type_484: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_200, torch.float32); permute_200 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_216: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_483, primals_87); primals_87 = None
mul_217: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_216, 768)
sum_65: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_216, [2], True)
mul_218: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_216, mul_56); mul_216 = None
sum_66: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_218, [2], True); mul_218 = None
mul_219: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_56, sum_66); sum_66 = None
sub_64: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_217, sum_65); mul_217 = sum_65 = None
sub_65: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_64, mul_219); sub_64 = mul_219 = None
mul_220: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_12, sub_65); div_12 = sub_65 = None
mul_221: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_483, mul_56); mul_56 = None
sum_67: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_221, [0, 1]); mul_221 = None
sum_68: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_483, [0, 1]); convert_element_type_483 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_118: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_117, mul_220); add_117 = mul_220 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_486: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_118, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_231: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_486, [32768, 768]); convert_element_type_486 = None
mm_43: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_231, permute_201); permute_201 = None
permute_202: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_231, [1, 0])
mm_44: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_202, view_82); permute_202 = view_82 = None
permute_203: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_44, [1, 0]); mm_44 = None
sum_69: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_231, [0], True, dtype = torch.float32); view_231 = None
view_232: "f32[768]" = torch.ops.aten.reshape.default(sum_69, [768]); sum_69 = None
permute_204: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_203, [1, 0]); permute_203 = None
view_233: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_43, [32, 1024, 3072]); mm_43 = None
convert_element_type_492: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_233, torch.float32); view_233 = None
convert_element_type_493: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_204, torch.float32); permute_204 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_81: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_26, [32, 1024, 3072]); addmm_26 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_52: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_81, 0.5)
mul_222: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_492, mul_52); mul_52 = None
convert_element_type_161: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_81, torch.float32)
pow_7: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_161, 3.0)
mul_53: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_7, 0.044715); pow_7 = None
add_54: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_81, mul_53); view_81 = mul_53 = None
mul_54: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_54, 0.7978845608028654); add_54 = None
tanh_6: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_54); mul_54 = None
add_55: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_6, 1.0)
mul_223: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_492, add_55); convert_element_type_492 = add_55 = None
convert_element_type_495: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_223, torch.bfloat16); mul_223 = None
mul_224: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_6, tanh_6); tanh_6 = None
sub_66: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_224); mul_224 = None
mul_225: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_222, sub_66); mul_222 = sub_66 = None
mul_226: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_225, 0.7978845608028654); mul_225 = None
convert_element_type_496: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_226, torch.bfloat16)
mul_227: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_226, 0.044715); mul_226 = None
pow_18: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_161, 2.0); convert_element_type_161 = None
mul_228: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_18, 3.0); pow_18 = None
mul_229: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_227, mul_228); mul_227 = mul_228 = None
convert_element_type_497: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_229, torch.bfloat16); mul_229 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_119: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_496, convert_element_type_497); convert_element_type_496 = convert_element_type_497 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_230: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_495, 0.5); convert_element_type_495 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_120: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_119, mul_230); add_119 = mul_230 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_234: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_120, [32768, 3072]); add_120 = None
mm_45: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_234, permute_205); permute_205 = None
permute_206: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_234, [1, 0])
mm_46: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_206, view_80); permute_206 = view_80 = None
permute_207: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_46, [1, 0]); mm_46 = None
sum_70: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_234, [0], True, dtype = torch.float32); view_234 = None
view_235: "f32[3072]" = torch.ops.aten.reshape.default(sum_70, [3072]); sum_70 = None
permute_208: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_207, [1, 0]); permute_207 = None
view_236: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_45, [32, 1024, 768]); mm_45 = None
convert_element_type_503: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_236, torch.float32); view_236 = None
convert_element_type_504: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_208, torch.float32); permute_208 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_232: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_503, primals_81); primals_81 = None
mul_233: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_232, 768)
sum_71: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_232, [2], True)
mul_234: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_232, mul_50); mul_232 = None
sum_72: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_234, [2], True); mul_234 = None
mul_235: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_50, sum_72); sum_72 = None
sub_68: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_233, sum_71); mul_233 = sum_71 = None
sub_69: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_68, mul_235); sub_68 = mul_235 = None
mul_236: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_13, sub_69); div_13 = sub_69 = None
mul_237: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_503, mul_50); mul_50 = None
sum_73: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_237, [0, 1]); mul_237 = None
sum_74: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_503, [0, 1]); convert_element_type_503 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_121: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_118, mul_236); add_118 = mul_236 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_506: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_121, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_237: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_506, [32768, 768]); convert_element_type_506 = None
mm_47: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_237, permute_209); permute_209 = None
permute_210: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_237, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_52: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_101, [0, 2, 1, 3])
view_77: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_52, [32, 1024, 768]); permute_52 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_78: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_77, [32768, 768]); view_77 = None
mm_48: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_210, view_78); permute_210 = view_78 = None
permute_211: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_48, [1, 0]); mm_48 = None
sum_75: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_237, [0], True, dtype = torch.float32); view_237 = None
view_238: "f32[768]" = torch.ops.aten.reshape.default(sum_75, [768]); sum_75 = None
permute_212: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_211, [1, 0]); permute_211 = None
view_239: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_47, [32, 1024, 768]); mm_47 = None
convert_element_type_512: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_212, torch.float32); permute_212 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_240: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_239, [32, 1024, 12, 64]); view_239 = None
permute_213: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_240, [0, 2, 1, 3]); view_240 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_5 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_213, permute_50, permute_49, permute_51, getitem_101, getitem_102, None, None, 1024, 1024, 0.0, True, getitem_107, getitem_108, scale = 0.125); permute_213 = permute_50 = permute_49 = permute_51 = getitem_101 = getitem_102 = getitem_107 = getitem_108 = None
getitem_209: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_5[0]
getitem_210: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_5[1]
getitem_211: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_5[2]; _scaled_dot_product_flash_attention_backward_5 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_214: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_211, [0, 2, 1, 3]); getitem_211 = None
view_241: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_214, [32, 1024, 768]); permute_214 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_215: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_209, [0, 2, 1, 3]); getitem_209 = None
view_242: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_215, [32, 1024, 768]); permute_215 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_216: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_210, [0, 2, 1, 3]); getitem_210 = None
view_243: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_216, [32, 1024, 768]); permute_216 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_5: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_242, view_243, view_241], 2); view_242 = view_243 = view_241 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_244: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_5, [32768, 2304]); cat_5 = None
mm_49: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_244, permute_217); permute_217 = None
permute_218: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_244, [1, 0])
mm_50: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_218, view_72); permute_218 = view_72 = None
permute_219: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_50, [1, 0]); mm_50 = None
sum_76: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_244, [0], True, dtype = torch.float32); view_244 = None
view_245: "f32[2304]" = torch.ops.aten.reshape.default(sum_76, [2304]); sum_76 = None
permute_220: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_219, [1, 0]); permute_219 = None
view_246: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_49, [32, 1024, 768]); mm_49 = None
convert_element_type_519: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_246, torch.float32); view_246 = None
convert_element_type_520: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_220, torch.float32); permute_220 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_239: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_519, primals_75); primals_75 = None
mul_240: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_239, 768)
sum_77: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_239, [2], True)
mul_241: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_239, mul_48); mul_239 = None
sum_78: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_241, [2], True); mul_241 = None
mul_242: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_48, sum_78); sum_78 = None
sub_71: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_240, sum_77); mul_240 = sum_77 = None
sub_72: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_71, mul_242); sub_71 = mul_242 = None
mul_243: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_14, sub_72); div_14 = sub_72 = None
mul_244: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_519, mul_48); mul_48 = None
sum_79: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_244, [0, 1]); mul_244 = None
sum_80: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_519, [0, 1]); convert_element_type_519 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_122: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_121, mul_243); add_121 = mul_243 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_522: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_122, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_247: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_522, [32768, 768]); convert_element_type_522 = None
mm_51: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_247, permute_221); permute_221 = None
permute_222: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_247, [1, 0])
mm_52: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_222, view_70); permute_222 = view_70 = None
permute_223: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_52, [1, 0]); mm_52 = None
sum_81: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_247, [0], True, dtype = torch.float32); view_247 = None
view_248: "f32[768]" = torch.ops.aten.reshape.default(sum_81, [768]); sum_81 = None
permute_224: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_223, [1, 0]); permute_223 = None
view_249: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_51, [32, 1024, 3072]); mm_51 = None
convert_element_type_528: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_249, torch.float32); view_249 = None
convert_element_type_529: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_224, torch.float32); permute_224 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_69: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_22, [32, 1024, 3072]); addmm_22 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_44: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_69, 0.5)
mul_245: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_528, mul_44); mul_44 = None
convert_element_type_137: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_69, torch.float32)
pow_6: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_137, 3.0)
mul_45: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_6, 0.044715); pow_6 = None
add_46: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_69, mul_45); view_69 = mul_45 = None
mul_46: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_46, 0.7978845608028654); add_46 = None
tanh_5: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_46); mul_46 = None
add_47: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_5, 1.0)
mul_246: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_528, add_47); convert_element_type_528 = add_47 = None
convert_element_type_531: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_246, torch.bfloat16); mul_246 = None
mul_247: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_5, tanh_5); tanh_5 = None
sub_73: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_247); mul_247 = None
mul_248: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_245, sub_73); mul_245 = sub_73 = None
mul_249: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_248, 0.7978845608028654); mul_248 = None
convert_element_type_532: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_249, torch.bfloat16)
mul_250: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_249, 0.044715); mul_249 = None
pow_19: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_137, 2.0); convert_element_type_137 = None
mul_251: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_19, 3.0); pow_19 = None
mul_252: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_250, mul_251); mul_250 = mul_251 = None
convert_element_type_533: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_252, torch.bfloat16); mul_252 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_123: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_532, convert_element_type_533); convert_element_type_532 = convert_element_type_533 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_253: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_531, 0.5); convert_element_type_531 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_124: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_123, mul_253); add_123 = mul_253 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_250: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_124, [32768, 3072]); add_124 = None
mm_53: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_250, permute_225); permute_225 = None
permute_226: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_250, [1, 0])
mm_54: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_226, view_68); permute_226 = view_68 = None
permute_227: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_54, [1, 0]); mm_54 = None
sum_82: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_250, [0], True, dtype = torch.float32); view_250 = None
view_251: "f32[3072]" = torch.ops.aten.reshape.default(sum_82, [3072]); sum_82 = None
permute_228: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_227, [1, 0]); permute_227 = None
view_252: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_53, [32, 1024, 768]); mm_53 = None
convert_element_type_539: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_252, torch.float32); view_252 = None
convert_element_type_540: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_228, torch.float32); permute_228 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_255: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_539, primals_69); primals_69 = None
mul_256: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_255, 768)
sum_83: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_255, [2], True)
mul_257: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_255, mul_42); mul_255 = None
sum_84: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_257, [2], True); mul_257 = None
mul_258: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_42, sum_84); sum_84 = None
sub_75: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_256, sum_83); mul_256 = sum_83 = None
sub_76: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_75, mul_258); sub_75 = mul_258 = None
mul_259: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_15, sub_76); div_15 = sub_76 = None
mul_260: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_539, mul_42); mul_42 = None
sum_85: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_260, [0, 1]); mul_260 = None
sum_86: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_539, [0, 1]); convert_element_type_539 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_125: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_122, mul_259); add_122 = mul_259 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_542: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_125, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_253: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_542, [32768, 768]); convert_element_type_542 = None
mm_55: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_253, permute_229); permute_229 = None
permute_230: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_253, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_44: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_85, [0, 2, 1, 3])
view_65: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_44, [32, 1024, 768]); permute_44 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_66: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_65, [32768, 768]); view_65 = None
mm_56: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_230, view_66); permute_230 = view_66 = None
permute_231: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_56, [1, 0]); mm_56 = None
sum_87: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_253, [0], True, dtype = torch.float32); view_253 = None
view_254: "f32[768]" = torch.ops.aten.reshape.default(sum_87, [768]); sum_87 = None
permute_232: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_231, [1, 0]); permute_231 = None
view_255: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_55, [32, 1024, 768]); mm_55 = None
convert_element_type_548: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_232, torch.float32); permute_232 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_256: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_255, [32, 1024, 12, 64]); view_255 = None
permute_233: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_256, [0, 2, 1, 3]); view_256 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_6 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_233, permute_42, permute_41, permute_43, getitem_85, getitem_86, None, None, 1024, 1024, 0.0, True, getitem_91, getitem_92, scale = 0.125); permute_233 = permute_42 = permute_41 = permute_43 = getitem_85 = getitem_86 = getitem_91 = getitem_92 = None
getitem_212: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_6[0]
getitem_213: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_6[1]
getitem_214: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_6[2]; _scaled_dot_product_flash_attention_backward_6 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_234: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_214, [0, 2, 1, 3]); getitem_214 = None
view_257: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_234, [32, 1024, 768]); permute_234 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_235: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_212, [0, 2, 1, 3]); getitem_212 = None
view_258: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_235, [32, 1024, 768]); permute_235 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_236: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_213, [0, 2, 1, 3]); getitem_213 = None
view_259: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_236, [32, 1024, 768]); permute_236 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_6: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_258, view_259, view_257], 2); view_258 = view_259 = view_257 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_260: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_6, [32768, 2304]); cat_6 = None
mm_57: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_260, permute_237); permute_237 = None
permute_238: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_260, [1, 0])
mm_58: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_238, view_60); permute_238 = view_60 = None
permute_239: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_58, [1, 0]); mm_58 = None
sum_88: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_260, [0], True, dtype = torch.float32); view_260 = None
view_261: "f32[2304]" = torch.ops.aten.reshape.default(sum_88, [2304]); sum_88 = None
permute_240: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_239, [1, 0]); permute_239 = None
view_262: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_57, [32, 1024, 768]); mm_57 = None
convert_element_type_555: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_262, torch.float32); view_262 = None
convert_element_type_556: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_240, torch.float32); permute_240 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_262: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_555, primals_63); primals_63 = None
mul_263: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_262, 768)
sum_89: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_262, [2], True)
mul_264: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_262, mul_40); mul_262 = None
sum_90: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_264, [2], True); mul_264 = None
mul_265: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_40, sum_90); sum_90 = None
sub_78: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_263, sum_89); mul_263 = sum_89 = None
sub_79: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_78, mul_265); sub_78 = mul_265 = None
mul_266: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_16, sub_79); div_16 = sub_79 = None
mul_267: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_555, mul_40); mul_40 = None
sum_91: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_267, [0, 1]); mul_267 = None
sum_92: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_555, [0, 1]); convert_element_type_555 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_126: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_125, mul_266); add_125 = mul_266 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_558: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_126, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_263: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_558, [32768, 768]); convert_element_type_558 = None
mm_59: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_263, permute_241); permute_241 = None
permute_242: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_263, [1, 0])
mm_60: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_242, view_58); permute_242 = view_58 = None
permute_243: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_60, [1, 0]); mm_60 = None
sum_93: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_263, [0], True, dtype = torch.float32); view_263 = None
view_264: "f32[768]" = torch.ops.aten.reshape.default(sum_93, [768]); sum_93 = None
permute_244: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_243, [1, 0]); permute_243 = None
view_265: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_59, [32, 1024, 3072]); mm_59 = None
convert_element_type_564: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_265, torch.float32); view_265 = None
convert_element_type_565: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_244, torch.float32); permute_244 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_57: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_18, [32, 1024, 3072]); addmm_18 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_36: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_57, 0.5)
mul_268: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_564, mul_36); mul_36 = None
convert_element_type_113: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_57, torch.float32)
pow_5: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_113, 3.0)
mul_37: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_5, 0.044715); pow_5 = None
add_38: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_57, mul_37); view_57 = mul_37 = None
mul_38: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_38, 0.7978845608028654); add_38 = None
tanh_4: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_38); mul_38 = None
add_39: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_4, 1.0)
mul_269: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_564, add_39); convert_element_type_564 = add_39 = None
convert_element_type_567: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_269, torch.bfloat16); mul_269 = None
mul_270: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_4, tanh_4); tanh_4 = None
sub_80: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_270); mul_270 = None
mul_271: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_268, sub_80); mul_268 = sub_80 = None
mul_272: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_271, 0.7978845608028654); mul_271 = None
convert_element_type_568: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_272, torch.bfloat16)
mul_273: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_272, 0.044715); mul_272 = None
pow_20: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_113, 2.0); convert_element_type_113 = None
mul_274: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_20, 3.0); pow_20 = None
mul_275: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_273, mul_274); mul_273 = mul_274 = None
convert_element_type_569: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_275, torch.bfloat16); mul_275 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_127: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_568, convert_element_type_569); convert_element_type_568 = convert_element_type_569 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_276: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_567, 0.5); convert_element_type_567 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_128: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_127, mul_276); add_127 = mul_276 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_266: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_128, [32768, 3072]); add_128 = None
mm_61: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_266, permute_245); permute_245 = None
permute_246: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_266, [1, 0])
mm_62: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_246, view_56); permute_246 = view_56 = None
permute_247: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_62, [1, 0]); mm_62 = None
sum_94: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_266, [0], True, dtype = torch.float32); view_266 = None
view_267: "f32[3072]" = torch.ops.aten.reshape.default(sum_94, [3072]); sum_94 = None
permute_248: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_247, [1, 0]); permute_247 = None
view_268: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_61, [32, 1024, 768]); mm_61 = None
convert_element_type_575: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_268, torch.float32); view_268 = None
convert_element_type_576: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_248, torch.float32); permute_248 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_278: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_575, primals_57); primals_57 = None
mul_279: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_278, 768)
sum_95: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_278, [2], True)
mul_280: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_278, mul_34); mul_278 = None
sum_96: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_280, [2], True); mul_280 = None
mul_281: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_34, sum_96); sum_96 = None
sub_82: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_279, sum_95); mul_279 = sum_95 = None
sub_83: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_82, mul_281); sub_82 = mul_281 = None
mul_282: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_17, sub_83); div_17 = sub_83 = None
mul_283: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_575, mul_34); mul_34 = None
sum_97: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_283, [0, 1]); mul_283 = None
sum_98: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_575, [0, 1]); convert_element_type_575 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_129: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_126, mul_282); add_126 = mul_282 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_578: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_129, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_269: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_578, [32768, 768]); convert_element_type_578 = None
mm_63: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_269, permute_249); permute_249 = None
permute_250: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_269, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_36: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_69, [0, 2, 1, 3])
view_53: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_36, [32, 1024, 768]); permute_36 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_54: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_53, [32768, 768]); view_53 = None
mm_64: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_250, view_54); permute_250 = view_54 = None
permute_251: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_64, [1, 0]); mm_64 = None
sum_99: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_269, [0], True, dtype = torch.float32); view_269 = None
view_270: "f32[768]" = torch.ops.aten.reshape.default(sum_99, [768]); sum_99 = None
permute_252: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_251, [1, 0]); permute_251 = None
view_271: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_63, [32, 1024, 768]); mm_63 = None
convert_element_type_584: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_252, torch.float32); permute_252 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_272: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_271, [32, 1024, 12, 64]); view_271 = None
permute_253: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_272, [0, 2, 1, 3]); view_272 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_7 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_253, permute_34, permute_33, permute_35, getitem_69, getitem_70, None, None, 1024, 1024, 0.0, True, getitem_75, getitem_76, scale = 0.125); permute_253 = permute_34 = permute_33 = permute_35 = getitem_69 = getitem_70 = getitem_75 = getitem_76 = None
getitem_215: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_7[0]
getitem_216: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_7[1]
getitem_217: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_7[2]; _scaled_dot_product_flash_attention_backward_7 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_254: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_217, [0, 2, 1, 3]); getitem_217 = None
view_273: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_254, [32, 1024, 768]); permute_254 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_255: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_215, [0, 2, 1, 3]); getitem_215 = None
view_274: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_255, [32, 1024, 768]); permute_255 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_256: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_216, [0, 2, 1, 3]); getitem_216 = None
view_275: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_256, [32, 1024, 768]); permute_256 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_7: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_274, view_275, view_273], 2); view_274 = view_275 = view_273 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_276: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_7, [32768, 2304]); cat_7 = None
mm_65: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_276, permute_257); permute_257 = None
permute_258: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_276, [1, 0])
mm_66: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_258, view_48); permute_258 = view_48 = None
permute_259: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_66, [1, 0]); mm_66 = None
sum_100: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_276, [0], True, dtype = torch.float32); view_276 = None
view_277: "f32[2304]" = torch.ops.aten.reshape.default(sum_100, [2304]); sum_100 = None
permute_260: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_259, [1, 0]); permute_259 = None
view_278: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_65, [32, 1024, 768]); mm_65 = None
convert_element_type_591: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_278, torch.float32); view_278 = None
convert_element_type_592: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_260, torch.float32); permute_260 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_285: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_591, primals_51); primals_51 = None
mul_286: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_285, 768)
sum_101: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_285, [2], True)
mul_287: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_285, mul_32); mul_285 = None
sum_102: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_287, [2], True); mul_287 = None
mul_288: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_32, sum_102); sum_102 = None
sub_85: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_286, sum_101); mul_286 = sum_101 = None
sub_86: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_85, mul_288); sub_85 = mul_288 = None
mul_289: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_18, sub_86); div_18 = sub_86 = None
mul_290: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_591, mul_32); mul_32 = None
sum_103: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_290, [0, 1]); mul_290 = None
sum_104: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_591, [0, 1]); convert_element_type_591 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_130: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_129, mul_289); add_129 = mul_289 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_594: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_130, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_279: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_594, [32768, 768]); convert_element_type_594 = None
mm_67: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_279, permute_261); permute_261 = None
permute_262: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_279, [1, 0])
mm_68: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_262, view_46); permute_262 = view_46 = None
permute_263: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_68, [1, 0]); mm_68 = None
sum_105: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_279, [0], True, dtype = torch.float32); view_279 = None
view_280: "f32[768]" = torch.ops.aten.reshape.default(sum_105, [768]); sum_105 = None
permute_264: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_263, [1, 0]); permute_263 = None
view_281: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_67, [32, 1024, 3072]); mm_67 = None
convert_element_type_600: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_281, torch.float32); view_281 = None
convert_element_type_601: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_264, torch.float32); permute_264 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_45: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_14, [32, 1024, 3072]); addmm_14 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_28: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_45, 0.5)
mul_291: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_600, mul_28); mul_28 = None
convert_element_type_89: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_45, torch.float32)
pow_4: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_89, 3.0)
mul_29: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_4, 0.044715); pow_4 = None
add_30: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_45, mul_29); view_45 = mul_29 = None
mul_30: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_30, 0.7978845608028654); add_30 = None
tanh_3: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_30); mul_30 = None
add_31: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_3, 1.0)
mul_292: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_600, add_31); convert_element_type_600 = add_31 = None
convert_element_type_603: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_292, torch.bfloat16); mul_292 = None
mul_293: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_3, tanh_3); tanh_3 = None
sub_87: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_293); mul_293 = None
mul_294: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_291, sub_87); mul_291 = sub_87 = None
mul_295: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_294, 0.7978845608028654); mul_294 = None
convert_element_type_604: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_295, torch.bfloat16)
mul_296: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_295, 0.044715); mul_295 = None
pow_21: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_89, 2.0); convert_element_type_89 = None
mul_297: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_21, 3.0); pow_21 = None
mul_298: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_296, mul_297); mul_296 = mul_297 = None
convert_element_type_605: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_298, torch.bfloat16); mul_298 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_131: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_604, convert_element_type_605); convert_element_type_604 = convert_element_type_605 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_299: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_603, 0.5); convert_element_type_603 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_132: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_131, mul_299); add_131 = mul_299 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_282: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_132, [32768, 3072]); add_132 = None
mm_69: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_282, permute_265); permute_265 = None
permute_266: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_282, [1, 0])
mm_70: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_266, view_44); permute_266 = view_44 = None
permute_267: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_70, [1, 0]); mm_70 = None
sum_106: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_282, [0], True, dtype = torch.float32); view_282 = None
view_283: "f32[3072]" = torch.ops.aten.reshape.default(sum_106, [3072]); sum_106 = None
permute_268: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_267, [1, 0]); permute_267 = None
view_284: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_69, [32, 1024, 768]); mm_69 = None
convert_element_type_611: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_284, torch.float32); view_284 = None
convert_element_type_612: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_268, torch.float32); permute_268 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_301: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_611, primals_45); primals_45 = None
mul_302: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_301, 768)
sum_107: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_301, [2], True)
mul_303: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_301, mul_26); mul_301 = None
sum_108: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_303, [2], True); mul_303 = None
mul_304: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_26, sum_108); sum_108 = None
sub_89: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_302, sum_107); mul_302 = sum_107 = None
sub_90: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_89, mul_304); sub_89 = mul_304 = None
mul_305: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_19, sub_90); div_19 = sub_90 = None
mul_306: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_611, mul_26); mul_26 = None
sum_109: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_306, [0, 1]); mul_306 = None
sum_110: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_611, [0, 1]); convert_element_type_611 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_133: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_130, mul_305); add_130 = mul_305 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_614: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_133, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_285: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_614, [32768, 768]); convert_element_type_614 = None
mm_71: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_285, permute_269); permute_269 = None
permute_270: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_285, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_28: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_53, [0, 2, 1, 3])
view_41: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_28, [32, 1024, 768]); permute_28 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_42: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_41, [32768, 768]); view_41 = None
mm_72: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_270, view_42); permute_270 = view_42 = None
permute_271: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_72, [1, 0]); mm_72 = None
sum_111: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_285, [0], True, dtype = torch.float32); view_285 = None
view_286: "f32[768]" = torch.ops.aten.reshape.default(sum_111, [768]); sum_111 = None
permute_272: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_271, [1, 0]); permute_271 = None
view_287: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_71, [32, 1024, 768]); mm_71 = None
convert_element_type_620: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_272, torch.float32); permute_272 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_288: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_287, [32, 1024, 12, 64]); view_287 = None
permute_273: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_288, [0, 2, 1, 3]); view_288 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_8 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_273, permute_26, permute_25, permute_27, getitem_53, getitem_54, None, None, 1024, 1024, 0.0, True, getitem_59, getitem_60, scale = 0.125); permute_273 = permute_26 = permute_25 = permute_27 = getitem_53 = getitem_54 = getitem_59 = getitem_60 = None
getitem_218: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_8[0]
getitem_219: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_8[1]
getitem_220: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_8[2]; _scaled_dot_product_flash_attention_backward_8 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_274: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_220, [0, 2, 1, 3]); getitem_220 = None
view_289: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_274, [32, 1024, 768]); permute_274 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_275: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_218, [0, 2, 1, 3]); getitem_218 = None
view_290: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_275, [32, 1024, 768]); permute_275 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_276: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_219, [0, 2, 1, 3]); getitem_219 = None
view_291: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_276, [32, 1024, 768]); permute_276 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_8: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_290, view_291, view_289], 2); view_290 = view_291 = view_289 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_292: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_8, [32768, 2304]); cat_8 = None
mm_73: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_292, permute_277); permute_277 = None
permute_278: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_292, [1, 0])
mm_74: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_278, view_36); permute_278 = view_36 = None
permute_279: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_74, [1, 0]); mm_74 = None
sum_112: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_292, [0], True, dtype = torch.float32); view_292 = None
view_293: "f32[2304]" = torch.ops.aten.reshape.default(sum_112, [2304]); sum_112 = None
permute_280: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_279, [1, 0]); permute_279 = None
view_294: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_73, [32, 1024, 768]); mm_73 = None
convert_element_type_627: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_294, torch.float32); view_294 = None
convert_element_type_628: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_280, torch.float32); permute_280 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_308: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_627, primals_39); primals_39 = None
mul_309: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_308, 768)
sum_113: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_308, [2], True)
mul_310: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_308, mul_24); mul_308 = None
sum_114: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_310, [2], True); mul_310 = None
mul_311: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_24, sum_114); sum_114 = None
sub_92: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_309, sum_113); mul_309 = sum_113 = None
sub_93: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_92, mul_311); sub_92 = mul_311 = None
mul_312: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_20, sub_93); div_20 = sub_93 = None
mul_313: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_627, mul_24); mul_24 = None
sum_115: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_313, [0, 1]); mul_313 = None
sum_116: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_627, [0, 1]); convert_element_type_627 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_134: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_133, mul_312); add_133 = mul_312 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_630: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_134, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_295: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_630, [32768, 768]); convert_element_type_630 = None
mm_75: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_295, permute_281); permute_281 = None
permute_282: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_295, [1, 0])
mm_76: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_282, view_34); permute_282 = view_34 = None
permute_283: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_76, [1, 0]); mm_76 = None
sum_117: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_295, [0], True, dtype = torch.float32); view_295 = None
view_296: "f32[768]" = torch.ops.aten.reshape.default(sum_117, [768]); sum_117 = None
permute_284: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_283, [1, 0]); permute_283 = None
view_297: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_75, [32, 1024, 3072]); mm_75 = None
convert_element_type_636: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_297, torch.float32); view_297 = None
convert_element_type_637: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_284, torch.float32); permute_284 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_33: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_10, [32, 1024, 3072]); addmm_10 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_20: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_33, 0.5)
mul_314: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_636, mul_20); mul_20 = None
convert_element_type_65: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_33, torch.float32)
pow_3: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_65, 3.0)
mul_21: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_3, 0.044715); pow_3 = None
add_22: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_33, mul_21); view_33 = mul_21 = None
mul_22: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_22, 0.7978845608028654); add_22 = None
tanh_2: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_22); mul_22 = None
add_23: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_2, 1.0)
mul_315: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_636, add_23); convert_element_type_636 = add_23 = None
convert_element_type_639: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_315, torch.bfloat16); mul_315 = None
mul_316: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_2, tanh_2); tanh_2 = None
sub_94: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_316); mul_316 = None
mul_317: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_314, sub_94); mul_314 = sub_94 = None
mul_318: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_317, 0.7978845608028654); mul_317 = None
convert_element_type_640: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_318, torch.bfloat16)
mul_319: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_318, 0.044715); mul_318 = None
pow_22: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_65, 2.0); convert_element_type_65 = None
mul_320: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_22, 3.0); pow_22 = None
mul_321: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_319, mul_320); mul_319 = mul_320 = None
convert_element_type_641: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_321, torch.bfloat16); mul_321 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_135: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_640, convert_element_type_641); convert_element_type_640 = convert_element_type_641 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_322: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_639, 0.5); convert_element_type_639 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_136: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_135, mul_322); add_135 = mul_322 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_298: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_136, [32768, 3072]); add_136 = None
mm_77: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_298, permute_285); permute_285 = None
permute_286: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_298, [1, 0])
mm_78: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_286, view_32); permute_286 = view_32 = None
permute_287: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_78, [1, 0]); mm_78 = None
sum_118: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_298, [0], True, dtype = torch.float32); view_298 = None
view_299: "f32[3072]" = torch.ops.aten.reshape.default(sum_118, [3072]); sum_118 = None
permute_288: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_287, [1, 0]); permute_287 = None
view_300: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_77, [32, 1024, 768]); mm_77 = None
convert_element_type_647: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None
convert_element_type_648: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_288, torch.float32); permute_288 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_324: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_647, primals_33); primals_33 = None
mul_325: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_324, 768)
sum_119: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_324, [2], True)
mul_326: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_324, mul_18); mul_324 = None
sum_120: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_326, [2], True); mul_326 = None
mul_327: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_18, sum_120); sum_120 = None
sub_96: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_325, sum_119); mul_325 = sum_119 = None
sub_97: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_96, mul_327); sub_96 = mul_327 = None
mul_328: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_21, sub_97); div_21 = sub_97 = None
mul_329: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_647, mul_18); mul_18 = None
sum_121: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_329, [0, 1]); mul_329 = None
sum_122: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_647, [0, 1]); convert_element_type_647 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_137: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_134, mul_328); add_134 = mul_328 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_650: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_137, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_301: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_650, [32768, 768]); convert_element_type_650 = None
mm_79: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_301, permute_289); permute_289 = None
permute_290: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_301, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_20: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_37, [0, 2, 1, 3])
view_29: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_20, [32, 1024, 768]); permute_20 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_30: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_29, [32768, 768]); view_29 = None
mm_80: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_290, view_30); permute_290 = view_30 = None
permute_291: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_80, [1, 0]); mm_80 = None
sum_123: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_301, [0], True, dtype = torch.float32); view_301 = None
view_302: "f32[768]" = torch.ops.aten.reshape.default(sum_123, [768]); sum_123 = None
permute_292: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_291, [1, 0]); permute_291 = None
view_303: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_79, [32, 1024, 768]); mm_79 = None
convert_element_type_656: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_292, torch.float32); permute_292 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_304: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_303, [32, 1024, 12, 64]); view_303 = None
permute_293: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_304, [0, 2, 1, 3]); view_304 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_9 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_293, permute_18, permute_17, permute_19, getitem_37, getitem_38, None, None, 1024, 1024, 0.0, True, getitem_43, getitem_44, scale = 0.125); permute_293 = permute_18 = permute_17 = permute_19 = getitem_37 = getitem_38 = getitem_43 = getitem_44 = None
getitem_221: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_9[0]
getitem_222: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_9[1]
getitem_223: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_9[2]; _scaled_dot_product_flash_attention_backward_9 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_294: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_223, [0, 2, 1, 3]); getitem_223 = None
view_305: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_294, [32, 1024, 768]); permute_294 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_295: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_221, [0, 2, 1, 3]); getitem_221 = None
view_306: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_295, [32, 1024, 768]); permute_295 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_296: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_222, [0, 2, 1, 3]); getitem_222 = None
view_307: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_296, [32, 1024, 768]); permute_296 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_9: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_306, view_307, view_305], 2); view_306 = view_307 = view_305 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_308: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_9, [32768, 2304]); cat_9 = None
mm_81: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_308, permute_297); permute_297 = None
permute_298: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_308, [1, 0])
mm_82: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_298, view_24); permute_298 = view_24 = None
permute_299: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_82, [1, 0]); mm_82 = None
sum_124: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_308, [0], True, dtype = torch.float32); view_308 = None
view_309: "f32[2304]" = torch.ops.aten.reshape.default(sum_124, [2304]); sum_124 = None
permute_300: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_299, [1, 0]); permute_299 = None
view_310: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_81, [32, 1024, 768]); mm_81 = None
convert_element_type_663: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_310, torch.float32); view_310 = None
convert_element_type_664: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_300, torch.float32); permute_300 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_331: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_663, primals_27); primals_27 = None
mul_332: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_331, 768)
sum_125: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_331, [2], True)
mul_333: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_331, mul_16); mul_331 = None
sum_126: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_333, [2], True); mul_333 = None
mul_334: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_16, sum_126); sum_126 = None
sub_99: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_332, sum_125); mul_332 = sum_125 = None
sub_100: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_99, mul_334); sub_99 = mul_334 = None
mul_335: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_22, sub_100); div_22 = sub_100 = None
mul_336: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_663, mul_16); mul_16 = None
sum_127: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_336, [0, 1]); mul_336 = None
sum_128: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_663, [0, 1]); convert_element_type_663 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_138: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_137, mul_335); add_137 = mul_335 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_666: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_138, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_311: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_666, [32768, 768]); convert_element_type_666 = None
mm_83: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_311, permute_301); permute_301 = None
permute_302: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_311, [1, 0])
mm_84: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_302, view_22); permute_302 = view_22 = None
permute_303: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_84, [1, 0]); mm_84 = None
sum_129: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_311, [0], True, dtype = torch.float32); view_311 = None
view_312: "f32[768]" = torch.ops.aten.reshape.default(sum_129, [768]); sum_129 = None
permute_304: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_303, [1, 0]); permute_303 = None
view_313: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_83, [32, 1024, 3072]); mm_83 = None
convert_element_type_672: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_313, torch.float32); view_313 = None
convert_element_type_673: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_304, torch.float32); permute_304 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_21: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_6, [32, 1024, 3072]); addmm_6 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_12: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_21, 0.5)
mul_337: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_672, mul_12); mul_12 = None
convert_element_type_41: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_21, torch.float32)
pow_2: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_41, 3.0)
mul_13: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_2, 0.044715); pow_2 = None
add_14: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_21, mul_13); view_21 = mul_13 = None
mul_14: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_14, 0.7978845608028654); add_14 = None
tanh_1: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_14); mul_14 = None
add_15: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh_1, 1.0)
mul_338: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_672, add_15); convert_element_type_672 = add_15 = None
convert_element_type_675: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_338, torch.bfloat16); mul_338 = None
mul_339: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh_1, tanh_1); tanh_1 = None
sub_101: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_339); mul_339 = None
mul_340: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_337, sub_101); mul_337 = sub_101 = None
mul_341: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_340, 0.7978845608028654); mul_340 = None
convert_element_type_676: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_341, torch.bfloat16)
mul_342: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_341, 0.044715); mul_341 = None
pow_23: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_41, 2.0); convert_element_type_41 = None
mul_343: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_23, 3.0); pow_23 = None
mul_344: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_342, mul_343); mul_342 = mul_343 = None
convert_element_type_677: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_344, torch.bfloat16); mul_344 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_139: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_676, convert_element_type_677); convert_element_type_676 = convert_element_type_677 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_345: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_675, 0.5); convert_element_type_675 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_140: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_139, mul_345); add_139 = mul_345 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_314: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_140, [32768, 3072]); add_140 = None
mm_85: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_314, permute_305); permute_305 = None
permute_306: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_314, [1, 0])
mm_86: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_306, view_20); permute_306 = view_20 = None
permute_307: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_86, [1, 0]); mm_86 = None
sum_130: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_314, [0], True, dtype = torch.float32); view_314 = None
view_315: "f32[3072]" = torch.ops.aten.reshape.default(sum_130, [3072]); sum_130 = None
permute_308: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_307, [1, 0]); permute_307 = None
view_316: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_85, [32, 1024, 768]); mm_85 = None
convert_element_type_683: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_316, torch.float32); view_316 = None
convert_element_type_684: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_308, torch.float32); permute_308 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_347: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_683, primals_21); primals_21 = None
mul_348: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_347, 768)
sum_131: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_347, [2], True)
mul_349: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_347, mul_10); mul_347 = None
sum_132: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_349, [2], True); mul_349 = None
mul_350: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_10, sum_132); sum_132 = None
sub_103: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_348, sum_131); mul_348 = sum_131 = None
sub_104: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_103, mul_350); sub_103 = mul_350 = None
mul_351: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_23, sub_104); div_23 = sub_104 = None
mul_352: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_683, mul_10); mul_10 = None
sum_133: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_352, [0, 1]); mul_352 = None
sum_134: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_683, [0, 1]); convert_element_type_683 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_141: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_138, mul_351); add_138 = mul_351 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_686: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_141, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_317: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_686, [32768, 768]); convert_element_type_686 = None
mm_87: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_317, permute_309); permute_309 = None
permute_310: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_317, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_12: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_21, [0, 2, 1, 3])
view_17: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_12, [32, 1024, 768]); permute_12 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_18: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_17, [32768, 768]); view_17 = None
mm_88: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_310, view_18); permute_310 = view_18 = None
permute_311: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_88, [1, 0]); mm_88 = None
sum_135: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_317, [0], True, dtype = torch.float32); view_317 = None
view_318: "f32[768]" = torch.ops.aten.reshape.default(sum_135, [768]); sum_135 = None
permute_312: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_311, [1, 0]); permute_311 = None
view_319: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_87, [32, 1024, 768]); mm_87 = None
convert_element_type_692: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_312, torch.float32); permute_312 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_320: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_319, [32, 1024, 12, 64]); view_319 = None
permute_313: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_320, [0, 2, 1, 3]); view_320 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_10 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_313, permute_10, permute_9, permute_11, getitem_21, getitem_22, None, None, 1024, 1024, 0.0, True, getitem_27, getitem_28, scale = 0.125); permute_313 = permute_10 = permute_9 = permute_11 = getitem_21 = getitem_22 = getitem_27 = getitem_28 = None
getitem_224: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_10[0]
getitem_225: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_10[1]
getitem_226: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_10[2]; _scaled_dot_product_flash_attention_backward_10 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_314: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_226, [0, 2, 1, 3]); getitem_226 = None
view_321: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_314, [32, 1024, 768]); permute_314 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_315: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_224, [0, 2, 1, 3]); getitem_224 = None
view_322: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_315, [32, 1024, 768]); permute_315 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_316: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_225, [0, 2, 1, 3]); getitem_225 = None
view_323: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_316, [32, 1024, 768]); permute_316 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_10: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_322, view_323, view_321], 2); view_322 = view_323 = view_321 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_324: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_10, [32768, 2304]); cat_10 = None
mm_89: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_324, permute_317); permute_317 = None
permute_318: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_324, [1, 0])
mm_90: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_318, view_12); permute_318 = view_12 = None
permute_319: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_90, [1, 0]); mm_90 = None
sum_136: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_324, [0], True, dtype = torch.float32); view_324 = None
view_325: "f32[2304]" = torch.ops.aten.reshape.default(sum_136, [2304]); sum_136 = None
permute_320: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_319, [1, 0]); permute_319 = None
view_326: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_89, [32, 1024, 768]); mm_89 = None
convert_element_type_699: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_326, torch.float32); view_326 = None
convert_element_type_700: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_320, torch.float32); permute_320 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_354: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_699, primals_15); primals_15 = None
mul_355: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_354, 768)
sum_137: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_354, [2], True)
mul_356: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_354, mul_8); mul_354 = None
sum_138: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_356, [2], True); mul_356 = None
mul_357: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_8, sum_138); sum_138 = None
sub_106: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_355, sum_137); mul_355 = sum_137 = None
sub_107: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_106, mul_357); sub_106 = mul_357 = None
mul_358: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_24, sub_107); div_24 = sub_107 = None
mul_359: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_699, mul_8); mul_8 = None
sum_139: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_359, [0, 1]); mul_359 = None
sum_140: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_699, [0, 1]); convert_element_type_699 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_142: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_141, mul_358); add_141 = mul_358 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
convert_element_type_702: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_142, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:93 in forward, code: x = self.c_proj(x)
view_327: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_702, [32768, 768]); convert_element_type_702 = None
mm_91: "bf16[32768, 3072]" = torch.ops.aten.mm.default(view_327, permute_321); permute_321 = None
permute_322: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_327, [1, 0])
mm_92: "bf16[768, 3072]" = torch.ops.aten.mm.default(permute_322, view_10); permute_322 = view_10 = None
permute_323: "bf16[3072, 768]" = torch.ops.aten.permute.default(mm_92, [1, 0]); mm_92 = None
sum_141: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_327, [0], True, dtype = torch.float32); view_327 = None
view_328: "f32[768]" = torch.ops.aten.reshape.default(sum_141, [768]); sum_141 = None
permute_324: "bf16[768, 3072]" = torch.ops.aten.permute.default(permute_323, [1, 0]); permute_323 = None
view_329: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(mm_91, [32, 1024, 3072]); mm_91 = None
convert_element_type_708: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_329, torch.float32); view_329 = None
convert_element_type_709: "f32[768, 3072]" = torch.ops.prims.convert_element_type.default(permute_324, torch.float32); permute_324 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_9: "bf16[32, 1024, 3072]" = torch.ops.aten.reshape.default(addmm_2, [32, 1024, 3072]); addmm_2 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_4: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(view_9, 0.5)
mul_360: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_708, mul_4); mul_4 = None
convert_element_type_17: "f32[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(view_9, torch.float32)
pow_1: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_17, 3.0)
mul_5: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(pow_1, 0.044715); pow_1 = None
add_6: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(view_9, mul_5); view_9 = mul_5 = None
mul_6: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(add_6, 0.7978845608028654); add_6 = None
tanh: "f32[32, 1024, 3072]" = torch.ops.aten.tanh.default(mul_6); mul_6 = None
add_7: "f32[32, 1024, 3072]" = torch.ops.aten.add.Tensor(tanh, 1.0)
mul_361: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_708, add_7); convert_element_type_708 = add_7 = None
convert_element_type_711: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_361, torch.bfloat16); mul_361 = None
mul_362: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(tanh, tanh); tanh = None
sub_108: "f32[32, 1024, 3072]" = torch.ops.aten.sub.Tensor(1, mul_362); mul_362 = None
mul_363: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_360, sub_108); mul_360 = sub_108 = None
mul_364: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_363, 0.7978845608028654); mul_363 = None
convert_element_type_712: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_364, torch.bfloat16)
mul_365: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_364, 0.044715); mul_364 = None
pow_24: "f32[32, 1024, 3072]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_17, 2.0); convert_element_type_17 = None
mul_366: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Scalar(pow_24, 3.0); pow_24 = None
mul_367: "f32[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(mul_365, mul_366); mul_365 = mul_366 = None
convert_element_type_713: "bf16[32, 1024, 3072]" = torch.ops.prims.convert_element_type.default(mul_367, torch.bfloat16); mul_367 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_143: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(convert_element_type_712, convert_element_type_713); convert_element_type_712 = convert_element_type_713 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
mul_368: "bf16[32, 1024, 3072]" = torch.ops.aten.mul.Tensor(convert_element_type_711, 0.5); convert_element_type_711 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:36 in forward, code: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
add_144: "bf16[32, 1024, 3072]" = torch.ops.aten.add.Tensor(add_143, mul_368); add_143 = mul_368 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:91 in forward, code: x = self.c_fc(x)
view_330: "bf16[32768, 3072]" = torch.ops.aten.reshape.default(add_144, [32768, 3072]); add_144 = None
mm_93: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_330, permute_325); permute_325 = None
permute_326: "bf16[3072, 32768]" = torch.ops.aten.permute.default(view_330, [1, 0])
mm_94: "bf16[3072, 768]" = torch.ops.aten.mm.default(permute_326, view_8); permute_326 = view_8 = None
permute_327: "bf16[768, 3072]" = torch.ops.aten.permute.default(mm_94, [1, 0]); mm_94 = None
sum_142: "f32[1, 3072]" = torch.ops.aten.sum.dim_IntList(view_330, [0], True, dtype = torch.float32); view_330 = None
view_331: "f32[3072]" = torch.ops.aten.reshape.default(sum_142, [3072]); sum_142 = None
permute_328: "bf16[3072, 768]" = torch.ops.aten.permute.default(permute_327, [1, 0]); permute_327 = None
view_332: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_93, [32, 1024, 768]); mm_93 = None
convert_element_type_719: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_332, torch.float32); view_332 = None
convert_element_type_720: "f32[3072, 768]" = torch.ops.prims.convert_element_type.default(permute_328, torch.float32); permute_328 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
mul_370: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_719, primals_9); primals_9 = None
mul_371: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_370, 768)
sum_143: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_370, [2], True)
mul_372: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_370, mul_2); mul_370 = None
sum_144: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_372, [2], True); mul_372 = None
mul_373: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_2, sum_144); sum_144 = None
sub_110: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_371, sum_143); mul_371 = sum_143 = None
sub_111: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_110, mul_373); sub_110 = mul_373 = None
mul_374: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_25, sub_111); div_25 = sub_111 = None
mul_375: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_719, mul_2); mul_2 = None
sum_145: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_375, [0, 1]); mul_375 = None
sum_146: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_719, [0, 1]); convert_element_type_719 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:107 in forward, code: x = x + self.mlp(self.ln_2(x))
add_145: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_142, mul_374); add_142 = mul_374 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
convert_element_type_722: "bf16[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(add_145, torch.bfloat16)
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_333: "bf16[32768, 768]" = torch.ops.aten.reshape.default(convert_element_type_722, [32768, 768]); convert_element_type_722 = None
mm_95: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_333, permute_329); permute_329 = None
permute_330: "bf16[768, 32768]" = torch.ops.aten.permute.default(view_333, [1, 0])
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
permute_4: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_5, [0, 2, 1, 3])
view_5: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_4, [32, 1024, 768]); permute_4 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:78 in forward, code: y = self.c_proj(y)
view_6: "bf16[32768, 768]" = torch.ops.aten.reshape.default(view_5, [32768, 768]); view_5 = None
mm_96: "bf16[768, 768]" = torch.ops.aten.mm.default(permute_330, view_6); permute_330 = view_6 = None
permute_331: "bf16[768, 768]" = torch.ops.aten.permute.default(mm_96, [1, 0]); mm_96 = None
sum_147: "f32[1, 768]" = torch.ops.aten.sum.dim_IntList(view_333, [0], True, dtype = torch.float32); view_333 = None
view_334: "f32[768]" = torch.ops.aten.reshape.default(sum_147, [768]); sum_147 = None
permute_332: "bf16[768, 768]" = torch.ops.aten.permute.default(permute_331, [1, 0]); permute_331 = None
view_335: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_95, [32, 1024, 768]); mm_95 = None
convert_element_type_728: "f32[768, 768]" = torch.ops.prims.convert_element_type.default(permute_332, torch.float32); permute_332 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:76 in forward, code: y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
view_336: "bf16[32, 1024, 12, 64]" = torch.ops.aten.reshape.default(view_335, [32, 1024, 12, 64]); view_335 = None
permute_333: "bf16[32, 12, 1024, 64]" = torch.ops.aten.permute.default(view_336, [0, 2, 1, 3]); view_336 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:68 in forward, code: y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
_scaled_dot_product_flash_attention_backward_11 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_333, permute_2, permute_1, permute_3, getitem_5, getitem_6, None, None, 1024, 1024, 0.0, True, getitem_11, getitem_12, scale = 0.125); permute_333 = permute_2 = permute_1 = permute_3 = getitem_5 = getitem_6 = getitem_11 = getitem_12 = None
getitem_227: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_11[0]
getitem_228: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_11[1]
getitem_229: "bf16[32, 12, 1024, 64]" = _scaled_dot_product_flash_attention_backward_11[2]; _scaled_dot_product_flash_attention_backward_11 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:65 in forward, code: v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_334: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_229, [0, 2, 1, 3]); getitem_229 = None
view_337: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_334, [32, 1024, 768]); permute_334 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:64 in forward, code: q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_335: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_227, [0, 2, 1, 3]); getitem_227 = None
view_338: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_335, [32, 1024, 768]); permute_335 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:63 in forward, code: k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
permute_336: "bf16[32, 1024, 12, 64]" = torch.ops.aten.permute.default(getitem_228, [0, 2, 1, 3]); getitem_228 = None
view_339: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(permute_336, [32, 1024, 768]); permute_336 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:62 in forward, code: q, k, v = qkv.split(self.n_embd, dim=2)
cat_11: "bf16[32, 1024, 2304]" = torch.ops.aten.cat.default([view_338, view_339, view_337], 2); view_338 = view_339 = view_337 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:61 in forward, code: qkv = self.c_attn(x)
view_340: "bf16[32768, 2304]" = torch.ops.aten.reshape.default(cat_11, [32768, 2304]); cat_11 = None
mm_97: "bf16[32768, 768]" = torch.ops.aten.mm.default(view_340, permute_337); permute_337 = None
permute_338: "bf16[2304, 32768]" = torch.ops.aten.permute.default(view_340, [1, 0])
mm_98: "bf16[2304, 768]" = torch.ops.aten.mm.default(permute_338, view); permute_338 = view = None
permute_339: "bf16[768, 2304]" = torch.ops.aten.permute.default(mm_98, [1, 0]); mm_98 = None
sum_148: "f32[1, 2304]" = torch.ops.aten.sum.dim_IntList(view_340, [0], True, dtype = torch.float32); view_340 = None
view_341: "f32[2304]" = torch.ops.aten.reshape.default(sum_148, [2304]); sum_148 = None
permute_340: "bf16[2304, 768]" = torch.ops.aten.permute.default(permute_339, [1, 0]); permute_339 = None
view_342: "bf16[32, 1024, 768]" = torch.ops.aten.reshape.default(mm_97, [32, 1024, 768]); mm_97 = None
convert_element_type_735: "f32[32, 1024, 768]" = torch.ops.prims.convert_element_type.default(view_342, torch.float32); view_342 = None
convert_element_type_736: "f32[2304, 768]" = torch.ops.prims.convert_element_type.default(permute_340, torch.float32); permute_340 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
mul_377: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_735, primals_3); primals_3 = None
mul_378: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_377, 768)
sum_149: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_377, [2], True)
# File: /home/shunting/ws/llm.c/train_gpt2.py:161 in forward, code: x = tok_emb + pos_emb
add: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(embedding, embedding_1); embedding = embedding_1 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
sub: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(add, getitem_1); add = getitem_1 = None
mul: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(sub, rsqrt); sub = None
mul_379: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul_377, mul); mul_377 = None
sum_150: "f32[32, 1024, 1]" = torch.ops.aten.sum.dim_IntList(mul_379, [2], True); mul_379 = None
mul_380: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(mul, sum_150); sum_150 = None
sub_113: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(mul_378, sum_149); mul_378 = sum_149 = None
sub_114: "f32[32, 1024, 768]" = torch.ops.aten.sub.Tensor(sub_113, mul_380); sub_113 = mul_380 = None
div_26: "f32[32, 1024, 1]" = torch.ops.aten.div.Tensor(rsqrt, 768); rsqrt = None
mul_381: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(div_26, sub_114); div_26 = sub_114 = None
mul_382: "f32[32, 1024, 768]" = torch.ops.aten.mul.Tensor(convert_element_type_735, mul); mul = None
sum_151: "f32[768]" = torch.ops.aten.sum.dim_IntList(mul_382, [0, 1]); mul_382 = None
sum_152: "f32[768]" = torch.ops.aten.sum.dim_IntList(convert_element_type_735, [0, 1]); convert_element_type_735 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:106 in forward, code: x = x + self.attn(self.ln_1(x))
add_146: "f32[32, 1024, 768]" = torch.ops.aten.add.Tensor(add_145, mul_381); add_145 = mul_381 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:161 in forward, code: x = tok_emb + pos_emb
sum_153: "f32[1, 1024, 768]" = torch.ops.aten.sum.dim_IntList(add_146, [0], True, dtype = torch.float32)
view_343: "f32[1024, 768]" = torch.ops.aten.reshape.default(sum_153, [1024, 768]); sum_153 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:160 in forward, code: pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
full_default_5: "b8[1024, 1]" = torch.ops.aten.full.default([1024, 1], False, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where_4: "f32[1024, 768]" = torch.ops.aten.where.self(full_default_5, full_default_1, view_343); full_default_5 = view_343 = None
full_default_7: "f32[1024, 768]" = torch.ops.aten.full.default([1024, 768], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
_unsafe_index_put: "f32[1024, 768]" = torch.ops.prims._unsafe_index_put_.default(full_default_7, [iota], where_4, True); full_default_7 = iota = where_4 = None
# File: /home/shunting/ws/llm.c/train_gpt2.py:159 in forward, code: tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
eq_1: "b8[32, 1024]" = torch.ops.aten.eq.Scalar(primals_150, -1)
unsqueeze_3: "b8[32, 1024, 1]" = torch.ops.aten.unsqueeze.default(eq_1, -1); eq_1 = None
where_5: "f32[32, 1024, 768]" = torch.ops.aten.where.self(unsqueeze_3, full_default_1, add_146); unsqueeze_3 = full_default_1 = add_146 = None
full_default_9: "f32[50257, 768]" = torch.ops.aten.full.default([50257, 768], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
_unsafe_index_put_1: "f32[50257, 768]" = torch.ops.prims._unsafe_index_put_.default(full_default_9, [primals_150], where_5, True); full_default_9 = primals_150 = where_5 = None
return [_unsafe_index_put_1, _unsafe_index_put, sum_151, sum_152, convert_element_type_736, view_341, convert_element_type_728, view_334, sum_145, sum_146, convert_element_type_720, view_331, convert_element_type_709, view_328, sum_139, sum_140, convert_element_type_700, view_325, convert_element_type_692, view_318, sum_133, sum_134, convert_element_type_684, view_315, convert_element_type_673, view_312, sum_127, sum_128, convert_element_type_664, view_309, convert_element_type_656, view_302, sum_121, sum_122, convert_element_type_648, view_299, convert_element_type_637, view_296, sum_115, sum_116, convert_element_type_628, view_293, convert_element_type_620, view_286, sum_109, sum_110, convert_element_type_612, view_283, convert_element_type_601, view_280, sum_103, sum_104, convert_element_type_592, view_277, convert_element_type_584, view_270, sum_97, sum_98, convert_element_type_576, view_267, convert_element_type_565, view_264, sum_91, sum_92, convert_element_type_556, view_261, convert_element_type_548, view_254, sum_85, sum_86, convert_element_type_540, view_251, convert_element_type_529, view_248, sum_79, sum_80, convert_element_type_520, view_245, convert_element_type_512, view_238, sum_73, sum_74, convert_element_type_504, view_235, convert_element_type_493, view_232, sum_67, sum_68, convert_element_type_484, view_229, convert_element_type_476, view_222, sum_61, sum_62, convert_element_type_468, view_219, convert_element_type_457, view_216, sum_55, sum_56, convert_element_type_448, view_213, convert_element_type_440, view_206, sum_49, sum_50, convert_element_type_432, view_203, convert_element_type_421, view_200, sum_43, sum_44, convert_element_type_412, view_197, convert_element_type_404, view_190, sum_37, sum_38, convert_element_type_396, view_187, convert_element_type_385, view_184, sum_31, sum_32, convert_element_type_376, view_181, convert_element_type_368, view_174, sum_25, sum_26, convert_element_type_360, view_171, convert_element_type_349, view_168, sum_19, sum_20, convert_element_type_340, view_165, convert_element_type_332, view_158, sum_13, sum_14, convert_element_type_324, view_155, convert_element_type_313, view_152, sum_7, sum_8, convert_element_type_305, None, None]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment