Skip to content

Instantly share code, notes, and snippets.

View mattteochen's full-sized avatar

Kaixi mattteochen

View GitHub Profile
# AOT ID: ['3_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
# AOT ID: ['2_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
# AOT ID: ['1_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
@mattteochen
mattteochen / segmented_fusion.txt
Created October 28, 2025 10:19
NVFUSER_DUMP=segmented_fusion
Segment the fusion (Original Fusion Un-modified):
Inputs:
T0_g_int64_t[bS0{1}, iS1{8}]
T1_g___bfloat[iS2{128256}, iS3{2048}]
Outputs:
T5_g___bfloat[bS9{1}, iS10{8}, iS11{2048}]
%kernel_math {
T2_l_int64_t[iS4{8}]
= squeeze( T0_g_int64_t[bS0{1}, iS1{8}], flags = {true, false} )
@mattteochen
mattteochen / python_definition_segments.txt
Created October 28, 2025 10:11
NVFUSER_DUMP=python_definition_segments
Python definition for segmented group 1:
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[128256, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False)
T1 = fd.define_tensor(shape=[1, 8], contiguity=[None, True], dtype=DataType.Int, is_cpu=False)
T2 = fd.ops.squeeze(T1, dims=[0], squeeze_expanded=True)
T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True])
T4 = fd.ops.index_select(T0, T3, dim=0)
fd.add_output(T4)
@mattteochen
mattteochen / output_code.py
Created October 27, 2025 17:06
output_code
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
@mattteochen
mattteochen / triton_red_fused__to_copy__unsafe_view_add_embedding_mean_mul_pow_rsqrt_9.py
Created October 27, 2025 16:49
triton_red_fused__to_copy__unsafe_view_add_embedding_mean_mul_pow_rsqrt_9
# Topologically Sorted Source Nodes: [inputs_embeds, attn_output_3, hidden_states_5, down_proj, hidden_states_9, hidden_states_10, pow_3, variance_2, add_6, rsqrt_2, hidden_states_11, to_10, hidden_states_12], Original ATen: [aten.embedding, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_6 => add_6
# attn_output_3 => view_17
# down_proj => view_23
# hidden_states_10 => convert_element_type_23
# hidden_states_11 => mul_13
# hidden_states_12 => mul_14
# hidden_states_5 => add_3
# hidden_states_9 => add_5
@mattteochen
mattteochen / triton_poi_fused__unsafe_view_mul_silu_8.py
Created October 27, 2025 16:47
triton_poi_fused__unsafe_view_mul_silu_8
# Topologically Sorted Source Nodes: [linear_4, silu, linear_5, mul_11], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
# Source node to ATen node mapping:
# linear_4 => view_19
# linear_5 => view_21
# mul_11 => mul_12
# silu => convert_element_type_17, convert_element_type_18, mul_11, sigmoid
# Graph fragment:
# %mm_4 : Tensor "bf16[1, 8192][8192, 1]cuda:0" = PlaceHolder[target=mm_4]
# %mm_5 : Tensor "bf16[1, 8192][8192, 1]cuda:0" = PlaceHolder[target=mm_5]
# %view_19 : Tensor "bf16[1, 1, 8192][8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_4, [1, 1, 8192]), kwargs = {})
@mattteochen
mattteochen / triton_red_fused__to_copy__unsafe_view_add_embedding_mean_mul_pow_rsqrt_7.py
Created October 27, 2025 16:43
triton_red_fused__to_copy__unsafe_view_add_embedding_mean_mul_pow_rsqrt_7
# Topologically Sorted Source Nodes: [inputs_embeds, attn_output_3, hidden_states_5, hidden_states_6, pow_2, variance_1, add_4, rsqrt_1, hidden_states_7, to_8, hidden_states_8], Original ATen: [aten.embedding, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add_4 => add_4
# attn_output_3 => view_17
# hidden_states_5 => add_3
# hidden_states_6 => convert_element_type_13
# hidden_states_7 => mul_9
# hidden_states_8 => mul_10
# inputs_embeds => embedding
# pow_2 => pow_2