Skip to content

Instantly share code, notes, and snippets.

View mattteochen's full-sized avatar

Kaixi mattteochen

View GitHub Profile
@mattteochen
mattteochen / fx_graph_transformed.py
Created October 24, 2025 14:16
meta-llama/Llama-3.2-1B torch inductor transformed fx graph
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "i64[1, 1]", arg1_1: "bf16[128256, 2048]", arg2_1: "i64[1]", arg3_1: "i64[1, 1]", arg4_1: "bf16[1, 8, 107, 64]", arg5_1: "bf16[1, 1, 1, 107]", arg6_1: "f32[32]", arg7_1: "bf16[2048]", arg8_1: "bf16[2048, 2048]", arg9_1: "bf16[512, 2048]", arg10_1: "bf16[512, 2048]", arg11_1: "bf16[1, 8, 107, 64]", arg12_1: "bf16[2048, 2048]", arg13_1: "bf16[2048]", arg14_1: "bf16[8192, 2048]", arg15_1: "bf16[8192, 2048]", arg16_1: "bf16[2048, 8192]", arg17_1: "bf16[2048]"):
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py:422 in forward, code: inputs_embeds = self.embed_tokens(input_ids)
embedding: "bf16[1, 1, 2048]" = torch.ops.aten.embedding.default(arg1_1, arg0_1); arg0_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py:70 in forward, code: hidden_states = hidden_states.to(torch.float32)
convert_element_type_3: "f32[1, 1, 2048
@mattteochen
mattteochen / fx_graph_readable.py
Created October 24, 2025 14:18
meta-llama/Llama-3.2-1B torch inductor initial fx graph
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "i64[1, 1]", arg1_1: "bf16[128256, 2048]", arg2_1: "i64[1]", arg3_1: "i64[1, 1]", arg4_1: "bf16[1, 8, 107, 64]", arg5_1: "bf16[1, 1, 1, 107]", arg6_1: "f32[32]", arg7_1: "bf16[2048]", arg8_1: "bf16[2048, 2048]", arg9_1: "bf16[512, 2048]", arg10_1: "bf16[512, 2048]", arg11_1: "bf16[1, 8, 107, 64]", arg12_1: "bf16[2048, 2048]", arg13_1: "bf16[2048]", arg14_1: "bf16[8192, 2048]", arg15_1: "bf16[8192, 2048]", arg16_1: "bf16[2048, 8192]", arg17_1: "bf16[2048]"):
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py:422 in forward, code: inputs_embeds = self.embed_tokens(input_ids)
embedding: "bf16[1, 1, 2048]" = torch.ops.aten.embedding.default(arg1_1, arg0_1); arg0_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/modeling_attn_mask_utils.py:241 in _unmask_unattended, code: return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
@mattteochen
mattteochen / triton_poi_fused__to_copy__unsafe_view_add_bmm_cat_cos_index_copy_mul_neg_sin_slice_transpose_unsqueeze_view_2.py
Created October 24, 2025 14:23
triton_poi_fused__to_copy__unsafe_view_add_bmm_cat_cos_index_copy_mul_neg_sin_slice_transpose_unsqueeze_view_2
# Topologically Sorted Source Nodes: [, freqs, emb, cos, cos_1, cos_2, cos_3, sin, sin_1, sin_2, sin_3, linear_1, view_1, key_states, mul_7, x2_1, neg_1, x1_1, cat_2, mul_8, k_embed, index_copy_], Original ATen: [aten.bmm, aten.transpose, aten.cat, aten.cos, aten.mul, aten._to_copy, aten.unsqueeze, aten.sin, aten._unsafe_view, aten.view, aten.slice, aten.neg, aten.add, aten.index_copy]
# Source node to ATen node mapping:
# => unsqueeze_default
# cat_2 => cat_1
# cos => cos
# cos_1 => mul_1
# cos_2 => convert_element_type_1
# cos_3 => unsqueeze_4
# emb => clone, expand_3, unsqueeze_3, view_3
# freqs => permute
@mattteochen
mattteochen / triton_poi_fused__unsafe_view_index_copy_transpose_view_3.py
Created October 24, 2025 14:25
triton_poi_fused__unsafe_view_index_copy_transpose_view_3
# Topologically Sorted Source Nodes: [linear_2, view_2, value_states, index_copy__1], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten.index_copy]
# Source node to ATen node mapping:
# index_copy__1 => index_put_1
# linear_2 => view_11
# value_states => permute_6
# view_2 => view_12
# Graph fragment:
# %arg2_1 : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=arg2_1]
# %mm_2 : Tensor "bf16[1, 512][512, 1]cuda:0" = PlaceHolder[target=mm_2]
# %index_put_1 : Tensor "bf16[1, 8, 107, 64][54784, 6848, 64, 1]cuda:0" = PlaceHolder[target=index_put_1]
@mattteochen
mattteochen / triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.py
Created October 27, 2025 09:39
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0
# kernel path: /tmp/tmp_ud_1sth/6p/c6paalaw6ymqqxwpy3yqkd2sratjy75kwuwakoizzdscamrznvdq.py
# Topologically Sorted Source Nodes: [inputs_embeds, hidden_states, pow_1, variance, add, rsqrt, hidden_states_1, to_4, hidden_states_2], Original ATen: [aten.embedding, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add => add
# hidden_states => convert_element_type_3
# hidden_states_1 => mul_3
# hidden_states_2 => mul_4
# inputs_embeds => embedding
# pow_1 => pow_1
# rsqrt => rsqrt
@mattteochen
mattteochen / triton_poi_fused__to_copy_bmm_expand_unsqueeze_1.py
Created October 27, 2025 11:06
triton_poi_fused__to_copy_bmm_expand_unsqueeze_1
# kernel path: /tmp/tmpzs1g0hn_/6q/c6q3n2z4eyt5mlgkcbaswion4chjqn3m2unptu6d5w5dwenreeyo.py
# Topologically Sorted Source Nodes: [getitem_1, expand, , getitem_2, position_ids_expanded], Original ATen: [aten.unsqueeze, aten.expand, aten.bmm, aten._to_copy]
# Source node to ATen node mapping:
# => mm_default, squeeze_dim, squeeze_dim_1
# expand => expand
# getitem_1 => unsqueeze, unsqueeze_1
# getitem_2 => unsqueeze_2
# position_ids_expanded => convert_element_type
# Graph fragment:
# %arg3_1 : Tensor "i64[1, 1][1, 1]cuda:0" = PlaceHolder[target=arg3_1]
@mattteochen
mattteochen / triton_per_fused__scaled_dot_product_cudnn_attention__to_copy__unsafe_view_add_all_bitwise_not_bmm_cat_clone_cos_eq_expand_mul_neg_sin_slice_transpose_unsqueeze_view_4.py
Created October 27, 2025 13:06
triton_per_fused__scaled_dot_product_cudnn_attention__to_copy__unsafe_view_add_all_bitwise_not_bmm_cat_clone_cos_eq_expand_mul_neg_sin_slice_transpose_unsqueeze_view_4
# kernel path: /tmp/tmpzs1g0hn_/gd/cgdlz6cc7uvsuyhpw2hykkhdjlxnf46r2vltqwksqucfzfobxb5f.py
# Topologically Sorted Source Nodes: [linear, view, query_states, , freqs, emb, cos, cos_1, cos_2, cos_3, mul_5, x2, neg, x1, cat_1, sin, sin_1, sin_2, sin_3, mul_6, q_embed, key, value, eq, all_1, invert, causal_mask, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten.bmm, aten.cat, aten.cos, aten.mul, aten._to_copy, aten.unsqueeze, aten.slice, aten.neg, aten.sin, aten.add, aten.expand, aten.clone, aten.eq, aten.all, aten.bitwise_not, aten._scaled_dot_product_cudnn_attention]
# Source node to ATen node mapping:
# => unsqueeze_default
# all_1 => any_2, logical_not, logical_not_1
# attn_output => _scaled_dot_product_cudnn_attention
# cat_1 => cat
# causal_mask => mul
# cos => cos
# cos_1 => mul_1
@mattteochen
mattteochen / triton_poi_fused__scaled_dot_product_cudnn_attention__to_copy__unsafe_view_add_all_bitwise_not_bmm_cat_clone_cos_expand_mul_neg_sin_slice_transpose_unsqueeze_view_5.py
Created October 27, 2025 13:11
triton_poi_fused__scaled_dot_product_cudnn_attention__to_copy__unsafe_view_add_all_bitwise_not_bmm_cat_clone_cos_expand_mul_neg_sin_slice_transpose_unsqueeze_view_5
# Topologically Sorted Source Nodes: [linear, view, query_states, , freqs, emb, cos, cos_1, cos_2, cos_3, mul_5, x2, neg, x1, cat_1, sin, sin_1, sin_2, sin_3, mul_6, q_embed, key, value, all_1, invert, causal_mask, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten.bmm, aten.cat, aten.cos, aten.mul, aten._to_copy, aten.unsqueeze, aten.slice, aten.neg, aten.sin, aten.add, aten.expand, aten.clone, aten.all, aten.bitwise_not, aten._scaled_dot_product_cudnn_attention]
# Source node to ATen node mapping:
# => unsqueeze_default
# all_1 => logical_not_1
# attn_output => _scaled_dot_product_cudnn_attention
# cat_1 => cat
# causal_mask => mul
# cos => cos
# cos_1 => mul_1
# cos_2 => convert_element_type_1
@mattteochen
mattteochen / triton_poi_fused__scaled_dot_product_cudnn_attention__to_copy__unsafe_view_add_all_bitwise_not_bmm_cat_clone_cos_expand_mul_neg_sin_slice_transpose_unsqueeze_view_6.py
Created October 27, 2025 16:24
triton_poi_fused__scaled_dot_product_cudnn_attention__to_copy__unsafe_view_add_all_bitwise_not_bmm_cat_clone_cos_expand_mul_neg_sin_slice_transpose_unsqueeze_view_6
# Topologically Sorted Source Nodes: [linear, view, query_states, , freqs, emb, cos, cos_1, cos_2, cos_3, mul_5, x2, neg, x1, cat_1, sin, sin_1, sin_2, sin_3, mul_6, q_embed, key, value, all_1, invert, causal_mask, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten.bmm, aten.cat, aten.cos, aten.mul, aten._to_copy, aten.unsqueeze, aten.slice, aten.neg, aten.sin, aten.add, aten.expand, aten.clone, aten.all, aten.bitwise_not, aten._scaled_dot_product_cudnn_attention]
# Source node to ATen node mapping:
# => unsqueeze_default
# all_1 => logical_not_1
# attn_output => _scaled_dot_product_cudnn_attention
# cat_1 => cat
# causal_mask => mul
# cos => cos
# cos_1 => mul_1
# cos_2 => convert_element_type_1
@mattteochen
mattteochen / triton_red_fused__to_copy__unsafe_view_add_embedding_mean_mul_pow_rsqrt_7.py
Created October 27, 2025 16:43
triton_red_fused__to_copy__unsafe_view_add_embedding_mean_mul_pow_rsqrt_7
# Topologically Sorted Source Nodes: [inputs_embeds, attn_output_3, hidden_states_5, hidden_states_6, pow_2, variance_1, add_4, rsqrt_1, hidden_states_7, to_8, hidden_states_8], Original ATen: [aten.embedding, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_4 => add_4
# attn_output_3 => view_17
# hidden_states_5 => add_3
# hidden_states_6 => convert_element_type_13
# hidden_states_7 => mul_9
# hidden_states_8 => mul_10
# inputs_embeds => embedding
# pow_2 => pow_2