Skip to content

Instantly share code, notes, and snippets.

View mattteochen's full-sized avatar

Kaixi mattteochen

View GitHub Profile
@mattteochen
mattteochen / llama4_graph0_thunder_0_repro.py
Created November 3, 2025 11:41
llama4_graph0_thunder_0_repro
from math import inf
from math import nan
NoneType = type(None)
import torch
from torch import device
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
import thunder
@mattteochen
mattteochen / llama4_thunder_example_inductor_subgraph.py
Created November 3, 2025 11:37
llama4_thunder_example_inductor_subgraph
# AOT ID: ['3_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
@mattteochen
mattteochen / triton_red_fused__to_copy__unsafe_view_add_index_mean_mul_pow_rsqrt_sort_view_17.py
Created October 30, 2025 15:40
triton_red_fused__to_copy__unsafe_view_add_index_mean_mul_pow_rsqrt_sort_view_17
# Topologically Sorted Source Nodes: [attn_output_3, hidden_states_3, hidden_states_5, hidden_states_6, attn_output_7, hidden_states_10, view_8, token_ids_sorted_by_expert_id, token_ids_sorted_by_expert_inverse_id, linear_14, outs_sorted_by_token_id, outs_sorted_by_token_id_1, hidden_states_14, hidden_states_15, float_11, pow_5, mean_4, add_11, rsqrt_4, mul_19, output_4, hidden_states_16], Original ATen: [aten._unsafe_view, aten.add, aten.view, aten.sort, aten.index, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_11 => add_11
# attn_output_3 => view_26
# attn_output_7 => view_57
# float_11 => convert_element_type_65
# hidden_states_10 => add_7
# hidden_states_14 => add_9
# hidden_states_15 => add_10
# hidden_states_16 => mul_23
@mattteochen
mattteochen / triton_poi_fused__grouped_mm_index_mul_sigmoid_sort_transpose_view_16.py
Created October 30, 2025 15:31
triton_poi_fused__grouped_mm_index_mul_sigmoid_sort_transpose_view_16
# Topologically Sorted Source Nodes: [hidden_states_12, view_8, token_ids_sorted_by_expert_id, router_scores, hidden_states_13, tokens_sorted_by_expert_id, transpose_11, _grouped_mm, transpose_12, _grouped_mm_1], Original ATen: [aten.view, aten.sort, aten.sigmoid, aten.mul, aten.index, aten.transpose, aten._grouped_mm]
# Source node to ATen node mapping:
# _grouped_mm => _grouped_mm
# _grouped_mm_1 => _grouped_mm_1
# hidden_states_12 => view_58
# hidden_states_13 => mul_17
# router_scores => sigmoid_1
# token_ids_sorted_by_expert_id => sort
# tokens_sorted_by_expert_id => index
# transpose_11 => permute_23
@mattteochen
mattteochen / triton_per_fused_cumsum_scatter_sum_14.py
Created October 30, 2025 15:26
triton_per_fused_cumsum_scatter_sum_14
# Topologically Sorted Source Nodes: [counts_1, tokens_per_expert, offsets], Original ATen: [aten.scatter, aten.sum, aten.cumsum]
# Source node to ATen node mapping:
# counts_1 => scatter_upon_const_tensor
# offsets => cumsum
# tokens_per_expert => sum_3
# Graph fragment:
# %getitem_1 : Tensor "i64[1, 1][1, 1]cuda:0" = PlaceHolder[target=getitem_1]
# %scatter_upon_const_tensor : Tensor "i32[1, 128][128, 1]cuda:0"[num_users=1] = call_function[target=torch._inductor.fx_passes.post_grad.scatter_upon_const_tensor](args = (), kwargs = {shape: [1, 128], background_val: 0, dtype: torch.int32, dim: 1, selector: %getitem_1, val: 1})
# %sum_3 : Tensor "i64[128][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%scatter_upon_const_tensor, [0]), kwargs = {})
# %cumsum : Tensor "i32[128][1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.cumsum.default](args = (%sum_3, 0), kwargs = {dtype: torch.int32})
@mattteochen
mattteochen / triton_red_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13.py
Created October 30, 2025 15:23
triton_red_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13
# Topologically Sorted Source Nodes: [attn_output_3, hidden_states_3, hidden_states_5, hidden_states_6, attn_output_7, hidden_states_10, float_10, pow_4, mean_3, add_8, rsqrt_3, mul_14, output_3, hidden_states_11], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_8 => add_8
# attn_output_3 => view_26
# attn_output_7 => view_57
# float_10 => convert_element_type_51
# hidden_states_10 => add_7
# hidden_states_11 => mul_16
# hidden_states_3 => add_2
# hidden_states_5 => view_32
@mattteochen
mattteochen / triton_red_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.py
Created October 30, 2025 14:45
triton_red_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12
# Topologically Sorted Source Nodes: [attn_output_3, hidden_states_3, hidden_states_5, hidden_states_6, float_7, pow_3, mean_2, add_5, rsqrt_2, mul_9, output_2, hidden_states_7], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_5 => add_5
# attn_output_3 => view_26
# float_7 => convert_element_type_31
# hidden_states_3 => add_2
# hidden_states_5 => view_32
# hidden_states_6 => add_4
# hidden_states_7 => mul_11
# mean_2 => mean_2
@mattteochen
mattteochen / triton_poi_fused__unsafe_view_mul_silu_11.py
Created October 30, 2025 14:36
triton_poi_fused__unsafe_view_mul_silu_11
# Topologically Sorted Source Nodes: [linear_4, silu, linear_5, down_proj], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
# Source node to ATen node mapping:
# down_proj => mul_9
# linear_4 => view_28
# linear_5 => view_30
# silu => convert_element_type_25, convert_element_type_26, mul_8, sigmoid
# Graph fragment:
# %mm_4 : Tensor "bf16[1, 16384][16384, 1]cuda:0" = PlaceHolder[target=mm_4]
# %mm_5 : Tensor "bf16[1, 16384][16384, 1]cuda:0" = PlaceHolder[target=mm_5]
# %view_28 : Tensor "bf16[1, 1, 16384][16384, 16384, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_4, [1, 1, 16384]), kwargs = {})
@mattteochen
mattteochen / triton_red_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.py
Created October 30, 2025 14:32
triton_red_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10
# Topologically Sorted Source Nodes: [attn_output_3, hidden_states_3, float_6, pow_2, mean_1, add_3, rsqrt_1, mul_6, output_1, hidden_states_4], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_3 => add_3
# attn_output_3 => view_26
# float_6 => convert_element_type_21
# hidden_states_3 => add_2
# hidden_states_4 => mul_7
# mean_1 => mean_1
# mul_6 => mul_6
# output_1 => convert_element_type_22
@mattteochen
mattteochen / triton_red_fused__softmax_add_exp_mul_prepare_softmax_online_slice_sub_view_9.py
Created October 30, 2025 14:23
triton_red_fused__softmax_add_exp_mul_prepare_softmax_online_slice_sub_view_9
# Topologically Sorted Source Nodes: [matmul_1, attn_weights, causal_mask, attn_weights_1, attn_weights_2, ], Original ATen: [aten.view, aten.mul, aten.slice, aten.add, aten._softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
# Source node to ATen node mapping:
# => exp_default_1, prepare_softmax_online_default_1, sub_tensor_1
# attn_weights => mul_5
# attn_weights_1 => add_1
# attn_weights_2 => convert_element_type_15, convert_element_type_16, div
# causal_mask => slice_1
# matmul_1 => view_20
# Graph fragment:
# %bmm_1 : Tensor "bf16[40, 1, 4100][4100, 4100, 1]cuda:0" = PlaceHolder[target=bmm_1]