Skip to content

Instantly share code, notes, and snippets.

@malfet
Last active January 6, 2024 19:47
Show Gist options
  • Save malfet/77ed58fdb34681ff094716ae7c085780 to your computer and use it in GitHub Desktop.
Save malfet/77ed58fdb34681ff094716ae7c085780 to your computer and use it in GitHub Desktop.
Test triton
import triton
import triton.language as tl
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
if __name__ == "__main__":
import torch
print(f"torch version {torch.__version__} triton version {triton.__version__}")
inp = torch.randn(10, device='cuda')
out = torch.randn(10, device='cuda')
kernel[(10,)](inp, out, 10, XBLOCK=16)
print(inp, out)
@malfet
Copy link
Author

malfet commented Jan 6, 2024

An even simpler test-triton:

import torch

import triton
import triton.language as tl


@triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
    pass


X = torch.randn(1, device="cuda")
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment