Skip to content

Instantly share code, notes, and snippets.

@xmfan
Last active February 3, 2024 05:29
Show Gist options
  • Save xmfan/e6cabda8e9c01ac7eb3741a026610b1a to your computer and use it in GitHub Desktop.
Save xmfan/e6cabda8e9c01ac7eb3741a026610b1a to your computer and use it in GitHub Desktop.
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_xmfan/7g/c7gwxjqcjp4fkf4zhh6bvlxq66vbvqg6btrugmuovositqsn7ksy.py
# Source Nodes: [mm, mm_1], Original ATen: [aten.mm]
# mm => mm
# mm_1 => mm_1
triton_poi_fused_mm_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@pointwise(
size_hints=[4],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mm_0', 'mutated_arg_names': [], 'no_x_dim': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
xnumel = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tl.store(out_ptr0 + (x0), tmp1, xmask)
tl.store(out_ptr1 + (x0), tmp1, xmask)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
args.clear()
assert_size_stride(arg0_1, (), ())
assert_size_stride(arg1_1, (1, 2), (1, 1))
assert_size_stride(arg2_1, (2, 1), (1, 1))
assert_size_stride(arg3_1, (2, 1), (1, 1))
assert_size_stride(arg4_1, (2, 2), (2, 1))
assert_size_stride(arg5_1, (1, 2), (2, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((2, 2), (2, 1), torch.float32)
buf4 = empty_strided_cuda((2, 2), (2, 1), torch.float32)
# Source Nodes: [mm, mm_1], Original ATen: [aten.mm]
stream0 = get_raw_stream(0)
triton_poi_fused_mm_0.run(arg0_1, buf0, buf4, 4, grid=grid(4), stream=stream0)
del arg0_1
buf1 = empty_strided_cuda((2, 1), (1, 1), torch.float32)
# Source Nodes: [mm], Original ATen: [aten.mm]
extern_kernels.mm(buf0, arg2_1, out=buf1)
del arg2_1
del buf0
inductor_ops.accumulate_grad_(arg3_1, reinterpret_tensor(buf1, (2, 1), (1, 1), 0))
del arg3_1
buf5 = buf1; del buf1 # reuse
# Source Nodes: [mm_1], Original ATen: [aten.mm]
extern_kernels.mm(buf4, reinterpret_tensor(arg1_1, (2, 1), (1, 1), 0), out=buf5)
del arg1_1
del buf4
buf6 = empty_strided_cuda((1, 2), (2, 1), torch.float32)
# Source Nodes: [mm_2], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf5, (1, 2), (1, 1), 0), arg4_1, out=buf6)
del arg4_1
del buf5
inductor_ops.accumulate_grad_(arg5_1, reinterpret_tensor(buf6, (1, 2), (2, 1), 0))
del buf6
del arg5_1
return ()
# def benchmark_compiled_module(times=10, repeat=10):
def benchmark_compiled_module():
# from torch._dynamo.testing import rand_strided
# from torch._inductor.utils import print_performance
# arg0_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
# arg1_1 = rand_strided((1, 2), (1, 1), device='cuda:0', dtype=torch.float32)
# arg2_1 = rand_strided((2, 1), (1, 1), device='cuda:0', dtype=torch.float32)
# arg3_1 = rand_strided((2, 1), (1, 1), device='cuda:0', dtype=torch.float32)
# arg4_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
# arg5_1 = rand_strided((1, 2), (2, 1), device='cuda:0', dtype=torch.float32)
arg0_1 = torch.tensor(1.).to("cuda")
arg1_1 = torch.tensor([[-0.8230, -0.7359]]).to("cuda")
arg2_1 = torch.tensor([[ 0.2271], [-0.5247]]).to("cuda")
arg3_1 = torch.nn.Parameter(torch.tensor([[-0.8230],[-0.7359]])).to("cuda")
arg4_1 = torch.tensor([[-2.1788, 0.5684], [-1.0845, -1.3986]]).to("cuda")
arg5_1 = torch.nn.Parameter(torch.tensor([[-0.0053, 0.3793]])).to("cuda")
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1])
torch.manual_seed(0)
fn()
print(arg5_1.grad)
# tensor([[5.0872, 1.2942]], device='cuda:0')
print(arg3_1.grad)
# tensor([[-0.2976], [-0.2976]], device='cuda:0')
# return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
# from torch._inductor.wrapper_benchmark import compiled_module_main
# compiled_module_main('None', benchmark_compiled_module)
benchmark_compiled_module()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment