Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created February 24, 2023 21:22
Show Gist options
  • Save shunting314/c6ed9f571919e3b414166f1696dcc61b to your computer and use it in GitHub Desktop.
Save shunting314/c6ed9f571919e3b414166f1696dcc61b to your computer and use it in GitHub Desktop.
from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
triton__0 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[8], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [i
nstance_descriptor(divisible_by_16=(0, 1), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 6
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.sin(tmp0)
tmp2 = tl.cos(tmp0)
tmp3 = tmp1 + tmp2
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp3, xmask)
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, = args
args.clear()
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty_strided((2, 3), (3, 1), device='cuda', dtype=torch.float32)
stream0 = get_cuda_stream(0)
triton__0.run(arg0_1, buf0, 6, grid=grid(6), stream=stream0)
del arg0_1
return (buf0, )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((2, 3), (3, 1), device='cuda:0', dtype=torch.float32)
print_performance(lambda: call([arg0_1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment