Skip to content

Instantly share code, notes, and snippets.

View mattteochen's full-sized avatar

Kaixi mattteochen

View GitHub Profile
@mattteochen
mattteochen / triton_poi_fused__to_copy__unsafe_view_view_view_as_complex_6.py
Created October 30, 2025 14:02
triton_poi_fused__to_copy__unsafe_view_view_view_as_complex_6
# Topologically Sorted Source Nodes: [linear, query_states, float_4, reshape, xq_], Original ATen: [aten._unsafe_view, aten.view, aten._to_copy, aten.view_as_complex]
# Source node to ATen node mapping:
# float_4 => convert_element_type_9
# linear => view_4
# query_states => view_5
# reshape => view_12
# xq_ => view_as_complex
# Graph fragment:
# %mm : Tensor "bf16[1, 5120][5120, 1]cuda:0" = PlaceHolder[target=mm]
# %view_4 : Tensor "bf16[1, 1, 5120][5120, 5120, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [1, 1, 5120]), kwargs = {})
@mattteochen
mattteochen / triton_poi_fused__to_copy_index_copy_transpose_view_5.py
Created October 30, 2025 14:00
triton_poi_fused__to_copy_index_copy_transpose_view_5
# kernel path: /tmp/tmp4dz_fivh/ba/cbaeyiwckwicjrhti36fgmh3bgn3kcfkfiluqir4sqp32fvlmeus.py
# Topologically Sorted Source Nodes: [xk_out, key_states_1, key_states_2, index_copy_], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.index_copy]
# Source node to ATen node mapping:
# index_copy_ => index_put
# key_states_1 => convert_element_type_12
# key_states_2 => permute_6
# xk_out => view_15
# Graph fragment:
# %arg7_1 : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=arg7_1]
# %view_as_real_1 : Tensor "f32[1, 1, 8, 64, 2][1024, 1024, 128, 2, 1]cuda:0" = PlaceHolder[target=view_as_real_1]
@mattteochen
mattteochen / triton_poi_fused__to_copy__unsafe_view_view_view_as_complex_4.py
Created October 30, 2025 13:57
triton_poi_fused__to_copy__unsafe_view_view_view_as_complex_4
# Topologically Sorted Source Nodes: [linear_1, key_states, float_5, reshape_1, xk_], Original ATen: [aten._unsafe_view, aten.view, aten._to_copy, aten.view_as_complex]
# Source node to ATen node mapping:
# float_5 => convert_element_type_10
# key_states => view_8
# linear_1 => view_7
# reshape_1 => view_13
# xk_ => view_as_complex_1
# Graph fragment:
# %mm_1 : Tensor "bf16[1, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_1]
# %view_7 : Tensor "bf16[1, 1, 1024][1024, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [1, 1, 1024]), kwargs = {})
@mattteochen
mattteochen / triton_poi_fused__unsafe_view_index_copy_transpose_view_3.py
Created October 30, 2025 13:53
triton_poi_fused__unsafe_view_index_copy_transpose_view_3
# Topologically Sorted Source Nodes: [linear_2, view_2, value_states, index_copy__1], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten.index_copy]
# Source node to ATen node mapping:
# index_copy__1 => index_put_1
# linear_2 => view_10
# value_states => permute_4
# view_2 => view_11
# Graph fragment:
# %arg7_1 : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=arg7_1]
# %mm_2 : Tensor "bf16[1, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_2]
# %index_put_1 : Tensor "bf16[1, 8, 4100, 128][4198400, 524800, 128, 1]cuda:0" = PlaceHolder[target=index_put_1]
@mattteochen
mattteochen / triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2.py
Created October 30, 2025 13:44
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2
# kernel path: /tmp/tmp4dz_fivh/av/cavam224no6dxhr6ikzn4lnvqxsjqhhk4nyhfsrpw7sy4gm7tknp.py
# Topologically Sorted Source Nodes: [float_3, pow_1, mean, add, rsqrt, mul_1, output, hidden_states], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add => add
# float_3 => convert_element_type_1
# hidden_states => mul_2
# mean => mean
# mul_1 => mul_1
# output => convert_element_type_2
# pow_1 => pow_1
@mattteochen
mattteochen / triton_poi_fused_ones_like_1.py
Created October 30, 2025 13:37
triton_poi_fused_ones_like_1
# kernel path: /tmp/tmp4dz_fivh/ee/ceemti6wwrcjx35l4qf3p467zo3iv37qntq4nzlsewugulul3fyk.py
# Topologically Sorted Source Nodes: [ones_like], Original ATen: [aten.ones_like]
# Source node to ATen node mapping:
# ones_like => full_default
# Graph fragment:
# %full_default : Tensor "f32[1, 1, 64][64, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([1, 1, 64], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# return %full_default
triton_poi_fused_ones_like_1 = async_compile.triton('triton_poi_fused_ones_like_1', '''
import triton
import triton.language as tl
@mattteochen
mattteochen / triton_poi_fused__to_copy_bmm_expand_unsqueeze_0.py
Created October 30, 2025 13:34
triton_poi_fused__to_copy_bmm_expand_unsqueeze_0
# Topologically Sorted Source Nodes: [getitem, inv_freq_expanded, , getitem_1, position_ids_expanded], Original ATen: [aten.unsqueeze, aten.expand, aten.bmm, aten._to_copy]
# Source node to ATen node mapping:
# => mm_default, squeeze_dim, squeeze_dim_1
# getitem => unsqueeze, unsqueeze_1
# getitem_1 => unsqueeze_2
# inv_freq_expanded => expand
# position_ids_expanded => convert_element_type
# Graph fragment:
# %arg1_1 : Tensor "i64[1, 1][1, 1]cuda:0" = PlaceHolder[target=arg1_1]
# %expand_2 : Tensor "f32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=expand_2]
@mattteochen
mattteochen / triton_poi_fused_arange_full_gt_mul_view_0.py
Created October 30, 2025 13:12
triton_poi_fused_arange_full_gt_mul_view_0
# AOT ID: ['1_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
@mattteochen
mattteochen / triton_poi_fused_embedding_0.py
Created October 30, 2025 13:00
triton_poi_fused_embedding_0
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
# AOT ID: ['3_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks