Skip to content

Instantly share code, notes, and snippets.

@drisspg

drisspg/file.py Secret

Created April 24, 2024 02:48
Show Gist options
  • Save drisspg/ce4c041f8df8a5a7983c5174705cf2b5 to your computer and use it in GitHub Desktop.
Save drisspg/ce4c041f8df8a5a7983c5174705cf2b5 to your computer and use it in GitHub Desktop.
aot_graphs_flex_atten_backwards.py
INFO: TRACED GRAPH
===== Forward graph 0 =====
/home/drisspg/meta/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "bf16[8, 16, 2048, 64]", primals_2: "bf16[8, 16, 2048, 64]", primals_3: "bf16[8, 16, 2048, 64]"):
# File: /home/drisspg/meta/pytorch/torch/nn/attention/_templated_attention.py:89 in _templated_attention, code: out, _ = templated_attention_hop(query, key, value, score_mod)
sdpa_score = self.sdpa_score
templated_attention = torch.ops.higher_order.templated_attention(primals_1, primals_2, primals_3, sdpa_score); sdpa_score = None
getitem: "bf16[8, 16, 2048, 64]" = templated_attention[0]
getitem_1: "f32[8, 16, 2048]" = templated_attention[1]; templated_attention = None
detach: "bf16[8, 16, 2048, 64]" = torch.ops.aten.detach.default(getitem)
detach_1: "bf16[8, 16, 2048, 64]" = torch.ops.aten.detach.default(detach); detach = None
detach_2: "f32[8, 16, 2048]" = torch.ops.aten.detach.default(getitem_1); getitem_1 = None
detach_3: "f32[8, 16, 2048]" = torch.ops.aten.detach.default(detach_2); detach_2 = None
detach_4: "bf16[8, 16, 2048, 64]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_5: "bf16[8, 16, 2048, 64]" = torch.ops.aten.detach.default(detach_4); detach_4 = None
detach_6: "f32[8, 16, 2048]" = torch.ops.aten.detach.default(detach_3); detach_3 = None
detach_7: "f32[8, 16, 2048]" = torch.ops.aten.detach.default(detach_6); detach_6 = None
return [getitem, primals_1, primals_2, primals_3, detach_5, detach_7]
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "bf16[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"):
# File: /home/drisspg/meta/pytorch/torch/nn/attention/_templated_attention.py:89 in _templated_attention, code: out, _ = templated_attention_hop(query, key, value, score_mod)
mul: "bf16[]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
add: "bf16[]" = torch.ops.aten.add.Tensor(mul, 1); mul = None
return add
INFO: TRACED GRAPH
===== Backward graph 0 =====
<eval_with_key>.7 class GraphModule(torch.nn.Module):
def forward(self, primals_1: "bf16[8, 16, 2048, 64]", primals_2: "bf16[8, 16, 2048, 64]", primals_3: "bf16[8, 16, 2048, 64]", detach_5: "bf16[8, 16, 2048, 64]", detach_7: "f32[8, 16, 2048]", tangents_1: "bf16[8, 16, 2048, 64]"):
# File: /home/drisspg/meta/pytorch/torch/nn/attention/_templated_attention.py:89 in _templated_attention, code: out, _ = templated_attention_hop(query, key, value, score_mod)
fw_graph = self.fw_graph
joint_graph = self.joint_graph
templated_attention_backward = torch.ops.higher_order.templated_attention_backward(primals_1, primals_2, primals_3, detach_5, detach_7, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = detach_5 = detach_7 = tangents_1 = fw_graph = joint_graph = None
getitem_2: "bf16[8, 16, 2048, 64]" = templated_attention_backward[0]
getitem_3: "bf16[8, 16, 2048, 64]" = templated_attention_backward[1]
getitem_4: "bf16[8, 16, 2048, 64]" = templated_attention_backward[2]; templated_attention_backward = None
return [getitem_2, getitem_3, getitem_4]
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "bf16[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"):
# File: /home/drisspg/meta/pytorch/torch/nn/attention/_templated_attention.py:89 in _templated_attention, code: out, _ = templated_attention_hop(query, key, value, score_mod)
mul: "bf16[]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
add: "bf16[]" = torch.ops.aten.add.Tensor(mul, 1); mul = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "bf16[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "bf16[]"):
# File: /home/drisspg/meta/pytorch/torch/nn/attention/_templated_attention.py:89 in _templated_attention, code: out, _ = templated_attention_hop(query, key, value, score_mod)
mul: "bf16[]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
add: "bf16[]" = torch.ops.aten.add.Tensor(mul, 1); mul = None
mul_1: "bf16[]" = torch.ops.aten.mul.Tensor(arg5_1, 2); arg5_1 = None
return [mul_1, None, None, None, None]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment